Create ClientRequestChain helper structure to combine multiple chains

This commit is contained in:
Thomas Zahner 2024-04-05 08:37:09 +02:00
parent 41e7f88da4
commit 93afae54bb
2 changed files with 40 additions and 27 deletions

View file

@ -34,7 +34,7 @@ impl<T, R> Chain<T, R> {
Self(Arc::new(Mutex::new(values)))
}
pub(crate) async fn traverse(&mut self, mut input: T) -> ChainResult<T, R> {
pub(crate) async fn traverse(&self, mut input: T) -> ChainResult<T, R> {
use ChainResult::{Done, Next};
for e in self.0.lock().await.iter_mut() {
match e.chain(input).await {
@ -47,11 +47,6 @@ impl<T, R> Chain<T, R> {
Next(input)
}
// TODO: probably remove
pub(crate) fn into_inner(self) -> InnerChain<T, R> {
Arc::try_unwrap(self.0).expect("Arc still has multiple owners").into_inner()
}
}
#[async_trait]
@ -59,6 +54,32 @@ pub(crate) trait Chainable<T, R>: Debug {
async fn chain(&mut self, input: T) -> ChainResult<T, R>;
}
#[derive(Debug)]
pub(crate) struct ClientRequestChain<'a> {
chains: Vec<&'a RequestChain>,
}
impl<'a> ClientRequestChain<'a> {
pub(crate) fn new(chains: Vec<&'a RequestChain>) -> Self {
Self { chains }
}
pub(crate) async fn traverse(&self, mut input: reqwest::Request) -> Status {
use ChainResult::{Done, Next};
for e in &self.chains {
match e.traverse(input).await {
Next(r) => input = r,
Done(r) => {
return r;
}
}
}
// consider as excluded if no chain element has converted it to a done
Status::Excluded
}
}
mod test {
use super::{
ChainResult,
@ -88,7 +109,7 @@ mod test {
#[tokio::test]
async fn simple_chain() {
use super::Chain;
let mut chain: Chain<Result, Result> = Chain::new(vec![Box::new(Add(7)), Box::new(Add(3))]);
let chain: Chain<Result, Result> = Chain::new(vec![Box::new(Add(7)), Box::new(Add(3))]);
let result = chain.traverse(Result(0)).await;
assert_eq!(result, Next(Result(10)));
}
@ -96,7 +117,7 @@ mod test {
#[tokio::test]
async fn early_exit_chain() {
use super::Chain;
let mut chain: Chain<Result, Result> =
let chain: Chain<Result, Result> =
Chain::new(vec![Box::new(Add(80)), Box::new(Add(30)), Box::new(Add(1))]);
let result = chain.traverse(Result(0)).await;
assert_eq!(result, Done(Result(80)));

View file

@ -30,7 +30,7 @@ use secrecy::{ExposeSecret, SecretString};
use typed_builder::TypedBuilder;
use crate::{
chain::{Chain, ChainResult, RequestChain},
chain::{Chain, ClientRequestChain, RequestChain},
checker::Checker,
filter::{Excludes, Filter, Includes},
quirks::Quirks,
@ -279,7 +279,7 @@ pub struct ClientBuilder {
/// can modify the request. A chained item can also decide to exit
/// early and return a status, so that subsequent chain items are
/// skipped and no HTTP request is ever made.
request_chain: RequestChain,
plugin_request_chain: RequestChain,
}
impl Default for ClientBuilder {
@ -392,7 +392,7 @@ impl ClientBuilder {
require_https: self.require_https,
include_fragments: self.include_fragments,
fragment_checker: FragmentChecker::new(),
request_chain: self.request_chain,
plugin_request_chain: self.plugin_request_chain,
})
}
}
@ -443,7 +443,7 @@ pub struct Client {
/// Caches Fragments
fragment_checker: FragmentChecker,
request_chain: RequestChain,
plugin_request_chain: RequestChain,
}
impl Client {
@ -487,7 +487,6 @@ impl Client {
}
let request_chain: RequestChain = Chain::new(vec![
// TODO: insert self.request_chain
Box::<Quirks>::default(),
Box::new(credentials),
Box::new(Checker::new(
@ -534,15 +533,11 @@ impl Client {
/// - The request failed.
/// - The response status code is not accepted.
/// - The URI cannot be converted to HTTPS.
pub async fn check_website(
&self,
uri: &Uri,
mut request_chain: RequestChain,
) -> Result<Status> {
match self.check_website_inner(uri, &mut request_chain).await {
pub async fn check_website(&self, uri: &Uri, request_chain: RequestChain) -> Result<Status> {
match self.check_website_inner(uri, &request_chain).await {
Status::Ok(code) if self.require_https && uri.scheme() == "http" => {
if self
.check_website_inner(&uri.to_https()?, &mut request_chain)
.check_website_inner(&uri.to_https()?, &request_chain)
.await
.is_success()
{
@ -567,7 +562,7 @@ impl Client {
/// - The URI is invalid.
/// - The request failed.
/// - The response status code is not accepted.
pub async fn check_website_inner(&self, uri: &Uri, request_chain: &mut RequestChain) -> Status {
pub async fn check_website_inner(&self, uri: &Uri, request_chain: &RequestChain) -> Status {
// Workaround for upstream reqwest panic
if validate_url(&uri.url) {
if matches!(uri.scheme(), "http" | "https") {
@ -592,11 +587,8 @@ impl Client {
Err(e) => return e.into(),
};
let status = match request_chain.traverse(request).await {
// consider as excluded if no chain element has converted it to a done
ChainResult::Next(_) => Status::Excluded,
ChainResult::Done(status) => status,
};
let chain = ClientRequestChain::new(vec![&self.plugin_request_chain, request_chain]);
let status = chain.traverse(request).await;
if status.is_success() {
return status;
@ -1098,7 +1090,7 @@ mod tests {
let chain = RequestChain::new(vec![Box::new(ExampleHandler {})]);
let client = ClientBuilder::builder()
.request_chain(chain)
.plugin_request_chain(chain)
.build()
.client()
.unwrap();