diff --git a/src/checker.rs b/src/checker.rs index b725be9..cbcfb3b 100644 --- a/src/checker.rs +++ b/src/checker.rs @@ -9,6 +9,7 @@ use regex::{Regex, RegexSet}; use reqwest::header::{self, HeaderMap, HeaderValue}; use std::net::IpAddr; use std::{collections::HashSet, convert::TryFrom, time::Duration}; +use tokio::time::delay_for; use url::Url; pub(crate) enum RequestMethod { @@ -204,10 +205,23 @@ impl<'a> Checker<'a> { } pub async fn check_real(&self, url: &Url) -> Status { - let status = self.check_normal(&url).await; - if status.is_success() { - return status; - } + let mut retries: i64 = 3; + let mut wait: u64 = 1; + let status = loop { + let res = self.check_normal(&url).await; + match res.is_success() { + true => return res, + false => { + if retries > 0 { + retries -= 1; + delay_for(Duration::from_secs(wait)).await; + wait *= 2; + } else { + break res; + } + } + } + }; // Pull out the heavy weapons in case of a failed normal request. // This could be a Github URL and we run into the rate limiter. if let Ok((owner, repo)) = self.extract_github(url.as_str()) { @@ -347,7 +361,7 @@ impl<'a> Checker<'a> { mod test { use super::*; use http::StatusCode; - use std::time::Duration; + use std::time::{Duration, Instant}; use url::Url; use wiremock::matchers::method; use wiremock::{Mock, MockServer, ResponseTemplate}; @@ -402,6 +416,20 @@ mod test { assert!(matches!(res, Status::Failed(_))); } + #[tokio::test] + async fn test_exponetial_backoff() { + let start = Instant::now(); + let res = get_checker(false, HeaderMap::new()) + .check(&Uri::Website( + Url::parse("https://endler.dev/abcd").unwrap(), + )) + .await; + let end = start.elapsed(); + + assert!(matches!(res, Status::Failed(_))); + assert!(matches!(end.as_secs(), 7)); + } + #[test] fn test_is_github() { assert_eq!(