Introduce early exit in chain

This commit is contained in:
Thomas Zahner 2024-03-13 08:25:01 +01:00
parent 7783cdfe46
commit 8da4592e2a
3 changed files with 76 additions and 39 deletions

View file

@ -1,35 +1,46 @@
use crate::{BasicAuthCredentials, Response};
use core::fmt::Debug;
use headers::authorization::Credentials;
use http::header::AUTHORIZATION;
use reqwest::Request;
use crate::BasicAuthCredentials;
#[derive(Debug, PartialEq)]
pub(crate) enum ChainResult<T, R> {
Chained(T),
EarlyExit(R),
}
pub(crate) type RequestChain = Chain<reqwest::Request>;
pub(crate) type RequestChain = Chain<Request, Response>;
#[derive(Debug)]
pub struct Chain<T>(Vec<Box<dyn Chainable<T> + Send>>);
pub struct Chain<T, R>(Vec<Box<dyn Chainable<T, R> + Send>>);
impl<T> Chain<T> {
pub(crate) fn new(values: Vec<Box<dyn Chainable<T> + Send>>) -> Self {
impl<T, R> Chain<T, R> {
pub(crate) fn new(values: Vec<Box<dyn Chainable<T, R> + Send>>) -> Self {
Self(values)
}
pub(crate) fn traverse(&mut self, mut input: T) -> T {
pub(crate) fn traverse(&mut self, mut input: T) -> ChainResult<T, R> {
use ChainResult::*;
for e in self.0.iter_mut() {
input = e.handle(input)
match e.handle(input) {
Chained(r) => input = r,
EarlyExit(r) => {
return EarlyExit(r);
}
}
}
input
Chained(input)
}
pub(crate) fn push(&mut self, value: Box<dyn Chainable<T> + Send>) {
pub(crate) fn push(&mut self, value: Box<dyn Chainable<T, R> + Send>) {
self.0.push(value);
}
}
pub(crate) trait Chainable<T>: Debug {
fn handle(&mut self, input: T) -> T;
pub(crate) trait Chainable<T, R>: Debug {
fn handle(&mut self, input: T) -> ChainResult<T, R>;
}
#[derive(Debug)]
@ -43,36 +54,51 @@ impl BasicAuth {
}
}
impl Chainable<Request> for BasicAuth {
fn handle(&mut self, mut request: Request) -> Request {
impl Chainable<Request, Response> for BasicAuth {
fn handle(&mut self, mut request: Request) -> ChainResult<Request, Response> {
request.headers_mut().append(
AUTHORIZATION,
self.credentials.to_authorization().0.encode(),
);
request
ChainResult::Chained(request)
}
}
mod test {
use super::Chainable;
use super::{ChainResult, ChainResult::*, Chainable};
#[derive(Debug)]
struct Add(i64);
#[derive(Debug)]
struct Request(i64);
#[derive(Debug, PartialEq, Eq)]
struct Result(i64);
impl Chainable<Request> for Add {
fn handle(&mut self, req: Request) -> Request {
Request(req.0 + self.0)
impl Chainable<Result, Result> for Add {
fn handle(&mut self, req: Result) -> ChainResult<Result, Result> {
let added = req.0 + self.0;
if added > 100 {
EarlyExit(Result(req.0))
} else {
Chained(Result(added))
}
}
}
#[test]
fn example_chain() {
fn simple_chain() {
use super::Chain;
let mut chain: Chain<Request> = Chain::new(vec![Box::new(Add(10)), Box::new(Add(-3))]);
let result = chain.traverse(Request(0));
assert_eq!(result.0, 7);
let mut chain: Chain<Result, Result> =
Chain::new(vec![Box::new(Add(10)), Box::new(Add(-3))]);
let result = chain.traverse(Result(0));
assert_eq!(result, Chained(Result(7)));
}
#[test]
fn early_exit_chain() {
use super::Chain;
let mut chain: Chain<Result, Result> =
Chain::new(vec![Box::new(Add(80)), Box::new(Add(30)), Box::new(Add(1))]);
let result = chain.traverse(Result(0));
assert_eq!(result, EarlyExit(Result(80)));
}
}

View file

@ -30,6 +30,7 @@ use secrecy::{ExposeSecret, SecretString};
use typed_builder::TypedBuilder;
use crate::{
chain::ChainResult::*,
chain::{BasicAuth, Chain, RequestChain},
filter::{Excludes, Filter, Includes},
quirks::Quirks,
@ -654,11 +655,14 @@ impl Client {
Err(e) => return e.into(),
};
let request = request_chain.traverse(request);
let result = request_chain.traverse(request);
match self.reqwest_client.execute(request).await {
Ok(ref response) => Status::new(response, self.accepted.clone()),
Err(e) => e.into(),
match result {
EarlyExit(r) => r.1.status,
Chained(r) => match self.reqwest_client.execute(r).await {
Ok(ref response) => Status::new(response, self.accepted.clone()),
Err(e) => e.into(),
},
}
}

View file

@ -1,4 +1,7 @@
use crate::chain::Chainable;
use crate::{
chain::{ChainResult, Chainable},
Response,
};
use header::HeaderValue;
use http::header;
use once_cell::sync::Lazy;
@ -72,11 +75,11 @@ impl Default for Quirks {
}
}
impl Chainable<Request> for Quirks {
/// Handle quirks in a given request. Only the first quirk regex pattern
impl Quirks {
/// Apply quirks to a given request. Only the first quirk regex pattern
/// matching the URL will be applied. The rest will be discarded for
/// simplicity reasons. This limitation might be lifted in the future.
fn handle(&mut self, request: Request) -> Request {
pub(crate) fn apply(&self, request: Request) -> Request {
for quirk in &self.quirks {
if quirk.pattern.is_match(request.url().as_str()) {
return (quirk.rewrite)(request);
@ -87,14 +90,18 @@ impl Chainable<Request> for Quirks {
}
}
impl Chainable<Request, Response> for Quirks {
fn handle(&mut self, input: Request) -> ChainResult<Request, Response> {
ChainResult::Chained(self.apply(input))
}
}
#[cfg(test)]
mod tests {
use header::HeaderValue;
use http::{header, Method};
use reqwest::{Request, Url};
use crate::chain::Chainable;
use super::Quirks;
#[derive(Debug)]
@ -116,7 +123,7 @@ mod tests {
fn test_cratesio_request() {
let url = Url::parse("https://crates.io/crates/lychee").unwrap();
let request = Request::new(Method::GET, url);
let modified = Quirks::default().handle(request);
let modified = Quirks::default().apply(request);
assert_eq!(
modified.headers().get(header::ACCEPT).unwrap(),
@ -128,7 +135,7 @@ mod tests {
fn test_youtube_video_request() {
let url = Url::parse("https://www.youtube.com/watch?v=NlKuICiT470&list=PLbWDhxwM_45mPVToqaIZNbZeIzFchsKKQ&index=7").unwrap();
let request = Request::new(Method::GET, url);
let modified = Quirks::default().handle(request);
let modified = Quirks::default().apply(request);
let expected_url = Url::parse("https://img.youtube.com/vi/NlKuICiT470/0.jpg").unwrap();
assert_eq!(
@ -141,7 +148,7 @@ mod tests {
fn test_youtube_video_shortlink_request() {
let url = Url::parse("https://youtu.be/Rvu7N4wyFpk?t=42").unwrap();
let request = Request::new(Method::GET, url);
let modified = Quirks::default().handle(request);
let modified = Quirks::default().apply(request);
let expected_url = Url::parse("https://img.youtube.com/vi/Rvu7N4wyFpk/0.jpg").unwrap();
assert_eq!(
@ -154,7 +161,7 @@ mod tests {
fn test_non_video_youtube_url_untouched() {
let url = Url::parse("https://www.youtube.com/channel/UCaYhcUwRBNscFNUKTjgPFiA").unwrap();
let request = Request::new(Method::GET, url.clone());
let modified = Quirks::default().handle(request);
let modified = Quirks::default().apply(request);
assert_eq!(MockRequest(modified), MockRequest::new(Method::GET, url));
}
@ -163,7 +170,7 @@ mod tests {
fn test_no_quirk_applied() {
let url = Url::parse("https://endler.dev").unwrap();
let request = Request::new(Method::GET, url.clone());
let modified = Quirks::default().handle(request);
let modified = Quirks::default().apply(request);
assert_eq!(MockRequest(modified), MockRequest::new(Method::GET, url));
}