Allow excluding cache based on status code (#1403)

This introduces an option `--cache-exclude-status`, which allows specifying a range of HTTP status codes which will be ignored from the cache.

Closes #1400.
This commit is contained in:
Damien Mathieu 2024-10-14 02:41:56 +02:00 committed by GitHub
parent 2a9f11a289
commit f0ebac29a2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 491 additions and 47 deletions

View file

@ -335,6 +335,22 @@ Options:
[default: 1d]
--cache-exclude-status <CACHE_EXCLUDE_STATUS>
A list of status codes that will be ignored from the cache
The following accept range syntax is supported: [start]..[=]end|code. Some valid
examples are:
- 429
- 500..=599
- 500..
Use "lychee --cache-exclude-status '429, 500..502' <inputs>..." to provide a comma- separated
list of excluded status codes. This example will not cache results with a status code of 429, 500,
501 and 502.
[default: ]
--dump
Don't perform any link checking. Instead, dump all the links extracted from inputs that would be checked

View file

@ -10,7 +10,7 @@ use reqwest::Url;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use lychee_lib::{Client, ErrorKind, Request, Response};
use lychee_lib::{Client, ErrorKind, Request, Response, Uri};
use lychee_lib::{InputSource, Result};
use lychee_lib::{ResponseBody, Status};
@ -46,6 +46,7 @@ where
let client = params.client;
let cache = params.cache;
let cache_exclude_status = params.cfg.cache_exclude_status.into_set();
let accept = params.cfg.accept.into_set();
let pb = if params.cfg.no_progress || params.cfg.verbose.log_level() >= log::Level::Info {
@ -61,6 +62,7 @@ where
max_concurrency,
client,
cache,
cache_exclude_status,
accept,
));
@ -219,6 +221,7 @@ async fn request_channel_task(
max_concurrency: usize,
client: Client,
cache: Arc<Cache>,
cache_exclude_status: HashSet<u16>,
accept: HashSet<u16>,
) {
StreamExt::for_each_concurrent(
@ -226,7 +229,14 @@ async fn request_channel_task(
max_concurrency,
|request: Result<Request>| async {
let request = request.expect("cannot read request");
let response = handle(&client, cache.clone(), request, accept.clone()).await;
let response = handle(
&client,
cache.clone(),
cache_exclude_status.clone(),
request,
accept.clone(),
)
.await;
send_resp
.send(response)
@ -260,6 +270,7 @@ async fn check_url(client: &Client, request: Request) -> Response {
async fn handle(
client: &Client,
cache: Arc<Cache>,
cache_exclude_status: HashSet<u16>,
request: Request,
accept: HashSet<u16>,
) -> Response {
@ -287,9 +298,10 @@ async fn handle(
// benefit.
// - Skip caching unsupported URLs as they might be supported in a
// future run.
// - Skip caching excluded links; they might not be excluded in the next run
// - Skip caching excluded links; they might not be excluded in the next run.
// - Skip caching links for which the status code has been explicitly excluded from the cache.
let status = response.status();
if uri.is_file() || status.is_excluded() || status.is_unsupported() || status.is_unknown() {
if ignore_cache(&uri, status, &cache_exclude_status) {
return response;
}
@ -297,6 +309,26 @@ async fn handle(
response
}
/// Returns `true` if the response should be ignored in the cache.
///
/// The response should be ignored if:
/// - The URI is a file URI.
/// - The status is excluded.
/// - The status is unsupported.
/// - The status is unknown.
/// - The status code is excluded from the cache.
fn ignore_cache(uri: &Uri, status: &Status, cache_exclude_status: &HashSet<u16>) -> bool {
let status_code_excluded = status
.code()
.map_or(false, |code| cache_exclude_status.contains(&code.as_u16()));
uri.is_file()
|| status.is_excluded()
|| status.is_unsupported()
|| status.is_unknown()
|| status_code_excluded
}
fn show_progress(
output: &mut dyn Write,
progress_bar: &Option<ProgressBar>,
@ -344,8 +376,9 @@ fn get_failed_urls(stats: &mut ResponseStats) -> Vec<(InputSource, Url)> {
#[cfg(test)]
mod tests {
use crate::{formatters::get_response_formatter, options};
use http::StatusCode;
use log::info;
use lychee_lib::{CacheStatus, ClientBuilder, InputSource, Uri};
use lychee_lib::{CacheStatus, ClientBuilder, ErrorKind, InputSource, Uri};
use super::*;
@ -406,4 +439,55 @@ mod tests {
Status::Error(ErrorKind::InvalidURI(_))
));
}
#[test]
fn test_cache_by_default() {
assert!(!ignore_cache(
&Uri::try_from("https://[::1]").unwrap(),
&Status::Ok(StatusCode::OK),
&HashSet::default()
));
}
#[test]
// Cache is ignored for file URLs
fn test_cache_ignore_file_urls() {
assert!(ignore_cache(
&Uri::try_from("file:///home").unwrap(),
&Status::Ok(StatusCode::OK),
&HashSet::default()
));
}
#[test]
// Cache is ignored for unsupported status
fn test_cache_ignore_unsupported_status() {
assert!(ignore_cache(
&Uri::try_from("https://[::1]").unwrap(),
&Status::Unsupported(ErrorKind::EmptyUrl),
&HashSet::default()
));
}
#[test]
// Cache is ignored for unknown status
fn test_cache_ignore_unknown_status() {
assert!(ignore_cache(
&Uri::try_from("https://[::1]").unwrap(),
&Status::UnknownStatusCode(StatusCode::IM_A_TEAPOT),
&HashSet::default()
));
}
#[test]
fn test_cache_ignore_excluded_status() {
// Cache is ignored for excluded status codes
let exclude = [StatusCode::OK.as_u16()].iter().copied().collect();
assert!(ignore_cache(
&Uri::try_from("https://[::1]").unwrap(),
&Status::Ok(StatusCode::OK),
&exclude
));
}
}

View file

@ -6,8 +6,8 @@ use clap::builder::PossibleValuesParser;
use clap::{arg, builder::TypedValueParser, Parser};
use const_format::{concatcp, formatcp};
use lychee_lib::{
AcceptSelector, Base, BasicAuthSelector, Input, DEFAULT_MAX_REDIRECTS, DEFAULT_MAX_RETRIES,
DEFAULT_RETRY_WAIT_TIME_SECS, DEFAULT_TIMEOUT_SECS, DEFAULT_USER_AGENT,
Base, BasicAuthSelector, Input, StatusCodeExcluder, StatusCodeSelector, DEFAULT_MAX_REDIRECTS,
DEFAULT_MAX_RETRIES, DEFAULT_RETRY_WAIT_TIME_SECS, DEFAULT_TIMEOUT_SECS, DEFAULT_USER_AGENT,
};
use secrecy::{ExposeSecret, SecretString};
use serde::Deserialize;
@ -145,7 +145,8 @@ default_function! {
retry_wait_time: usize = DEFAULT_RETRY_WAIT_TIME_SECS;
method: String = DEFAULT_METHOD.to_string();
verbosity: Verbosity = Verbosity::default();
accept_selector: AcceptSelector = AcceptSelector::default();
cache_exclude_selector: StatusCodeExcluder = StatusCodeExcluder::new();
accept_selector: StatusCodeSelector = StatusCodeSelector::default();
}
// Macro for merging configuration values
@ -231,6 +232,26 @@ pub(crate) struct Config {
#[serde(with = "humantime_serde")]
pub(crate) max_cache_age: Duration,
/// A list of status codes that will be excluded from the cache
#[arg(
long,
default_value_t,
long_help = "A list of status codes that will be ignored from the cache
The following accept range syntax is supported: [start]..[=]end|code. Some valid
examples are:
- 429
- 500..=599
- 500..
Use \"lychee --cache-exclude-status '429, 500..502' <inputs>...\" to provide a comma- separated
list of excluded status codes. This example will not cache results with a status code of 429, 500,
501 and 502."
)]
#[serde(default = "cache_exclude_selector")]
pub(crate) cache_exclude_status: StatusCodeExcluder,
/// Don't perform any link checking.
/// Instead, dump all the links extracted from inputs that would be checked
#[arg(long)]
@ -394,7 +415,7 @@ separated list of accepted status codes. This example will accept 200, 201,
202, 203, 204, 429, and 500 as valid status codes."
)]
#[serde(default = "accept_selector")]
pub(crate) accept: AcceptSelector,
pub(crate) accept: StatusCodeSelector,
/// Enable the checking of fragments in links.
#[arg(long)]
@ -509,6 +530,7 @@ impl Config {
max_retries: DEFAULT_MAX_RETRIES;
max_concurrency: DEFAULT_MAX_CONCURRENCY;
max_cache_age: humantime::parse_duration(DEFAULT_MAX_CACHE_AGE).unwrap();
cache_exclude_status: StatusCodeExcluder::default();
threads: None;
user_agent: DEFAULT_USER_AGENT;
insecure: false;
@ -538,7 +560,7 @@ impl Config {
require_https: false;
cookie_jar: None;
include_fragments: false;
accept: AcceptSelector::default();
accept: StatusCodeSelector::default();
}
if self
@ -564,7 +586,7 @@ mod tests {
#[test]
fn test_accept_status_codes() {
let toml = Config {
accept: AcceptSelector::from_str("200..=204, 429, 500").unwrap(),
accept: StatusCodeSelector::from_str("200..=204, 429, 500").unwrap(),
..Default::default()
};
@ -577,4 +599,15 @@ mod tests {
assert!(cli.accept.contains(204));
assert!(!cli.accept.contains(205));
}
#[test]
fn test_default() {
let cli = Config::default();
assert_eq!(
cli.accept,
StatusCodeSelector::from_str("100..=103,200..=299").expect("no error")
);
assert_eq!(cli.cache_exclude_status, StatusCodeExcluder::new());
}
}

View file

@ -895,6 +895,65 @@ mod cli {
Ok(())
}
#[tokio::test]
async fn test_lycheecache_exclude_custom_status_codes() -> Result<()> {
let base_path = fixtures_path().join("cache");
let cache_file = base_path.join(LYCHEE_CACHE_FILE);
// Unconditionally remove cache file if it exists
let _ = fs::remove_file(&cache_file);
let mock_server_ok = mock_server!(StatusCode::OK);
let mock_server_no_content = mock_server!(StatusCode::NO_CONTENT);
let mock_server_too_many_requests = mock_server!(StatusCode::TOO_MANY_REQUESTS);
let dir = tempfile::tempdir()?;
let mut file = File::create(dir.path().join("c.md"))?;
writeln!(file, "{}", mock_server_ok.uri().as_str())?;
writeln!(file, "{}", mock_server_no_content.uri().as_str())?;
writeln!(file, "{}", mock_server_too_many_requests.uri().as_str())?;
let mut cmd = main_command();
let test_cmd = cmd
.current_dir(&base_path)
.arg(dir.path().join("c.md"))
.arg("--verbose")
.arg("--no-progress")
.arg("--cache")
.arg("--cache-exclude-status")
.arg("204,429");
assert!(
!cache_file.exists(),
"cache file should not exist before this test"
);
// run first without cache to generate the cache file
test_cmd
.assert()
.stderr(contains(format!("[200] {}/\n", mock_server_ok.uri())))
.stderr(contains(format!(
"[204] {}/ | OK (204 No Content): No Content\n",
mock_server_no_content.uri()
)))
.stderr(contains(format!(
"[429] {}/ | Failed: Network error: Too Many Requests\n",
mock_server_too_many_requests.uri()
)));
// check content of cache file
let data = fs::read_to_string(&cache_file)?;
assert!(data.contains(&format!("{}/,200", mock_server_ok.uri())));
assert!(!data.contains(&format!("{}/,204", mock_server_no_content.uri())));
assert!(!data.contains(&format!("{}/,429", mock_server_too_many_requests.uri())));
// clear the cache file
fs::remove_file(&cache_file)?;
Ok(())
}
#[tokio::test]
async fn test_lycheecache_accept_custom_status_codes() -> Result<()> {
let base_path = fixtures_path().join("cache_accept_custom_status_codes");

View file

@ -95,8 +95,9 @@ pub use crate::{
collector::Collector,
filter::{Excludes, Filter, Includes},
types::{
uri::valid::Uri, AcceptRange, AcceptRangeError, AcceptSelector, Base, BasicAuthCredentials,
uri::valid::Uri, AcceptRange, AcceptRangeError, Base, BasicAuthCredentials,
BasicAuthSelector, CacheStatus, CookieJar, ErrorKind, FileType, Input, InputContent,
InputSource, Request, Response, ResponseBody, Result, Status,
InputSource, Request, Response, ResponseBody, Result, Status, StatusCodeExcluder,
StatusCodeSelector,
},
};

View file

@ -1,5 +1,3 @@
mod range;
mod selector;
pub use range::*;
pub use selector::*;

View file

@ -7,8 +7,8 @@ use thiserror::Error;
static RANGE_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new(r"^([0-9]{3})?\.\.(=?)([0-9]{3})+$|^([0-9]{3})$").unwrap());
/// The [`AcceptRangeParseError`] indicates that the parsing process of an
/// [`AcceptRange`] from a string failed due to various underlying reasons.
/// Indicates that the parsing process of an [`AcceptRange`] from a string
/// failed due to various underlying reasons.
#[derive(Debug, Error, PartialEq)]
pub enum AcceptRangeError {
/// The string input didn't contain any range pattern.

View file

@ -6,7 +6,7 @@ use thiserror::Error;
use tokio::task::JoinError;
use super::InputContent;
use crate::types::AcceptSelectorError;
use crate::types::StatusCodeSelectorError;
use crate::{basic_auth::BasicAuthExtractorError, utils, Uri};
/// Kinds of status errors
@ -142,9 +142,9 @@ pub enum ErrorKind {
#[error("Cannot load cookies")]
Cookies(String),
/// Accept selector parse error
#[error("Accept range error")]
AcceptSelectorError(#[from] AcceptSelectorError),
/// Status code selector parse error
#[error("Status code range error")]
StatusCodeSelectorError(#[from] StatusCodeSelectorError),
}
impl ErrorKind {
@ -290,7 +290,7 @@ impl Hash for ErrorKind {
Self::TooManyRedirects(e) => e.to_string().hash(state),
Self::BasicAuthExtractorError(e) => e.to_string().hash(state),
Self::Cookies(e) => e.to_string().hash(state),
Self::AcceptSelectorError(e) => e.to_string().hash(state),
Self::StatusCodeSelectorError(e) => e.to_string().hash(state),
}
}
}

View file

@ -12,6 +12,7 @@ pub(crate) mod mail;
mod request;
mod response;
mod status;
mod status_code;
pub(crate) mod uri;
pub use accept::*;
@ -25,6 +26,7 @@ pub use input::{Input, InputContent, InputSource};
pub use request::Request;
pub use response::{Response, ResponseBody};
pub use status::Status;
pub use status_code::*;
/// The lychee `Result` type
pub type Result<T> = std::result::Result<T, crate::ErrorKind>;

View file

@ -0,0 +1,242 @@
use std::{collections::HashSet, fmt::Display, str::FromStr};
use serde::{de::Visitor, Deserialize};
use crate::{
types::accept::AcceptRange, types::status_code::StatusCodeSelectorError, AcceptRangeError,
};
/// A [`StatusCodeExcluder`] holds ranges of HTTP status codes, and determines
/// whether a specific code is matched, so the link can be counted as valid (not
/// broken) or excluded. `StatusCodeExcluder` differs from
/// [`StatusCodeSelector`](super::selector::StatusCodeSelector) in the defaults
/// it provides. As this is meant to exclude status codes, the default is to
/// keep everything.
#[derive(Clone, Debug, PartialEq)]
pub struct StatusCodeExcluder {
ranges: Vec<AcceptRange>,
}
impl FromStr for StatusCodeExcluder {
type Err = StatusCodeSelectorError;
fn from_str(input: &str) -> Result<Self, Self::Err> {
let input = input.trim();
if input.is_empty() {
return Ok(Self::new());
}
let ranges = input
.split(',')
.map(|part| AcceptRange::from_str(part.trim()))
.collect::<Result<Vec<AcceptRange>, AcceptRangeError>>()?;
Ok(Self::new_from(ranges))
}
}
impl Default for StatusCodeExcluder {
fn default() -> Self {
Self::new()
}
}
impl Display for StatusCodeExcluder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let ranges: Vec<_> = self.ranges.iter().map(ToString::to_string).collect();
write!(f, "{}", ranges.join(","))
}
}
impl StatusCodeExcluder {
/// Creates a new empty [`StatusCodeExcluder`].
#[must_use]
pub const fn new() -> Self {
Self { ranges: Vec::new() }
}
/// Creates a new [`StatusCodeExcluder`] prefilled with `ranges`.
#[must_use]
pub fn new_from(ranges: Vec<AcceptRange>) -> Self {
let mut selector = Self::new();
for range in ranges {
selector.add_range(range);
}
selector
}
/// Adds a range of HTTP status codes to this [`StatusCodeExcluder`].
/// This method merges the new and existing ranges if they overlap.
pub fn add_range(&mut self, range: AcceptRange) -> &mut Self {
// Merge with previous range if possible
if let Some(last) = self.ranges.last_mut() {
if last.merge(&range) {
return self;
}
}
// If neither is the case, the ranges have no overlap at all. Just add
// to the list of ranges.
self.ranges.push(range);
self
}
/// Returns whether this [`StatusCodeExcluder`] contains `value`.
#[must_use]
pub fn contains(&self, value: u16) -> bool {
self.ranges.iter().any(|range| range.contains(value))
}
/// Consumes self and creates a [`HashSet`] which contains all
/// accepted status codes.
#[must_use]
pub fn into_set(self) -> HashSet<u16> {
let mut set = HashSet::new();
for range in self.ranges {
for value in range.inner() {
set.insert(value);
}
}
set
}
#[cfg(test)]
pub(crate) fn len(&self) -> usize {
self.ranges.len()
}
}
struct StatusCodeExcluderVisitor;
impl<'de> Visitor<'de> for StatusCodeExcluderVisitor {
type Value = StatusCodeExcluder;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a string or a sequence of strings")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
StatusCodeExcluder::from_str(v).map_err(serde::de::Error::custom)
}
fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
let value = u16::try_from(v).map_err(serde::de::Error::custom)?;
Ok(StatusCodeExcluder::new_from(vec![AcceptRange::new(
value, value,
)]))
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut selector = StatusCodeExcluder::new();
while let Some(v) = seq.next_element::<toml::Value>()? {
if let Some(v) = v.as_integer() {
let value = u16::try_from(v).map_err(serde::de::Error::custom)?;
selector.add_range(AcceptRange::new(value, value));
continue;
}
if let Some(s) = v.as_str() {
let range = AcceptRange::from_str(s).map_err(serde::de::Error::custom)?;
selector.add_range(range);
continue;
}
return Err(serde::de::Error::custom(
"failed to parse sequence of accept ranges",
));
}
Ok(selector)
}
}
impl<'de> Deserialize<'de> for StatusCodeExcluder {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_any(StatusCodeExcluderVisitor)
}
}
#[cfg(test)]
mod test {
use super::*;
use rstest::rstest;
#[rstest]
#[case("", vec![], vec![100, 110, 150, 200, 300, 175, 350], 0)]
#[case("100..=150,200..=300", vec![100, 110, 150, 200, 300], vec![175, 350], 2)]
#[case("200..=300,100..=250", vec![100, 150, 200, 250, 300], vec![350], 1)]
#[case("100..=200,150..=200", vec![100, 150, 200], vec![250, 300], 1)]
#[case("100..=200,300", vec![100, 110, 200, 300], vec![250, 350], 2)]
fn test_from_str(
#[case] input: &str,
#[case] valid_values: Vec<u16>,
#[case] invalid_values: Vec<u16>,
#[case] length: usize,
) {
let selector = StatusCodeExcluder::from_str(input).unwrap();
assert_eq!(selector.len(), length);
for valid in valid_values {
assert!(selector.contains(valid));
}
for invalid in invalid_values {
assert!(!selector.contains(invalid));
}
}
#[rstest]
#[case(r"accept = ['200..204', '429']", vec![200, 203, 429], vec![204, 404], 2)]
#[case(r"accept = '200..204, 429'", vec![200, 203, 429], vec![204, 404], 2)]
#[case(r"accept = ['200', '429']", vec![200, 429], vec![404], 2)]
#[case(r"accept = '200, 429'", vec![200, 429], vec![404], 2)]
#[case(r"accept = [200, 429]", vec![200, 429], vec![404], 2)]
#[case(r"accept = '200'", vec![200], vec![404], 1)]
#[case(r"accept = 200", vec![200], vec![404], 1)]
fn test_deserialize(
#[case] input: &str,
#[case] valid_values: Vec<u16>,
#[case] invalid_values: Vec<u16>,
#[case] length: usize,
) {
#[derive(Deserialize)]
struct Config {
accept: StatusCodeExcluder,
}
let config: Config = toml::from_str(input).unwrap();
assert_eq!(config.accept.len(), length);
for valid in valid_values {
assert!(config.accept.contains(valid));
}
for invalid in invalid_values {
assert!(!config.accept.contains(invalid));
}
}
#[rstest]
#[case("100..=150,200..=300", "100..=150,200..=300")]
#[case("100..=150,300", "100..=150,300..=300")]
fn test_display(#[case] input: &str, #[case] display: &str) {
let selector = StatusCodeExcluder::from_str(input).unwrap();
assert_eq!(selector.to_string(), display);
}
}

View file

@ -0,0 +1,5 @@
mod excluder;
mod selector;
pub use excluder::*;
pub use selector::*;

View file

@ -6,7 +6,7 @@ use thiserror::Error;
use crate::{types::accept::AcceptRange, AcceptRangeError};
#[derive(Debug, Error)]
pub enum AcceptSelectorError {
pub enum StatusCodeSelectorError {
#[error("invalid/empty input")]
InvalidInput,
@ -14,21 +14,25 @@ pub enum AcceptSelectorError {
AcceptRangeError(#[from] AcceptRangeError),
}
/// An [`AcceptSelector`] determines if a returned HTTP status code should be
/// accepted and thus counted as a valid (not broken) link.
/// A [`StatusCodeSelector`] holds ranges of HTTP status codes, and determines
/// whether a specific code is matched, so the link can be counted as valid (not
/// broken) or excluded. `StatusCodeSelector` differs from
/// [`StatusCodeExcluder`](super::excluder::StatusCodeExcluder)
/// in the defaults it provides. As this is meant to
/// select valid status codes, the default includes every successful status.
#[derive(Clone, Debug, PartialEq)]
pub struct AcceptSelector {
pub struct StatusCodeSelector {
ranges: Vec<AcceptRange>,
}
impl FromStr for AcceptSelector {
type Err = AcceptSelectorError;
impl FromStr for StatusCodeSelector {
type Err = StatusCodeSelectorError;
fn from_str(input: &str) -> Result<Self, Self::Err> {
let input = input.trim();
if input.is_empty() {
return Err(AcceptSelectorError::InvalidInput);
return Err(StatusCodeSelectorError::InvalidInput);
}
let ranges = input
@ -40,27 +44,27 @@ impl FromStr for AcceptSelector {
}
}
impl Default for AcceptSelector {
impl Default for StatusCodeSelector {
fn default() -> Self {
Self::new_from(vec![AcceptRange::new(100, 103), AcceptRange::new(200, 299)])
}
}
impl Display for AcceptSelector {
impl Display for StatusCodeSelector {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let ranges: Vec<_> = self.ranges.iter().map(ToString::to_string).collect();
write!(f, "{}", ranges.join(","))
}
}
impl AcceptSelector {
/// Creates a new empty [`AcceptSelector`].
impl StatusCodeSelector {
/// Creates a new empty [`StatusCodeSelector`].
#[must_use]
pub const fn new() -> Self {
Self { ranges: Vec::new() }
}
/// Creates a new [`AcceptSelector`] prefilled with `ranges`.
/// Creates a new [`StatusCodeSelector`] prefilled with `ranges`.
#[must_use]
pub fn new_from(ranges: Vec<AcceptRange>) -> Self {
let mut selector = Self::new();
@ -72,7 +76,7 @@ impl AcceptSelector {
selector
}
/// Adds a range of accepted HTTP status codes to this [`AcceptSelector`].
/// Adds a range of HTTP status codes to this [`StatusCodeSelector`].
/// This method merges the new and existing ranges if they overlap.
pub fn add_range(&mut self, range: AcceptRange) -> &mut Self {
// Merge with previous range if possible
@ -88,7 +92,7 @@ impl AcceptSelector {
self
}
/// Returns whether this [`AcceptSelector`] contains `value`.
/// Returns whether this [`StatusCodeSelector`] contains `value`.
#[must_use]
pub fn contains(&self, value: u16) -> bool {
self.ranges.iter().any(|range| range.contains(value))
@ -115,10 +119,10 @@ impl AcceptSelector {
}
}
struct AcceptSelectorVisitor;
struct StatusCodeSelectorVisitor;
impl<'de> Visitor<'de> for AcceptSelectorVisitor {
type Value = AcceptSelector;
impl<'de> Visitor<'de> for StatusCodeSelectorVisitor {
type Value = StatusCodeSelector;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a string or a sequence of strings")
@ -128,7 +132,7 @@ impl<'de> Visitor<'de> for AcceptSelectorVisitor {
where
E: serde::de::Error,
{
AcceptSelector::from_str(v).map_err(serde::de::Error::custom)
StatusCodeSelector::from_str(v).map_err(serde::de::Error::custom)
}
fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
@ -136,7 +140,7 @@ impl<'de> Visitor<'de> for AcceptSelectorVisitor {
E: serde::de::Error,
{
let value = u16::try_from(v).map_err(serde::de::Error::custom)?;
Ok(AcceptSelector::new_from(vec![AcceptRange::new(
Ok(StatusCodeSelector::new_from(vec![AcceptRange::new(
value, value,
)]))
}
@ -145,7 +149,7 @@ impl<'de> Visitor<'de> for AcceptSelectorVisitor {
where
A: serde::de::SeqAccess<'de>,
{
let mut selector = AcceptSelector::new();
let mut selector = StatusCodeSelector::new();
while let Some(v) = seq.next_element::<toml::Value>()? {
if let Some(v) = v.as_integer() {
let value = u16::try_from(v).map_err(serde::de::Error::custom)?;
@ -167,12 +171,12 @@ impl<'de> Visitor<'de> for AcceptSelectorVisitor {
}
}
impl<'de> Deserialize<'de> for AcceptSelector {
impl<'de> Deserialize<'de> for StatusCodeSelector {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_any(AcceptSelectorVisitor)
deserializer.deserialize_any(StatusCodeSelectorVisitor)
}
}
@ -192,7 +196,7 @@ mod test {
#[case] invalid_values: Vec<u16>,
#[case] length: usize,
) {
let selector = AcceptSelector::from_str(input).unwrap();
let selector = StatusCodeSelector::from_str(input).unwrap();
assert_eq!(selector.len(), length);
for valid in valid_values {
@ -220,7 +224,7 @@ mod test {
) {
#[derive(Deserialize)]
struct Config {
accept: AcceptSelector,
accept: StatusCodeSelector,
}
let config: Config = toml::from_str(input).unwrap();
@ -239,7 +243,7 @@ mod test {
#[case("100..=150,200..=300", "100..=150,200..=300")]
#[case("100..=150,300", "100..=150,300..=300")]
fn test_display(#[case] input: &str, #[case] display: &str) {
let selector = AcceptSelector::from_str(input).unwrap();
let selector = StatusCodeSelector::from_str(input).unwrap();
assert_eq!(selector.to_string(), display);
}
}