From 8da4592e2ac3fa4aeccdd48823792a52fc616194 Mon Sep 17 00:00:00 2001 From: Thomas Zahner Date: Wed, 13 Mar 2024 08:25:01 +0100 Subject: [PATCH] Introduce early exit in chain --- lychee-lib/src/chain/mod.rs | 74 ++++++++++++++++++++++++------------ lychee-lib/src/client.rs | 12 ++++-- lychee-lib/src/quirks/mod.rs | 29 ++++++++------ 3 files changed, 76 insertions(+), 39 deletions(-) diff --git a/lychee-lib/src/chain/mod.rs b/lychee-lib/src/chain/mod.rs index 7b9a564..8ea86b9 100644 --- a/lychee-lib/src/chain/mod.rs +++ b/lychee-lib/src/chain/mod.rs @@ -1,35 +1,46 @@ +use crate::{BasicAuthCredentials, Response}; use core::fmt::Debug; use headers::authorization::Credentials; use http::header::AUTHORIZATION; use reqwest::Request; -use crate::BasicAuthCredentials; +#[derive(Debug, PartialEq)] +pub(crate) enum ChainResult { + Chained(T), + EarlyExit(R), +} -pub(crate) type RequestChain = Chain; +pub(crate) type RequestChain = Chain; #[derive(Debug)] -pub struct Chain(Vec + Send>>); +pub struct Chain(Vec + Send>>); -impl Chain { - pub(crate) fn new(values: Vec + Send>>) -> Self { +impl Chain { + pub(crate) fn new(values: Vec + Send>>) -> Self { Self(values) } - pub(crate) fn traverse(&mut self, mut input: T) -> T { + pub(crate) fn traverse(&mut self, mut input: T) -> ChainResult { + use ChainResult::*; for e in self.0.iter_mut() { - input = e.handle(input) + match e.handle(input) { + Chained(r) => input = r, + EarlyExit(r) => { + return EarlyExit(r); + } + } } - input + Chained(input) } - pub(crate) fn push(&mut self, value: Box + Send>) { + pub(crate) fn push(&mut self, value: Box + Send>) { self.0.push(value); } } -pub(crate) trait Chainable: Debug { - fn handle(&mut self, input: T) -> T; +pub(crate) trait Chainable: Debug { + fn handle(&mut self, input: T) -> ChainResult; } #[derive(Debug)] @@ -43,36 +54,51 @@ impl BasicAuth { } } -impl Chainable for BasicAuth { - fn handle(&mut self, mut request: Request) -> Request { +impl Chainable for BasicAuth { + fn handle(&mut self, mut request: Request) -> ChainResult { request.headers_mut().append( AUTHORIZATION, self.credentials.to_authorization().0.encode(), ); - request + ChainResult::Chained(request) } } mod test { - use super::Chainable; + use super::{ChainResult, ChainResult::*, Chainable}; #[derive(Debug)] struct Add(i64); - #[derive(Debug)] - struct Request(i64); + #[derive(Debug, PartialEq, Eq)] + struct Result(i64); - impl Chainable for Add { - fn handle(&mut self, req: Request) -> Request { - Request(req.0 + self.0) + impl Chainable for Add { + fn handle(&mut self, req: Result) -> ChainResult { + let added = req.0 + self.0; + if added > 100 { + EarlyExit(Result(req.0)) + } else { + Chained(Result(added)) + } } } #[test] - fn example_chain() { + fn simple_chain() { use super::Chain; - let mut chain: Chain = Chain::new(vec![Box::new(Add(10)), Box::new(Add(-3))]); - let result = chain.traverse(Request(0)); - assert_eq!(result.0, 7); + let mut chain: Chain = + Chain::new(vec![Box::new(Add(10)), Box::new(Add(-3))]); + let result = chain.traverse(Result(0)); + assert_eq!(result, Chained(Result(7))); + } + + #[test] + fn early_exit_chain() { + use super::Chain; + let mut chain: Chain = + Chain::new(vec![Box::new(Add(80)), Box::new(Add(30)), Box::new(Add(1))]); + let result = chain.traverse(Result(0)); + assert_eq!(result, EarlyExit(Result(80))); } } diff --git a/lychee-lib/src/client.rs b/lychee-lib/src/client.rs index 00c11b6..3e2d3e3 100644 --- a/lychee-lib/src/client.rs +++ b/lychee-lib/src/client.rs @@ -30,6 +30,7 @@ use secrecy::{ExposeSecret, SecretString}; use typed_builder::TypedBuilder; use crate::{ + chain::ChainResult::*, chain::{BasicAuth, Chain, RequestChain}, filter::{Excludes, Filter, Includes}, quirks::Quirks, @@ -654,11 +655,14 @@ impl Client { Err(e) => return e.into(), }; - let request = request_chain.traverse(request); + let result = request_chain.traverse(request); - match self.reqwest_client.execute(request).await { - Ok(ref response) => Status::new(response, self.accepted.clone()), - Err(e) => e.into(), + match result { + EarlyExit(r) => r.1.status, + Chained(r) => match self.reqwest_client.execute(r).await { + Ok(ref response) => Status::new(response, self.accepted.clone()), + Err(e) => e.into(), + }, } } diff --git a/lychee-lib/src/quirks/mod.rs b/lychee-lib/src/quirks/mod.rs index 76ffe8f..e93d392 100644 --- a/lychee-lib/src/quirks/mod.rs +++ b/lychee-lib/src/quirks/mod.rs @@ -1,4 +1,7 @@ -use crate::chain::Chainable; +use crate::{ + chain::{ChainResult, Chainable}, + Response, +}; use header::HeaderValue; use http::header; use once_cell::sync::Lazy; @@ -72,11 +75,11 @@ impl Default for Quirks { } } -impl Chainable for Quirks { - /// Handle quirks in a given request. Only the first quirk regex pattern +impl Quirks { + /// Apply quirks to a given request. Only the first quirk regex pattern /// matching the URL will be applied. The rest will be discarded for /// simplicity reasons. This limitation might be lifted in the future. - fn handle(&mut self, request: Request) -> Request { + pub(crate) fn apply(&self, request: Request) -> Request { for quirk in &self.quirks { if quirk.pattern.is_match(request.url().as_str()) { return (quirk.rewrite)(request); @@ -87,14 +90,18 @@ impl Chainable for Quirks { } } +impl Chainable for Quirks { + fn handle(&mut self, input: Request) -> ChainResult { + ChainResult::Chained(self.apply(input)) + } +} + #[cfg(test)] mod tests { use header::HeaderValue; use http::{header, Method}; use reqwest::{Request, Url}; - use crate::chain::Chainable; - use super::Quirks; #[derive(Debug)] @@ -116,7 +123,7 @@ mod tests { fn test_cratesio_request() { let url = Url::parse("https://crates.io/crates/lychee").unwrap(); let request = Request::new(Method::GET, url); - let modified = Quirks::default().handle(request); + let modified = Quirks::default().apply(request); assert_eq!( modified.headers().get(header::ACCEPT).unwrap(), @@ -128,7 +135,7 @@ mod tests { fn test_youtube_video_request() { let url = Url::parse("https://www.youtube.com/watch?v=NlKuICiT470&list=PLbWDhxwM_45mPVToqaIZNbZeIzFchsKKQ&index=7").unwrap(); let request = Request::new(Method::GET, url); - let modified = Quirks::default().handle(request); + let modified = Quirks::default().apply(request); let expected_url = Url::parse("https://img.youtube.com/vi/NlKuICiT470/0.jpg").unwrap(); assert_eq!( @@ -141,7 +148,7 @@ mod tests { fn test_youtube_video_shortlink_request() { let url = Url::parse("https://youtu.be/Rvu7N4wyFpk?t=42").unwrap(); let request = Request::new(Method::GET, url); - let modified = Quirks::default().handle(request); + let modified = Quirks::default().apply(request); let expected_url = Url::parse("https://img.youtube.com/vi/Rvu7N4wyFpk/0.jpg").unwrap(); assert_eq!( @@ -154,7 +161,7 @@ mod tests { fn test_non_video_youtube_url_untouched() { let url = Url::parse("https://www.youtube.com/channel/UCaYhcUwRBNscFNUKTjgPFiA").unwrap(); let request = Request::new(Method::GET, url.clone()); - let modified = Quirks::default().handle(request); + let modified = Quirks::default().apply(request); assert_eq!(MockRequest(modified), MockRequest::new(Method::GET, url)); } @@ -163,7 +170,7 @@ mod tests { fn test_no_quirk_applied() { let url = Url::parse("https://endler.dev").unwrap(); let request = Request::new(Method::GET, url.clone()); - let modified = Quirks::default().handle(request); + let modified = Quirks::default().apply(request); assert_eq!(MockRequest(modified), MockRequest::new(Method::GET, url)); }