Pass down request_chain instead of credentials & add test

This commit is contained in:
Thomas Zahner 2024-03-12 14:48:38 +01:00
parent 237f690f18
commit 7783cdfe46
2 changed files with 72 additions and 42 deletions

View file

@ -1,23 +1,38 @@
use core::fmt::Debug;
use headers::authorization::Credentials;
use http::header::AUTHORIZATION;
use reqwest::Request;
use crate::BasicAuthCredentials;
pub(crate) type Chain<T> = Vec<Box<dyn Chainable<T> + Send>>;
pub(crate) type RequestChain = Chain<reqwest::Request>;
pub(crate) fn traverse<T>(chain: Chain<T>, mut input: T) -> T {
for mut e in chain {
input = e.handle(input)
#[derive(Debug)]
pub struct Chain<T>(Vec<Box<dyn Chainable<T> + Send>>);
impl<T> Chain<T> {
pub(crate) fn new(values: Vec<Box<dyn Chainable<T> + Send>>) -> Self {
Self(values)
}
input
pub(crate) fn traverse(&mut self, mut input: T) -> T {
for e in self.0.iter_mut() {
input = e.handle(input)
}
input
}
pub(crate) fn push(&mut self, value: Box<dyn Chainable<T> + Send>) {
self.0.push(value);
}
}
pub(crate) trait Chainable<T> {
pub(crate) trait Chainable<T>: Debug {
fn handle(&mut self, input: T) -> T;
}
#[derive(Debug)]
pub(crate) struct BasicAuth {
credentials: BasicAuthCredentials,
}
@ -41,8 +56,10 @@ impl Chainable<Request> for BasicAuth {
mod test {
use super::Chainable;
#[derive(Debug)]
struct Add(i64);
#[derive(Debug)]
struct Request(i64);
impl Chainable<Request> for Add {
@ -53,8 +70,9 @@ mod test {
#[test]
fn example_chain() {
let chain: crate::chain::Chain<Request> = vec![Box::new(Add(10)), Box::new(Add(-3))];
let result = crate::chain::traverse(chain, Request(0));
use super::Chain;
let mut chain: Chain<Request> = Chain::new(vec![Box::new(Add(10)), Box::new(Add(-3))]);
let result = chain.traverse(Request(0));
assert_eq!(result.0, 7);
}
}

View file

@ -30,14 +30,14 @@ use secrecy::{ExposeSecret, SecretString};
use typed_builder::TypedBuilder;
use crate::{
chain::{traverse, BasicAuth, Chain},
chain::{BasicAuth, Chain, RequestChain},
filter::{Excludes, Filter, Includes},
quirks::Quirks,
remap::Remaps,
retry::RetryExt,
types::uri::github::GithubUri,
utils::fragment_checker::FragmentChecker,
BasicAuthCredentials, ErrorKind, Request, Response, Result, Status, Uri,
ErrorKind, Request, Response, Result, Status, Uri,
};
#[cfg(all(feature = "email-check", feature = "native-tls"))]
@ -374,8 +374,6 @@ impl ClientBuilder {
include_mail: self.include_mail,
};
let quirks = Quirks::default();
Ok(Client {
reqwest_client,
github_client,
@ -386,7 +384,6 @@ impl ClientBuilder {
method: self.method,
accepted: self.accepted,
require_https: self.require_https,
quirks,
include_fragments: self.include_fragments,
fragment_checker: FragmentChecker::new(),
})
@ -433,9 +430,6 @@ pub struct Client {
/// This would treat unencrypted links as errors when HTTPS is available.
require_https: bool,
/// Override behaviors for certain known issues with special URIs.
quirks: Quirks,
/// Enable the checking of fragments in links.
include_fragments: bool,
@ -477,16 +471,23 @@ impl Client {
// ));
// }
self.remap(uri)?;
self.remap(uri)?; // todo: possible to do in request chain?
if self.is_excluded(uri) {
return Ok(Response::new(uri.clone(), Status::Excluded, source));
}
let quirks = Quirks::default();
let mut request_chain: RequestChain = Chain::new(vec![Box::new(quirks)]);
if let Some(c) = credentials {
request_chain.push(Box::new(BasicAuth::new(c.clone())));
}
let status = match uri.scheme() {
_ if uri.is_file() => self.check_file(uri).await,
_ if uri.is_mail() => self.check_mail(uri).await,
_ => self.check_website(uri, credentials).await?,
_ => self.check_website(uri, &mut request_chain).await?,
};
Ok(Response::new(uri.clone(), status, source))
@ -522,12 +523,12 @@ impl Client {
pub async fn check_website(
&self,
uri: &Uri,
credentials: &Option<BasicAuthCredentials>,
request_chain: &mut RequestChain,
) -> Result<Status> {
match self.check_website_inner(uri, credentials).await {
match self.check_website_inner(uri, request_chain).await {
Status::Ok(code) if self.require_https && uri.scheme() == "http" => {
if self
.check_website_inner(&uri.to_https()?, credentials)
.check_website_inner(&uri.to_https()?, request_chain)
.await
.is_success()
{
@ -552,11 +553,7 @@ impl Client {
/// - The URI is invalid.
/// - The request failed.
/// - The response status code is not accepted.
pub async fn check_website_inner(
&self,
uri: &Uri,
credentials: &Option<BasicAuthCredentials>,
) -> Status {
pub async fn check_website_inner(&self, uri: &Uri, request_chain: &mut RequestChain) -> Status {
// Workaround for upstream reqwest panic
if validate_url(&uri.url) {
if matches!(uri.scheme(), "http" | "https") {
@ -571,7 +568,7 @@ impl Client {
return Status::Unsupported(ErrorKind::InvalidURI(uri.clone()));
}
let status = self.retry_request(uri, credentials).await;
let status = self.retry_request(uri, request_chain).await;
if status.is_success() {
return status;
}
@ -593,11 +590,11 @@ impl Client {
/// Retry requests up to `max_retries` times
/// with an exponential backoff.
async fn retry_request(&self, uri: &Uri, credentials: &Option<BasicAuthCredentials>) -> Status {
async fn retry_request(&self, uri: &Uri, request_chain: &mut RequestChain) -> Status {
let mut retries: u64 = 0;
let mut wait_time = self.retry_wait_time;
let mut status = self.check_default(uri, credentials).await;
let mut status = self.check_default(uri, request_chain).await;
while retries < self.max_retries {
if status.is_success() || !status.should_retry() {
return status;
@ -605,7 +602,7 @@ impl Client {
retries += 1;
tokio::time::sleep(wait_time).await;
wait_time = wait_time.saturating_mul(2);
status = self.check_default(uri, credentials).await;
status = self.check_default(uri, request_chain).await;
}
status
}
@ -646,10 +643,7 @@ impl Client {
}
/// Check a URI using [reqwest](https://github.com/seanmonstar/reqwest).
async fn check_default(&self, uri: &Uri, credentials: &Option<BasicAuthCredentials>) -> Status {
// todo: middleware
// todo: create credentials middleware
async fn check_default(&self, uri: &Uri, request_chain: &mut RequestChain) -> Status {
let request = self
.reqwest_client
.request(self.method.clone(), uri.as_str())
@ -660,13 +654,7 @@ impl Client {
Err(e) => return e.into(),
};
let mut chain: Chain<reqwest::Request> = vec![Box::new(self.quirks.clone())];
if let Some(c) = credentials {
chain.push(Box::new(BasicAuth::new(c.clone())));
}
let request = traverse(chain, request);
let request = request_chain.traverse(request);
match self.reqwest_client.execute(request).await {
Ok(ref response) => Status::new(response, self.accepted.clone()),
@ -773,7 +761,7 @@ mod tests {
use wiremock::matchers::path;
use super::ClientBuilder;
use crate::{mock_server, test_utils::get_mock_client_response, Uri};
use crate::{mock_server, test_utils::get_mock_client_response, Request, Uri};
#[tokio::test]
async fn test_nonexistent() {
@ -820,6 +808,30 @@ mod tests {
assert!(res.status().is_failure());
}
#[tokio::test]
async fn test_basic_auth() {
let mut r = Request::new(
"https://authenticationtest.com/HTTPAuth/"
.try_into()
.unwrap(),
crate::InputSource::Stdin,
None,
None,
None,
);
let res = get_mock_client_response(r.clone()).await;
assert_eq!(res.status().code(), Some(401.try_into().unwrap()));
r.credentials = Some(crate::BasicAuthCredentials {
username: "user".into(),
password: "pass".into(),
});
let res = get_mock_client_response(r.clone()).await;
assert!(res.status().is_success());
}
#[tokio::test]
async fn test_non_github() {
let mock_server = mock_server!(StatusCode::OK);