From 35610764a133afe37b9b978604fd3e016d22b880 Mon Sep 17 00:00:00 2001 From: Matthias Endler Date: Fri, 23 May 2025 13:37:32 +0200 Subject: [PATCH] Add support for custom headers in input processing (#1561) --- Cargo.lock | 11 + README.md | 9 +- docs/TROUBLESHOOTING.md | 4 +- examples/collect_links/collect_links.rs | 3 + fixtures/configs/headers.toml | 6 + fixtures/configs/smoketest.toml | 2 +- lychee-bin/Cargo.toml | 4 +- lychee-bin/src/client.rs | 9 +- lychee-bin/src/options.rs | 323 +++++++++++++++++++++--- lychee-bin/src/parse.rs | 45 +--- lychee-bin/tests/cli.rs | 114 ++++++++- lychee-lib/src/collector.rs | 28 +- lychee-lib/src/types/input.rs | 78 ++++-- lychee.example.toml | 2 +- 14 files changed, 512 insertions(+), 126 deletions(-) create mode 100644 fixtures/configs/headers.toml diff --git a/Cargo.lock b/Cargo.lock index 6fe799f..788fd7f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1892,6 +1892,16 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-serde" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f056c8559e3757392c8d091e796416e4649d8e49e88b8d76df6c002f05027fd" +dependencies = [ + "http 1.3.1", + "serde", +] + [[package]] name = "httparse" version = "1.10.1" @@ -2529,6 +2539,7 @@ dependencies = [ "futures", "headers", "http 1.3.1", + "http-serde", "human-sort", "humantime", "humantime-serde", diff --git a/README.md b/README.md index 7a0677a..eafe2c2 100644 --- a/README.md +++ b/README.md @@ -458,8 +458,13 @@ Options: Example: --fallback-extensions html,htm,php,asp,aspx,jsp,cgi - --header
- Custom request header + -H, --header + Set custom header for requests + + Some websites require custom headers to be passed in order to return valid responses. + You can specify custom headers in the format 'Name: Value'. For example, 'Accept: text/html'. + This is the same format that other tools like curl or wget use. + Multiple headers can be specified by using the flag multiple times. -a, --accept A List of accepted status codes for valid links diff --git a/docs/TROUBLESHOOTING.md b/docs/TROUBLESHOOTING.md index 11f1e9d..d44c3c9 100644 --- a/docs/TROUBLESHOOTING.md +++ b/docs/TROUBLESHOOTING.md @@ -48,9 +48,9 @@ Some sites expect one or more custom headers to return a valid response. \ For example, crates.io expects a `Accept: text/html` header or else it \ will [return a 404](https://github.com/rust-lang/crates.io/issues/788). -To fix that you can pass additional headers like so: `--header "accept=text/html"`. \ +To fix that you can pass additional headers like so: `--header "Accept: text/html"`. \ You can use that argument multiple times to add more headers. \ -Or, you can accept all content/MIME types: `--header "accept=*/*"`. +Or, you can accept all content/MIME types: `--header "Accept: */*"`. See more info about the Accept header [over at MDN](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept). diff --git a/examples/collect_links/collect_links.rs b/examples/collect_links/collect_links.rs index 57edd5f..a7266b0 100644 --- a/examples/collect_links/collect_links.rs +++ b/examples/collect_links/collect_links.rs @@ -1,3 +1,4 @@ +use http::HeaderMap; use lychee_lib::{Collector, Input, InputSource, Result}; use reqwest::Url; use std::path::PathBuf; @@ -13,11 +14,13 @@ async fn main() -> Result<()> { )), file_type_hint: None, excluded_paths: None, + headers: HeaderMap::new(), }, Input { source: InputSource::FsPath(PathBuf::from("fixtures/TEST.md")), file_type_hint: None, excluded_paths: None, + headers: HeaderMap::new(), }, ]; diff --git a/fixtures/configs/headers.toml b/fixtures/configs/headers.toml new file mode 100644 index 0000000..2873301 --- /dev/null +++ b/fixtures/configs/headers.toml @@ -0,0 +1,6 @@ +[header] +X-Foo = "Bar" +X-Bar = "Baz" + +# Alternative TOML syntax: +# header = { X-Foo = "Bar", X-Bar = "Baz" } diff --git a/fixtures/configs/smoketest.toml b/fixtures/configs/smoketest.toml index 3dc229f..6ae6a33 100644 --- a/fixtures/configs/smoketest.toml +++ b/fixtures/configs/smoketest.toml @@ -69,7 +69,7 @@ require_https = false method = "get" # Custom request headers -headers = [] +header = { X-Foo = "Bar", X-Bar = "Baz" } # Remap URI matching pattern to different URI. # This also supports (named) capturing groups. diff --git a/lychee-bin/Cargo.toml b/lychee-bin/Cargo.toml index b7f942d..40a766c 100644 --- a/lychee-bin/Cargo.toml +++ b/lychee-bin/Cargo.toml @@ -28,8 +28,10 @@ env_logger = "0.11.8" futures = "0.3.31" headers = "0.4.0" http = "1.3.1" +http-serde = "2.1.1" humantime = "2.2.0" humantime-serde = "1.1.1" +human-sort = "0.2.2" indicatif = "0.17.11" log = "0.4.27" openssl-sys = { version = "0.9.108", optional = true } @@ -55,7 +57,7 @@ tokio = { version = "1.45.0", features = ["full"] } tokio-stream = "0.1.17" toml = "0.8.22" url = "2.5.4" -human-sort = "0.2.2" + [dev-dependencies] assert_cmd = "2.0.17" diff --git a/lychee-bin/src/client.rs b/lychee-bin/src/client.rs index e9ccf0d..2dca723 100644 --- a/lychee-bin/src/client.rs +++ b/lychee-bin/src/client.rs @@ -1,7 +1,7 @@ -use crate::options::Config; -use crate::parse::{parse_duration_secs, parse_headers, parse_remaps}; +use crate::options::{Config, HeaderMapExt}; +use crate::parse::{parse_duration_secs, parse_remaps}; use anyhow::{Context, Result}; -use http::StatusCode; +use http::{HeaderMap, StatusCode}; use lychee_lib::{Client, ClientBuilder}; use regex::RegexSet; use reqwest_cookie_store::CookieStoreMutex; @@ -10,7 +10,6 @@ use std::{collections::HashSet, str::FromStr}; /// Creates a client according to the command-line config pub(crate) fn create(cfg: &Config, cookie_jar: Option<&Arc>) -> Result { - let headers = parse_headers(&cfg.header)?; let timeout = parse_duration_secs(cfg.timeout); let retry_wait_time = parse_duration_secs(cfg.retry_wait_time); let method: reqwest::Method = reqwest::Method::from_str(&cfg.method.to_uppercase())?; @@ -34,6 +33,8 @@ pub(crate) fn create(cfg: &Config, cookie_jar: Option<&Arc>) - .map(|value| StatusCode::from_u16(*value)) .collect::, _>>()?; + let headers = HeaderMap::from_header_pairs(&cfg.header)?; + ClientBuilder::builder() .remaps(remaps) .base(cfg.base_url.clone()) diff --git a/lychee-bin/src/options.rs b/lychee-bin/src/options.rs index a10f775..5d99782 100644 --- a/lychee-bin/src/options.rs +++ b/lychee-bin/src/options.rs @@ -5,6 +5,10 @@ use anyhow::{anyhow, Context, Error, Result}; use clap::builder::PossibleValuesParser; use clap::{arg, builder::TypedValueParser, Parser}; use const_format::{concatcp, formatcp}; +use http::{ + header::{HeaderName, HeaderValue}, + HeaderMap, +}; use lychee_lib::{ Base, BasicAuthSelector, FileExtensions, FileType, Input, StatusCodeExcluder, StatusCodeSelector, DEFAULT_MAX_REDIRECTS, DEFAULT_MAX_RETRIES, DEFAULT_RETRY_WAIT_TIME_SECS, @@ -12,7 +16,8 @@ use lychee_lib::{ }; use reqwest::tls; use secrecy::{ExposeSecret, SecretString}; -use serde::Deserialize; +use serde::{Deserialize, Deserializer}; +use std::collections::HashMap; use std::path::Path; use std::{fs, path::PathBuf, str::FromStr, time::Duration}; use strum::{Display, EnumIter, EnumString, VariantNames}; @@ -201,6 +206,98 @@ macro_rules! fold_in { }; } +/// Parse a single header into a [`HeaderName`] and [`HeaderValue`] +/// +/// Headers are expected to be in format `Header-Name: Header-Value`. +/// The header name and value are trimmed of whitespace. +/// +/// If the header contains multiple colons, the part after the first colon is +/// considered the value. +fn parse_single_header(header: &str) -> Result<(HeaderName, HeaderValue)> { + let parts: Vec<&str> = header.splitn(2, ':').collect(); + match parts.as_slice() { + [name, value] => { + let name = HeaderName::from_bytes(name.trim().as_bytes()) + .map_err(|e| anyhow!("Invalid header name '{}': {}", name.trim(), e))?; + let value = HeaderValue::from_str(value.trim()) + .map_err(|e| anyhow!("Invalid header value '{}': {}", value.trim(), e))?; + Ok((name, value)) + } + _ => Err(anyhow!( + "Invalid header format. Expected colon-separated string in the format 'HeaderName: HeaderValue', got '{}'", + header + )), + } +} + +/// Parses a single HTTP header into a tuple of (String, String) +/// +/// This does NOT merge multiple headers into one. +#[derive(Clone, Debug)] +struct HeaderParser; + +impl TypedValueParser for HeaderParser { + type Value = (String, String); + + fn parse_ref( + &self, + _cmd: &clap::Command, + _arg: Option<&clap::Arg>, + value: &std::ffi::OsStr, + ) -> Result { + let header_str = value.to_str().ok_or_else(|| { + clap::Error::raw( + clap::error::ErrorKind::InvalidValue, + "Header value contains invalid UTF-8", + ) + })?; + + match parse_single_header(header_str) { + Ok((name, value)) => { + let Ok(value) = value.to_str() else { + return Err(clap::Error::raw( + clap::error::ErrorKind::InvalidValue, + "Header value contains invalid UTF-8", + )); + }; + + Ok((name.to_string(), value.to_string())) + } + Err(e) => Err(clap::Error::raw( + clap::error::ErrorKind::InvalidValue, + e.to_string(), + )), + } + } +} + +impl clap::builder::ValueParserFactory for HeaderParser { + type Parser = HeaderParser; + fn value_parser() -> Self::Parser { + HeaderParser + } +} + +/// Extension trait for converting a Vec of header pairs to a `HeaderMap` +pub(crate) trait HeaderMapExt { + /// Convert a collection of header key-value pairs to a `HeaderMap` + fn from_header_pairs(headers: &[(String, String)]) -> Result; +} + +impl HeaderMapExt for HeaderMap { + fn from_header_pairs(headers: &[(String, String)]) -> Result { + let mut header_map = HeaderMap::new(); + for (name, value) in headers { + let header_name = HeaderName::from_bytes(name.as_bytes()) + .map_err(|e| anyhow!("Invalid header name '{}': {}", name, e))?; + let header_value = HeaderValue::from_str(value) + .map_err(|e| anyhow!("Invalid header value '{}': {}", value, e))?; + header_map.insert(header_name, header_value); + } + Ok(header_map) + } +} + /// A fast, async link checker /// /// Finds broken URLs and mail addresses inside Markdown, HTML, @@ -235,14 +332,33 @@ impl LycheeOptions { } else { Some(self.config.exclude_path.clone()) }; + let headers = HeaderMap::from_header_pairs(&self.config.header)?; + self.raw_inputs .iter() - .map(|s| Input::new(s, None, self.config.glob_ignore_case, excluded.clone())) + .map(|s| { + Input::new( + s, + None, + self.config.glob_ignore_case, + excluded.clone(), + headers.clone(), + ) + }) .collect::>() .context("Cannot parse inputs from arguments") } } +// Custom deserializer function for the header field +fn deserialize_headers<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let map = HashMap::::deserialize(deserializer)?; + Ok(map.into_iter().collect()) +} + /// The main configuration for lychee #[allow(clippy::struct_excessive_bools)] #[derive(Parser, Debug, Deserialize, Clone, Default)] @@ -450,10 +566,27 @@ Example: --fallback-extensions html,htm,php,asp,aspx,jsp,cgi" )] pub(crate) fallback_extensions: Vec, - /// Custom request header - #[arg(long)] + /// Set custom header for requests + #[arg( + short = 'H', + long = "header", + // Note: We use a `Vec<(String, String)>` for headers, which is + // unfortunate. The reason is that `clap::ArgAction::Append` collects + // multiple values, and `clap` cannot automatically convert these tuples + // into a `HashMap`. + action = clap::ArgAction::Append, + value_parser = HeaderParser, + value_name = "HEADER:VALUE", + long_help = "Set custom header for requests + +Some websites require custom headers to be passed in order to return valid responses. +You can specify custom headers in the format 'Name: Value'. For example, 'Accept: text/html'. +This is the same format that other tools like curl or wget use. +Multiple headers can be specified by using the flag multiple times." + )] #[serde(default)] - pub(crate) header: Vec, + #[serde(deserialize_with = "deserialize_headers")] + pub header: Vec<(String, String)>, /// A List of accepted status codes for valid links #[arg( @@ -581,6 +714,20 @@ separated list of accepted status codes. This example will accept 200, 201, } impl Config { + /// Special handling for merging headers + /// + /// Overwrites existing headers in `self` with the values from `other`. + fn merge_headers(&mut self, other: &[(String, String)]) { + let self_map = self.header.iter().cloned().collect::>(); + let other_map = other.iter().cloned().collect::>(); + + // Merge the two maps, with `other` taking precedence + let merged_map: HashMap<_, _> = self_map.into_iter().chain(other_map).collect(); + + // Convert the merged map back to a Vec of tuples + self.header = merged_map.into_iter().collect(); + } + /// Load configuration from a file pub(crate) fn load_from_file(path: &Path) -> Result { // Read configuration file @@ -590,52 +737,57 @@ impl Config { /// Merge the configuration from TOML into the CLI configuration pub(crate) fn merge(&mut self, toml: Config) { + // Special handling for headers before fold_in! + self.merge_headers(&toml.header); + fold_in! { // Destination and source configs self, toml; // Keys with defaults to assign - verbose: Verbosity::default(); - cache: false; - no_progress: false; - max_redirects: DEFAULT_MAX_REDIRECTS; - max_retries: DEFAULT_MAX_RETRIES; - max_concurrency: DEFAULT_MAX_CONCURRENCY; - max_cache_age: humantime::parse_duration(DEFAULT_MAX_CACHE_AGE).unwrap(); - cache_exclude_status: StatusCodeExcluder::default(); - threads: None; - user_agent: DEFAULT_USER_AGENT; - insecure: false; - scheme: Vec::::new(); - include: Vec::::new(); - exclude: Vec::::new(); - exclude_file: Vec::::new(); // deprecated - exclude_path: Vec::::new(); - exclude_all_private: false; - exclude_private: false; - exclude_link_local: false; - exclude_loopback: false; - format: StatsFormat::default(); - remap: Vec::::new(); - fallback_extensions: Vec::::new(); - header: Vec::::new(); - timeout: DEFAULT_TIMEOUT_SECS; - retry_wait_time: DEFAULT_RETRY_WAIT_TIME_SECS; - method: DEFAULT_METHOD; + accept: StatusCodeSelector::default(); base_url: None; basic_auth: None; - skip_missing: false; - include_verbatim: false; - include_mail: false; - glob_ignore_case: false; - output: None; - require_https: false; + cache_exclude_status: StatusCodeExcluder::default(); + cache: false; cookie_jar: None; - include_fragments: false; - accept: StatusCodeSelector::default(); + exclude_all_private: false; + exclude_file: Vec::::new(); // deprecated + exclude_link_local: false; + exclude_loopback: false; + exclude_path: Vec::::new(); + exclude_private: false; + exclude: Vec::::new(); extensions: FileType::default_extensions(); + fallback_extensions: Vec::::new(); + format: StatsFormat::default(); + glob_ignore_case: false; + header: Vec::<(String, String)>::new(); + include_fragments: false; + include_mail: false; + include_verbatim: false; + include: Vec::::new(); + insecure: false; + max_cache_age: humantime::parse_duration(DEFAULT_MAX_CACHE_AGE).unwrap(); + max_concurrency: DEFAULT_MAX_CONCURRENCY; + max_redirects: DEFAULT_MAX_REDIRECTS; + max_retries: DEFAULT_MAX_RETRIES; + method: DEFAULT_METHOD; + no_progress: false; + output: None; + remap: Vec::::new(); + require_https: false; + retry_wait_time: DEFAULT_RETRY_WAIT_TIME_SECS; + scheme: Vec::::new(); + skip_missing: false; + threads: None; + timeout: DEFAULT_TIMEOUT_SECS; + user_agent: DEFAULT_USER_AGENT; + verbose: Verbosity::default(); } + // If the config file has a value for the GitHub token, but the CLI + // doesn't, use the token from the config file. if self .github_token .as_ref() @@ -654,6 +806,8 @@ impl Config { #[cfg(test)] mod tests { + use std::collections::HashMap; + use super::*; #[test] @@ -683,4 +837,93 @@ mod tests { ); assert_eq!(cli.cache_exclude_status, StatusCodeExcluder::new()); } + + #[test] + fn test_parse_custom_headers() { + assert_eq!( + parse_single_header("accept:text/html").unwrap(), + ( + HeaderName::from_static("accept"), + HeaderValue::from_static("text/html") + ) + ); + } + + #[test] + fn test_parse_custom_header_multiple_colons() { + assert_eq!( + parse_single_header("key:x-test:check=this").unwrap(), + ( + HeaderName::from_static("key"), + HeaderValue::from_static("x-test:check=this") + ) + ); + } + + #[test] + fn test_parse_custom_headers_with_equals() { + assert_eq!( + parse_single_header("key:x-test=check=this").unwrap(), + ( + HeaderName::from_static("key"), + HeaderValue::from_static("x-test=check=this") + ) + ); + } + + #[test] + fn test_header_parsing_and_merging() { + // Simulate commandline arguments with multiple headers + let args = vec![ + "lychee", + "--header", + "Accept: text/html", + "--header", + "X-Test: check=this", + "input.md", + ]; + + // Parse the arguments + let opts = crate::LycheeOptions::parse_from(args); + + // Check that the headers were collected correctly + let headers = &opts.config.header; + assert_eq!(headers.len(), 2); + + // Convert to HashMap for easier testing + let header_map: HashMap = headers.iter().cloned().collect(); + assert_eq!(header_map["accept"], "text/html"); + assert_eq!(header_map["x-test"], "check=this"); + } + + #[test] + fn test_merge_headers_with_config() { + let toml = Config { + header: vec![ + ("Accept".to_string(), "text/html".to_string()), + ("X-Test".to_string(), "check=this".to_string()), + ], + ..Default::default() + }; + + // Set X-Test and see if it gets overwritten + let mut cli = Config { + header: vec![("X-Test".to_string(), "check=that".to_string())], + ..Default::default() + }; + cli.merge(toml); + + assert_eq!(cli.header.len(), 2); + + // Sort vector before assert + cli.header.sort(); + + assert_eq!( + cli.header, + vec![ + ("Accept".to_string(), "text/html".to_string()), + ("X-Test".to_string(), "check=this".to_string()), + ] + ); + } } diff --git a/lychee-bin/src/parse.rs b/lychee-bin/src/parse.rs index f7fc4d6..0f2b737 100644 --- a/lychee-bin/src/parse.rs +++ b/lychee-bin/src/parse.rs @@ -1,35 +1,12 @@ -use anyhow::{anyhow, Context, Result}; -use headers::{HeaderMap, HeaderName}; +use anyhow::{Context, Result}; use lychee_lib::{remap::Remaps, Base}; use std::time::Duration; -/// Split a single HTTP header into a (key, value) tuple -fn read_header(input: &str) -> Result<(String, String), anyhow::Error> { - if let Some((key, value)) = input.split_once('=') { - Ok((key.to_string(), value.to_string())) - } else { - Err(anyhow!( - "Header value must be of the form key=value, got {}", - input - )) - } -} - /// Parse seconds into a `Duration` pub(crate) const fn parse_duration_secs(secs: usize) -> Duration { Duration::from_secs(secs as u64) } -/// Parse HTTP headers into a `HeaderMap` -pub(crate) fn parse_headers>(headers: &[T]) -> Result { - let mut out = HeaderMap::new(); - for header in headers { - let (key, val) = read_header(header.as_ref())?; - out.insert(HeaderName::from_bytes(key.as_bytes())?, val.parse()?); - } - Ok(out) -} - /// Parse URI remaps pub(crate) fn parse_remaps(remaps: &[String]) -> Result { Remaps::try_from(remaps) @@ -42,30 +19,10 @@ pub(crate) fn parse_base(src: &str) -> Result { #[cfg(test)] mod tests { - - use headers::HeaderMap; use regex::Regex; - use reqwest::header; use super::*; - #[test] - fn test_parse_custom_headers() { - let mut custom = HeaderMap::new(); - custom.insert(header::ACCEPT, "text/html".parse().unwrap()); - assert_eq!(parse_headers(&["accept=text/html"]).unwrap(), custom); - } - - #[test] - fn test_parse_custom_headers_with_equals() { - let mut custom_with_equals = HeaderMap::new(); - custom_with_equals.insert("x-test", "check=this".parse().unwrap()); - assert_eq!( - parse_headers(&["x-test=check=this"]).unwrap(), - custom_with_equals - ); - } - #[test] fn test_parse_remap() { let remaps = diff --git a/lychee-bin/tests/cli.rs b/lychee-bin/tests/cli.rs index 9090a4d..ae6e7c8 100644 --- a/lychee-bin/tests/cli.rs +++ b/lychee-bin/tests/cli.rs @@ -12,7 +12,7 @@ mod cli { use anyhow::anyhow; use assert_cmd::Command; use assert_json_diff::assert_json_include; - use http::StatusCode; + use http::{Method, StatusCode}; use lychee_lib::{InputSource, ResponseBody}; use predicates::{ prelude::{predicate, PredicateBooleanExt}, @@ -1950,6 +1950,118 @@ mod cli { Ok(()) } + #[tokio::test] + async fn test_no_header_set_on_input() -> Result<()> { + let mut cmd = main_command(); + let server = wiremock::MockServer::start().await; + server + .register( + wiremock::Mock::given(wiremock::matchers::method("GET")) + .respond_with(wiremock::ResponseTemplate::new(200)) + .expect(1), + ) + .await; + + cmd.arg("--verbose").arg(server.uri()).assert().success(); + + let received_requests = server.received_requests().await.unwrap(); + assert_eq!(received_requests.len(), 1); + + let received_request = &received_requests[0]; + assert_eq!(received_request.method, Method::GET); + assert_eq!(received_request.url.path(), "/"); + + // Make sure the request does not contain the custom header + assert!(!received_request.headers.contains_key("X-Foo")); + Ok(()) + } + + #[tokio::test] + async fn test_header_set_on_input() -> Result<()> { + let mut cmd = main_command(); + let server = wiremock::MockServer::start().await; + server + .register( + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::header("X-Foo", "Bar")) + .respond_with(wiremock::ResponseTemplate::new(200)) + // We expect the mock to be called exactly least once. + .expect(1) + .named("GET expecting custom header"), + ) + .await; + + cmd.arg("--verbose") + .arg("--header") + .arg("X-Foo: Bar") + .arg(server.uri()) + .assert() + .success(); + + // Check that the server received the request with the header + server.verify().await; + Ok(()) + } + + #[tokio::test] + async fn test_multi_header_set_on_input() -> Result<()> { + let mut cmd = main_command(); + let server = wiremock::MockServer::start().await; + server + .register( + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::header("X-Foo", "Bar")) + .and(wiremock::matchers::header("X-Bar", "Baz")) + .respond_with(wiremock::ResponseTemplate::new(200)) + // We expect the mock to be called exactly least once. + .expect(1) + .named("GET expecting custom header"), + ) + .await; + + cmd.arg("--verbose") + .arg("--header") + .arg("X-Foo: Bar") + .arg("--header") + .arg("X-Bar: Baz") + .arg(server.uri()) + .assert() + .success(); + + // Check that the server received the request with the header + server.verify().await; + Ok(()) + } + + #[tokio::test] + async fn test_header_set_in_config() -> Result<()> { + let mut cmd = main_command(); + let server = wiremock::MockServer::start().await; + server + .register( + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::header("X-Foo", "Bar")) + .and(wiremock::matchers::header("X-Bar", "Baz")) + .respond_with(wiremock::ResponseTemplate::new(200)) + // We expect the mock to be called exactly least once. + .expect(1) + .named("GET expecting custom header"), + ) + .await; + + let config = fixtures_path().join("configs").join("headers.toml"); + cmd.arg("--verbose") + .arg("--config") + .arg(config) + .arg(server.uri()) + .assert() + .success(); + + // Check that the server received the request with the header + server.verify().await; + Ok(()) + } + #[test] fn test_sorted_error_output() -> Result<()> { let test_files = ["TEST_GITHUB_404.md", "TEST_INVALID_URLS.html"]; diff --git a/lychee-lib/src/collector.rs b/lychee-lib/src/collector.rs index 7794d3b..abe681c 100644 --- a/lychee-lib/src/collector.rs +++ b/lychee-lib/src/collector.rs @@ -181,7 +181,7 @@ impl Collector { mod tests { use std::{collections::HashSet, convert::TryFrom, fs::File, io::Write}; - use http::StatusCode; + use http::{HeaderMap, StatusCode}; use reqwest::Url; use super::*; @@ -230,7 +230,13 @@ mod tests { // Treat as plaintext file (no extension) let file_path = temp_dir.path().join("README"); let _file = File::create(&file_path).unwrap(); - let input = Input::new(&file_path.as_path().display().to_string(), None, true, None)?; + let input = Input::new( + &file_path.as_path().display().to_string(), + None, + true, + None, + HeaderMap::new(), + )?; let contents: Vec<_> = input .get_contents(true, true, true, FileType::default_extensions()) .collect::>() @@ -243,7 +249,7 @@ mod tests { #[tokio::test] async fn test_url_without_extension_is_html() -> Result<()> { - let input = Input::new("https://example.com/", None, true, None)?; + let input = Input::new("https://example.com/", None, true, None, HeaderMap::new())?; let contents: Vec<_> = input .get_contents(true, true, true, FileType::default_extensions()) .collect::>() @@ -278,6 +284,7 @@ mod tests { source: InputSource::String(TEST_STRING.to_owned()), file_type_hint: None, excluded_paths: None, + headers: HeaderMap::new(), }, Input { source: InputSource::RemoteUrl(Box::new( @@ -287,11 +294,13 @@ mod tests { )), file_type_hint: None, excluded_paths: None, + headers: HeaderMap::new(), }, Input { source: InputSource::FsPath(file_path), file_type_hint: None, excluded_paths: None, + headers: HeaderMap::new(), }, Input { source: InputSource::FsGlob { @@ -300,6 +309,7 @@ mod tests { }, file_type_hint: None, excluded_paths: None, + headers: HeaderMap::new(), }, ]; @@ -327,7 +337,8 @@ mod tests { let input = Input { source: InputSource::String("This is [a test](https://endler.dev). This is a relative link test [Relative Link Test](relative_link)".to_string()), file_type_hint: Some(FileType::Markdown), - excluded_paths: None, + excluded_paths: None, + headers: HeaderMap::new(), }; let links = collect(vec![input], None, Some(base)).await.ok().unwrap(); @@ -354,6 +365,7 @@ mod tests { ), file_type_hint: Some(FileType::Html), excluded_paths: None, + headers: HeaderMap::new(), }; let links = collect(vec![input], None, Some(base)).await.ok().unwrap(); @@ -383,6 +395,7 @@ mod tests { ), file_type_hint: Some(FileType::Html), excluded_paths: None, + headers: HeaderMap::new(), }; let links = collect(vec![input], None, Some(base)).await.ok().unwrap(); @@ -409,6 +422,7 @@ mod tests { ), file_type_hint: Some(FileType::Markdown), excluded_paths: None, + headers: HeaderMap::new(), }; let links = collect(vec![input], None, Some(base)).await.ok().unwrap(); @@ -432,6 +446,7 @@ mod tests { source: InputSource::String(input), file_type_hint: Some(FileType::Html), excluded_paths: None, + headers: HeaderMap::new(), }; let links = collect(vec![input], None, Some(base)).await.ok().unwrap(); @@ -464,6 +479,7 @@ mod tests { source: InputSource::RemoteUrl(Box::new(server_uri.clone())), file_type_hint: None, excluded_paths: None, + headers: HeaderMap::new(), }; let links = collect(vec![input], None, None).await.ok().unwrap(); @@ -484,6 +500,7 @@ mod tests { ), file_type_hint: None, excluded_paths: None, + headers: HeaderMap::new(), }; let links = collect(vec![input], None, None).await.ok().unwrap(); @@ -514,6 +531,7 @@ mod tests { )), file_type_hint: Some(FileType::Html), excluded_paths: None, + headers: HeaderMap::new(), }, Input { source: InputSource::RemoteUrl(Box::new( @@ -525,6 +543,7 @@ mod tests { )), file_type_hint: Some(FileType::Html), excluded_paths: None, + headers: HeaderMap::new(), }, ]; @@ -560,6 +579,7 @@ mod tests { ), file_type_hint: Some(FileType::Html), excluded_paths: None, + headers: HeaderMap::new(), }; let links = collect(vec![input], None, Some(base)).await.ok().unwrap(); diff --git a/lychee-lib/src/types/input.rs b/lychee-lib/src/types/input.rs index 104f788..91d8176 100644 --- a/lychee-lib/src/types/input.rs +++ b/lychee-lib/src/types/input.rs @@ -3,6 +3,7 @@ use crate::{utils, ErrorKind, Result}; use async_stream::try_stream; use futures::stream::Stream; use glob::glob_with; +use http::HeaderMap; use ignore::WalkBuilder; use reqwest::Url; use serde::{Deserialize, Serialize}; @@ -101,7 +102,7 @@ impl Display for InputSource { } /// Lychee Input with optional file hint for parsing -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct Input { /// Origin of input pub source: InputSource, @@ -109,6 +110,8 @@ pub struct Input { pub file_type_hint: Option, /// Excluded paths that will be skipped when reading content pub excluded_paths: Option>, + /// Custom headers to be used when fetching remote URLs + pub headers: reqwest::header::HeaderMap, } impl Input { @@ -125,6 +128,7 @@ impl Input { file_type_hint: Option, glob_ignore_case: bool, excluded_paths: Option>, + headers: reqwest::header::HeaderMap, ) -> Result { let source = if value == STDIN { InputSource::Stdin @@ -190,9 +194,20 @@ impl Input { source, file_type_hint, excluded_paths, + headers, }) } + /// Convenience constructor with sane defaults + /// + /// # Errors + /// + /// Returns an error if the input does not exist (i.e. invalid path) + /// and the input cannot be parsed as a URL. + pub fn from_value(value: &str) -> Result { + Self::new(value, None, false, None, HeaderMap::new()) + } + /// Retrieve the contents from the input /// /// If the input is a path, only search through files that match the given @@ -215,7 +230,7 @@ impl Input { try_stream! { match self.source { InputSource::RemoteUrl(ref url) => { - let content = Self::url_contents(url).await; + let content = Self::url_contents(url, &self.headers).await; match content { Err(_) if skip_missing => (), Err(e) => Err(e)?, @@ -313,7 +328,7 @@ impl Input { } } - async fn url_contents(url: &Url) -> Result { + async fn url_contents(url: &Url, headers: &HeaderMap) -> Result { // Assume HTML for default paths let file_type = if url.path().is_empty() || url.path() == "/" { FileType::Html @@ -321,7 +336,12 @@ impl Input { FileType::from(url.as_str()) }; - let res = reqwest::get(url.clone()) + let client = reqwest::Client::new(); + + let res = client + .get(url.clone()) + .headers(headers.clone()) + .send() .await .map_err(ErrorKind::NetworkRequest)?; let input_content = InputContent { @@ -414,6 +434,14 @@ impl Input { } } +impl TryFrom<&str> for Input { + type Error = crate::ErrorKind; + + fn try_from(value: &str) -> std::result::Result { + Self::from_value(value) + } +} + /// Function for path exclusion tests /// /// This is a standalone function to allow for easier testing @@ -428,6 +456,8 @@ fn is_excluded_path(excluded_paths: &[PathBuf], path: &PathBuf) -> bool { #[cfg(test)] mod tests { + use http::HeaderMap; + use super::*; #[test] @@ -438,14 +468,15 @@ mod tests { assert!(path.exists()); assert!(path.is_relative()); - let input = Input::new(test_file, None, false, None); + let input = Input::new(test_file, None, false, None, HeaderMap::new()); assert!(input.is_ok()); assert!(matches!( input, Ok(Input { source: InputSource::FsPath(PathBuf { .. }), file_type_hint: None, - excluded_paths: None + excluded_paths: None, + headers: _, }) )); } @@ -458,7 +489,7 @@ mod tests { assert!(!path.exists()); assert!(path.is_relative()); - let input = Input::new(test_file, None, false, None); + let input = Input::from_value(test_file); assert!(input.is_err()); assert!(matches!(input, Err(ErrorKind::InvalidFile(PathBuf { .. })))); } @@ -490,7 +521,7 @@ mod tests { #[test] fn test_url_without_scheme() { - let input = Input::new("example.com", None, false, None); + let input = Input::from_value("example.com"); assert_eq!( input.unwrap().source.to_string(), String::from("http://example.com/") @@ -501,7 +532,7 @@ mod tests { #[cfg(windows)] #[test] fn test_windows_style_filepath_not_existing() { - let input = Input::new("C:\\example\\project\\here", None, false, None); + let input = Input::from_value("C:\\example\\project\\here"); assert!(input.is_err()); let input = input.unwrap_err(); @@ -521,7 +552,7 @@ mod tests { let dir = temp_dir(); let file = NamedTempFile::new_in(dir).unwrap(); let path = file.path(); - let input = Input::new(path.to_str().unwrap(), None, false, None).unwrap(); + let input = Input::from_value(path.to_str().unwrap()).unwrap(); match input.source { InputSource::FsPath(_) => (), @@ -533,33 +564,28 @@ mod tests { fn test_url_scheme_check_succeeding() { // Valid http and https URLs assert!(matches!( - Input::new("http://example.com", None, false, None), + Input::from_value("http://example.com"), Ok(Input { source: InputSource::RemoteUrl(_), .. }) )); assert!(matches!( - Input::new("https://example.com", None, false, None), + Input::from_value("https://example.com"), Ok(Input { source: InputSource::RemoteUrl(_), .. }) )); assert!(matches!( - Input::new( - "http://subdomain.example.com/path?query=value", - None, - false, - None - ), + Input::from_value("http://subdomain.example.com/path?query=value",), Ok(Input { source: InputSource::RemoteUrl(_), .. }) )); assert!(matches!( - Input::new("https://example.com:8080", None, false, None), + Input::from_value("https://example.com:8080"), Ok(Input { source: InputSource::RemoteUrl(_), .. @@ -571,19 +597,19 @@ mod tests { fn test_url_scheme_check_failing() { // Invalid schemes assert!(matches!( - Input::new("ftp://example.com", None, false, None), + Input::from_value("ftp://example.com"), Err(ErrorKind::InvalidFile(_)) )); assert!(matches!( - Input::new("httpx://example.com", None, false, None), + Input::from_value("httpx://example.com"), Err(ErrorKind::InvalidFile(_)) )); assert!(matches!( - Input::new("file:///path/to/file", None, false, None), + Input::from_value("file:///path/to/file"), Err(ErrorKind::InvalidFile(_)) )); assert!(matches!( - Input::new("mailto:user@example.com", None, false, None), + Input::from_value("mailto:user@example.com"), Err(ErrorKind::InvalidFile(_)) )); } @@ -592,11 +618,11 @@ mod tests { fn test_non_url_inputs() { // Non-URL inputs assert!(matches!( - Input::new("./local/path", None, false, None), + Input::from_value("./local/path"), Err(ErrorKind::InvalidFile(_)) )); assert!(matches!( - Input::new("*.md", None, false, None), + Input::from_value("*.md"), Ok(Input { source: InputSource::FsGlob { .. }, .. @@ -604,7 +630,7 @@ mod tests { )); // Assuming the current directory exists assert!(matches!( - Input::new(".", None, false, None), + Input::from_value("."), Ok(Input { source: InputSource::FsPath(_), .. diff --git a/lychee.example.toml b/lychee.example.toml index bd0e778..e759aa9 100644 --- a/lychee.example.toml +++ b/lychee.example.toml @@ -69,7 +69,7 @@ require_https = false method = "get" # Custom request headers -header = ["name=value", "other=value"] +header = { "accept" = "text/html", "x-custom-header" = "value" } # Remap URI matching pattern to different URI. remap = ["https://example.com http://example.invalid"]