Update RequestChain & add chain to client

This commit is contained in:
Thomas Zahner 2024-03-13 14:58:50 +01:00
parent 667105e13e
commit 402482ca01
4 changed files with 71 additions and 14 deletions

View file

@ -1,21 +1,33 @@
use core::fmt::Debug;
use crate::Status;
#[derive(Debug, PartialEq)]
pub(crate) enum ChainResult<T, R> {
Chained(T),
EarlyExit(R),
}
pub(crate) type RequestChain = Chain<reqwest::Request, crate::Response>;
pub(crate) type RequestChain = Chain<reqwest::Request, Status>;
#[derive(Debug)]
pub struct Chain<T, R>(Vec<Box<dyn Chainable<T, R> + Send>>);
impl<T, R> Default for Chain<T, R> {
fn default() -> Self {
Self(vec![])
}
}
impl<T, R> Chain<T, R> {
pub(crate) fn new(values: Vec<Box<dyn Chainable<T, R> + Send>>) -> Self {
Self(values)
}
pub(crate) fn append(&mut self, other: &mut Chain<T, R>) {
self.0.append(&mut other.0);
}
pub(crate) fn traverse(&mut self, mut input: T) -> ChainResult<T, R> {
use ChainResult::*;
for e in self.0.iter_mut() {

View file

@ -13,7 +13,12 @@
clippy::default_trait_access,
clippy::used_underscore_binding
)]
use std::{collections::HashSet, path::Path, sync::Arc, time::Duration};
use std::{
collections::HashSet,
path::Path,
sync::{Arc, Mutex},
time::Duration,
};
#[cfg(all(feature = "email-check", feature = "native-tls"))]
use check_if_email_exists::{check_email, CheckEmailInput, Reachable};
@ -30,7 +35,7 @@ use secrecy::{ExposeSecret, SecretString};
use typed_builder::TypedBuilder;
use crate::{
chain::{Chain, ChainResult::*, RequestChain},
chain::{Chain, ChainResult::{Chained, EarlyExit}, RequestChain},
filter::{Excludes, Filter, Includes},
quirks::Quirks,
remap::Remaps,
@ -274,6 +279,8 @@ pub struct ClientBuilder {
/// Enable the checking of fragments in links.
include_fragments: bool,
request_chain: Arc<Mutex<RequestChain>>,
}
impl Default for ClientBuilder {
@ -386,6 +393,7 @@ impl ClientBuilder {
require_https: self.require_https,
include_fragments: self.include_fragments,
fragment_checker: FragmentChecker::new(),
request_chain: self.request_chain,
})
}
}
@ -435,6 +443,8 @@ pub struct Client {
/// Caches Fragments
fragment_checker: FragmentChecker,
request_chain: Arc<Mutex<RequestChain>>,
}
impl Client {
@ -520,6 +530,8 @@ impl Client {
) -> Result<Status> {
let quirks = Quirks::default();
let mut request_chain: RequestChain = Chain::new(vec![Box::new(quirks)]);
request_chain.append(&mut self.request_chain.lock().unwrap());
if let Some(c) = credentials {
request_chain.push(Box::new(c.clone()));
}
@ -656,7 +668,7 @@ impl Client {
let result = request_chain.traverse(request);
match result {
EarlyExit(r) => r.1.status,
EarlyExit(status) => status,
Chained(r) => match self.reqwest_client.execute(r).await {
Ok(ref response) => Status::new(response, self.accepted.clone()),
Err(e) => e.into(),
@ -754,6 +766,7 @@ where
mod tests {
use std::{
fs::File,
sync::{Arc, Mutex},
time::{Duration, Instant},
};
@ -763,7 +776,12 @@ mod tests {
use wiremock::matchers::path;
use super::ClientBuilder;
use crate::{mock_server, test_utils::get_mock_client_response, Request, Uri};
use crate::{
chain::{ChainResult, Chainable, RequestChain},
mock_server,
test_utils::get_mock_client_response,
Request, Status, Uri,
};
#[tokio::test]
async fn test_nonexistent() {
@ -1092,4 +1110,33 @@ mod tests {
assert!(res.status().is_unsupported());
}
}
#[tokio::test]
async fn test_chain() {
use reqwest::Request;
use ChainResult::EarlyExit;
#[derive(Debug)]
struct ExampleHandler();
impl Chainable<Request, crate::Status> for ExampleHandler {
fn handle(&mut self, _: Request) -> ChainResult<Request, crate::Status> {
EarlyExit(Status::Excluded)
}
}
let chain = Arc::new(Mutex::new(RequestChain::new(vec![Box::new(
ExampleHandler {},
)])));
let client = ClientBuilder::builder()
.request_chain(chain)
.build()
.client()
.unwrap();
let result = client.check("http://example.com");
let res = result.await.unwrap();
assert_eq!(res.status(), &Status::Excluded);
}
}

View file

@ -1,6 +1,6 @@
use crate::{
chain::{ChainResult, Chainable},
Response,
Status,
};
use header::HeaderValue;
use http::header;
@ -90,8 +90,8 @@ impl Quirks {
}
}
impl Chainable<Request, Response> for Quirks {
fn handle(&mut self, input: Request) -> ChainResult<Request, Response> {
impl Chainable<Request, Status> for Quirks {
fn handle(&mut self, input: Request) -> ChainResult<Request, Status> {
ChainResult::Chained(self.apply(input))
}
}

View file

@ -7,10 +7,8 @@ use reqwest::Request;
use serde::Deserialize;
use thiserror::Error;
use crate::{
chain::{ChainResult, Chainable},
Response,
};
use crate::chain::{ChainResult, Chainable};
use crate::Status;
#[derive(Copy, Clone, Debug, Error, PartialEq)]
pub enum BasicAuthCredentialsParseError {
@ -75,8 +73,8 @@ impl BasicAuthCredentials {
}
}
impl Chainable<Request, Response> for BasicAuthCredentials {
fn handle(&mut self, mut request: Request) -> ChainResult<Request, Response> {
impl Chainable<Request, Status> for BasicAuthCredentials {
fn handle(&mut self, mut request: Request) -> ChainResult<Request, Status> {
request
.headers_mut()
.append(AUTHORIZATION, self.to_authorization().0.encode());