Extract checking functionality & make chain async

This commit is contained in:
Thomas Zahner 2024-03-20 11:50:26 +01:00
parent 17e2911700
commit d92d3ba733
6 changed files with 93 additions and 64 deletions

View file

@ -25,7 +25,6 @@ impl<T, R> Clone for Chain<T, R> {
}
}
impl<T, R> Chain<T, R> {
pub(crate) fn new(values: Vec<Box<dyn Chainable<T, R> + Send>>) -> Self {
Self(Arc::new(Mutex::new(values)))
@ -55,7 +54,7 @@ impl<T, R> Chain<T, R> {
}
pub(crate) trait Chainable<T, R>: Debug {
fn chain(&mut self, input: T) -> ChainResult<T, R>;
async fn chain(&mut self, input: T) -> ChainResult<T, R>;
}
mod test {
@ -72,7 +71,7 @@ mod test {
struct Result(usize);
impl Chainable<Result, Result> for Add {
fn chain(&mut self, req: Result) -> ChainResult<Result, Result> {
async fn chain(&mut self, req: Result) -> ChainResult<Result, Result> {
let added = req.0 + self.0;
if added > 100 {
Done(Result(req.0))

65
lychee-lib/src/checker.rs Normal file
View file

@ -0,0 +1,65 @@
use crate::{
chain::{Chain, ChainResult, Chainable},
retry::RetryExt,
Status,
};
use http::StatusCode;
use reqwest::Request;
use std::{collections::HashSet, time::Duration};
#[derive(Debug)]
pub(crate) struct Checker {
retry_wait_time: Duration,
max_retries: u64,
reqwest_client: reqwest::Client,
accepted: Option<HashSet<StatusCode>>,
}
impl Checker {
pub(crate) fn new(
retry_wait_time: Duration,
max_retries: u64,
reqwest_client: reqwest::Client,
accepted: Option<HashSet<StatusCode>>,
) -> Self {
Self {
retry_wait_time,
max_retries,
reqwest_client,
accepted,
}
}
/// Retry requests up to `max_retries` times
/// with an exponential backoff.
pub(crate) async fn retry_request(&self, request: reqwest::Request) -> Status {
let mut retries: u64 = 0;
let mut wait_time = self.retry_wait_time;
let mut status = self.check_default(request.try_clone().unwrap()).await; // TODO: try_clone
while retries < self.max_retries {
if status.is_success() || !status.should_retry() {
return status;
}
retries += 1;
tokio::time::sleep(wait_time).await;
wait_time = wait_time.saturating_mul(2);
status = self.check_default(request.try_clone().unwrap()).await; // TODO: try_clone
}
status
}
/// Check a URI using [reqwest](https://github.com/seanmonstar/reqwest).
async fn check_default(&self, request: reqwest::Request) -> Status {
match self.reqwest_client.execute(request).await {
Ok(ref response) => Status::new(response, self.accepted.clone()),
Err(e) => e.into(),
}
}
}
impl Chainable<Request, Status> for Checker {
async fn chain(&mut self, input: Request) -> ChainResult<Request, Status> {
ChainResult::Done(self.retry_request(input).await)
}
}

View file

@ -13,12 +13,7 @@
clippy::default_trait_access,
clippy::used_underscore_binding
)]
use std::{
collections::HashSet,
path::Path,
sync::Arc,
time::Duration,
};
use std::{collections::HashSet, path::Path, sync::Arc, time::Duration};
#[cfg(all(feature = "email-check", feature = "native-tls"))]
use check_if_email_exists::{check_email, CheckEmailInput, Reachable};
@ -35,11 +30,8 @@ use secrecy::{ExposeSecret, SecretString};
use typed_builder::TypedBuilder;
use crate::{
chain::{
Chain,
ChainResult::{Next, Done},
RequestChain,
},
chain::{Chain, RequestChain},
checker::Checker,
filter::{Excludes, Filter, Includes},
quirks::Quirks,
remap::Remaps,
@ -587,7 +579,23 @@ impl Client {
return Status::Unsupported(ErrorKind::InvalidURI(uri.clone()));
}
let status = self.retry_request(uri, request_chain).await;
let request = self
.reqwest_client
.request(self.method.clone(), uri.as_str())
.build();
let request = match request {
Ok(r) => r,
Err(e) => return e.into(),
};
let checker = Checker::new(
self.retry_wait_time,
self.max_retries,
self.reqwest_client.clone(),
self.accepted.clone(),
);
let status = checker.retry_request(request).await;
if status.is_success() {
return status;
}
@ -607,25 +615,6 @@ impl Client {
status
}
/// Retry requests up to `max_retries` times
/// with an exponential backoff.
async fn retry_request(&self, uri: &Uri, request_chain: &mut RequestChain) -> Status {
let mut retries: u64 = 0;
let mut wait_time = self.retry_wait_time;
let mut status = self.check_default(uri, request_chain).await;
while retries < self.max_retries {
if status.is_success() || !status.should_retry() {
return status;
}
retries += 1;
tokio::time::sleep(wait_time).await;
wait_time = wait_time.saturating_mul(2);
status = self.check_default(uri, request_chain).await;
}
status
}
/// Check a `uri` hosted on `GitHub` via the GitHub API.
///
/// # Caveats
@ -661,29 +650,6 @@ impl Client {
Status::Ok(StatusCode::OK)
}
/// Check a URI using [reqwest](https://github.com/seanmonstar/reqwest).
async fn check_default(&self, uri: &Uri, request_chain: &mut RequestChain) -> Status {
let request = self
.reqwest_client
.request(self.method.clone(), uri.as_str())
.build();
let request = match request {
Ok(r) => r,
Err(e) => return e.into(),
};
let result = request_chain.traverse(request);
match result {
Done(status) => status,
Next(r) => match self.reqwest_client.execute(r).await {
Ok(ref response) => Status::new(response, self.accepted.clone()),
Err(e) => e.into(),
},
}
}
/// Check a `file` URI.
pub async fn check_file(&self, uri: &Uri) -> Status {
let Ok(path) = uri.url.to_file_path() else {
@ -1120,14 +1086,12 @@ mod tests {
struct ExampleHandler();
impl Chainable<Request, Status> for ExampleHandler {
fn chain(&mut self, _: Request) -> ChainResult<Request, Status> {
async fn chain(&mut self, _: Request) -> ChainResult<Request, Status> {
ChainResult::Done(Status::Excluded)
}
}
let chain = RequestChain::new(vec![Box::new(
ExampleHandler {},
)]);
let chain = RequestChain::new(vec![Box::new(ExampleHandler {})]);
let client = ClientBuilder::builder()
.request_chain(chain)

View file

@ -51,10 +51,11 @@
doc_comment::doctest!("../../README.md");
mod basic_auth;
mod chain;
mod checker;
mod client;
/// A pool of clients, to handle concurrent checks
pub mod collector;
mod chain;
mod quirks;
mod retry;
mod types;

View file

@ -91,7 +91,7 @@ impl Quirks {
}
impl Chainable<Request, Status> for Quirks {
fn chain(&mut self, input: Request) -> ChainResult<Request, Status> {
async fn chain(&mut self, input: Request) -> ChainResult<Request, Status> {
ChainResult::Next(self.apply(input))
}
}

View file

@ -74,7 +74,7 @@ impl BasicAuthCredentials {
}
impl Chainable<Request, Status> for BasicAuthCredentials {
fn chain(&mut self, mut request: Request) -> ChainResult<Request, Status> {
async fn chain(&mut self, mut request: Request) -> ChainResult<Request, Status> {
request
.headers_mut()
.append(AUTHORIZATION, self.to_authorization().0.encode());