Use StatusCodeSelector default as default accepted StatusCodes

This commit is contained in:
Thomas Zahner 2025-05-23 12:52:37 +02:00
parent c2a0908747
commit 74961d2470
4 changed files with 33 additions and 27 deletions

View file

@ -17,6 +17,7 @@ pub(crate) fn create(cfg: &Config, cookie_jar: Option<&Arc<CookieStoreMutex>>) -
let remaps = parse_remaps(&cfg.remap)?;
let includes = RegexSet::new(&cfg.include)?;
let excludes = RegexSet::new(&cfg.exclude)?;
let accepted: HashSet<StatusCode> = cfg.accept.clone().try_into()?;
// Offline mode overrides the scheme
let schemes = if cfg.offline {
@ -25,14 +26,6 @@ pub(crate) fn create(cfg: &Config, cookie_jar: Option<&Arc<CookieStoreMutex>>) -
cfg.scheme.clone()
};
let accepted = cfg
.accept
.clone()
.into_set()
.iter()
.map(|value| StatusCode::from_u16(*value))
.collect::<Result<HashSet<_>, _>>()?;
let headers = HeaderMap::from_header_pairs(&cfg.header)?;
ClientBuilder::builder()

View file

@ -49,7 +49,7 @@ where
let client = params.client;
let cache = params.cache;
let cache_exclude_status = params.cfg.cache_exclude_status.into_set();
let accept = params.cfg.accept.into_set();
let accept = params.cfg.accept.into();
let pb = if params.cfg.no_progress || params.cfg.verbose.log_level() >= log::Level::Info {
None

View file

@ -238,9 +238,7 @@ pub struct ClientBuilder {
/// Set of accepted return codes / status codes.
///
/// Unmatched return codes/ status codes are deemed as errors.
///
/// TODO: accept all "valid" status codes by default. Maybe use `AcceptRange`?
#[builder(default = HashSet::from([StatusCode::OK]))]
#[builder(default = HashSet::try_from(StatusCodeSelector::default()).unwrap())]
accepted: HashSet<StatusCode>,
/// Response timeout per request in seconds.

View file

@ -1,5 +1,6 @@
use std::{collections::HashSet, fmt::Display, str::FromStr};
use http::StatusCode;
use serde::{Deserialize, de::Visitor};
use thiserror::Error;
@ -98,27 +99,33 @@ impl StatusCodeSelector {
self.ranges.iter().any(|range| range.contains(value))
}
/// Consumes self and creates a [`HashSet`] which contains all
/// accepted status codes.
#[must_use]
pub fn into_set(self) -> HashSet<u16> {
let mut set = HashSet::new();
for range in self.ranges {
for value in range.inner() {
set.insert(value);
}
}
set
}
#[cfg(test)]
pub(crate) fn len(&self) -> usize {
self.ranges.len()
}
}
impl From<StatusCodeSelector> for HashSet<u16> {
fn from(value: StatusCodeSelector) -> Self {
value
.ranges
.into_iter()
.flat_map(|range| range.inner().collect::<Vec<_>>())
.collect()
}
}
impl TryFrom<StatusCodeSelector> for HashSet<StatusCode> {
type Error = http::status::InvalidStatusCode;
fn try_from(value: StatusCodeSelector) -> Result<Self, Self::Error> {
<HashSet<u16>>::from(value)
.into_iter()
.map(StatusCode::from_u16)
.collect()
}
}
struct StatusCodeSelectorVisitor;
impl<'de> Visitor<'de> for StatusCodeSelectorVisitor {
@ -183,6 +190,7 @@ impl<'de> Deserialize<'de> for StatusCodeSelector {
#[cfg(test)]
mod test {
use super::*;
use http::status::InvalidStatusCode;
use rstest::rstest;
#[rstest]
@ -246,4 +254,11 @@ mod test {
let selector = StatusCodeSelector::from_str(input).unwrap();
assert_eq!(selector.to_string(), display);
}
#[rstest]
#[case("100..=102,200..202", HashSet::from([100, 101, 102, 200, 201]))]
fn test_into_u16_set(#[case] input: &str, #[case] expected: HashSet<u16>) {
let actual: HashSet<u16> = StatusCodeSelector::from_str(input).unwrap().into();
assert_eq!(actual, expected);
}
}