Add support for custom headers in input processing (#1561)

This commit is contained in:
Matthias Endler 2025-05-23 13:37:32 +02:00 committed by GitHub
parent 973b2aa5e0
commit 35610764a1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 512 additions and 126 deletions

11
Cargo.lock generated
View file

@ -1892,6 +1892,16 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "http-serde"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f056c8559e3757392c8d091e796416e4649d8e49e88b8d76df6c002f05027fd"
dependencies = [
"http 1.3.1",
"serde",
]
[[package]]
name = "httparse"
version = "1.10.1"
@ -2529,6 +2539,7 @@ dependencies = [
"futures",
"headers",
"http 1.3.1",
"http-serde",
"human-sort",
"humantime",
"humantime-serde",

View file

@ -458,8 +458,13 @@ Options:
Example: --fallback-extensions html,htm,php,asp,aspx,jsp,cgi
--header <HEADER>
Custom request header
-H, --header <HEADER:VALUE>
Set custom header for requests
Some websites require custom headers to be passed in order to return valid responses.
You can specify custom headers in the format 'Name: Value'. For example, 'Accept: text/html'.
This is the same format that other tools like curl or wget use.
Multiple headers can be specified by using the flag multiple times.
-a, --accept <ACCEPT>
A List of accepted status codes for valid links

View file

@ -48,9 +48,9 @@ Some sites expect one or more custom headers to return a valid response. \
For example, crates.io expects a `Accept: text/html` header or else it \
will [return a 404](https://github.com/rust-lang/crates.io/issues/788).
To fix that you can pass additional headers like so: `--header "accept=text/html"`. \
To fix that you can pass additional headers like so: `--header "Accept: text/html"`. \
You can use that argument multiple times to add more headers. \
Or, you can accept all content/MIME types: `--header "accept=*/*"`.
Or, you can accept all content/MIME types: `--header "Accept: */*"`.
See more info about the Accept header
[over at MDN](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept).

View file

@ -1,3 +1,4 @@
use http::HeaderMap;
use lychee_lib::{Collector, Input, InputSource, Result};
use reqwest::Url;
use std::path::PathBuf;
@ -13,11 +14,13 @@ async fn main() -> Result<()> {
)),
file_type_hint: None,
excluded_paths: None,
headers: HeaderMap::new(),
},
Input {
source: InputSource::FsPath(PathBuf::from("fixtures/TEST.md")),
file_type_hint: None,
excluded_paths: None,
headers: HeaderMap::new(),
},
];

6
fixtures/configs/headers.toml vendored Normal file
View file

@ -0,0 +1,6 @@
[header]
X-Foo = "Bar"
X-Bar = "Baz"
# Alternative TOML syntax:
# header = { X-Foo = "Bar", X-Bar = "Baz" }

View file

@ -69,7 +69,7 @@ require_https = false
method = "get"
# Custom request headers
headers = []
header = { X-Foo = "Bar", X-Bar = "Baz" }
# Remap URI matching pattern to different URI.
# This also supports (named) capturing groups.

View file

@ -28,8 +28,10 @@ env_logger = "0.11.8"
futures = "0.3.31"
headers = "0.4.0"
http = "1.3.1"
http-serde = "2.1.1"
humantime = "2.2.0"
humantime-serde = "1.1.1"
human-sort = "0.2.2"
indicatif = "0.17.11"
log = "0.4.27"
openssl-sys = { version = "0.9.108", optional = true }
@ -55,7 +57,7 @@ tokio = { version = "1.45.0", features = ["full"] }
tokio-stream = "0.1.17"
toml = "0.8.22"
url = "2.5.4"
human-sort = "0.2.2"
[dev-dependencies]
assert_cmd = "2.0.17"

View file

@ -1,7 +1,7 @@
use crate::options::Config;
use crate::parse::{parse_duration_secs, parse_headers, parse_remaps};
use crate::options::{Config, HeaderMapExt};
use crate::parse::{parse_duration_secs, parse_remaps};
use anyhow::{Context, Result};
use http::StatusCode;
use http::{HeaderMap, StatusCode};
use lychee_lib::{Client, ClientBuilder};
use regex::RegexSet;
use reqwest_cookie_store::CookieStoreMutex;
@ -10,7 +10,6 @@ use std::{collections::HashSet, str::FromStr};
/// Creates a client according to the command-line config
pub(crate) fn create(cfg: &Config, cookie_jar: Option<&Arc<CookieStoreMutex>>) -> Result<Client> {
let headers = parse_headers(&cfg.header)?;
let timeout = parse_duration_secs(cfg.timeout);
let retry_wait_time = parse_duration_secs(cfg.retry_wait_time);
let method: reqwest::Method = reqwest::Method::from_str(&cfg.method.to_uppercase())?;
@ -34,6 +33,8 @@ pub(crate) fn create(cfg: &Config, cookie_jar: Option<&Arc<CookieStoreMutex>>) -
.map(|value| StatusCode::from_u16(*value))
.collect::<Result<HashSet<_>, _>>()?;
let headers = HeaderMap::from_header_pairs(&cfg.header)?;
ClientBuilder::builder()
.remaps(remaps)
.base(cfg.base_url.clone())

View file

@ -5,6 +5,10 @@ use anyhow::{anyhow, Context, Error, Result};
use clap::builder::PossibleValuesParser;
use clap::{arg, builder::TypedValueParser, Parser};
use const_format::{concatcp, formatcp};
use http::{
header::{HeaderName, HeaderValue},
HeaderMap,
};
use lychee_lib::{
Base, BasicAuthSelector, FileExtensions, FileType, Input, StatusCodeExcluder,
StatusCodeSelector, DEFAULT_MAX_REDIRECTS, DEFAULT_MAX_RETRIES, DEFAULT_RETRY_WAIT_TIME_SECS,
@ -12,7 +16,8 @@ use lychee_lib::{
};
use reqwest::tls;
use secrecy::{ExposeSecret, SecretString};
use serde::Deserialize;
use serde::{Deserialize, Deserializer};
use std::collections::HashMap;
use std::path::Path;
use std::{fs, path::PathBuf, str::FromStr, time::Duration};
use strum::{Display, EnumIter, EnumString, VariantNames};
@ -201,6 +206,98 @@ macro_rules! fold_in {
};
}
/// Parse a single header into a [`HeaderName`] and [`HeaderValue`]
///
/// Headers are expected to be in format `Header-Name: Header-Value`.
/// The header name and value are trimmed of whitespace.
///
/// If the header contains multiple colons, the part after the first colon is
/// considered the value.
fn parse_single_header(header: &str) -> Result<(HeaderName, HeaderValue)> {
let parts: Vec<&str> = header.splitn(2, ':').collect();
match parts.as_slice() {
[name, value] => {
let name = HeaderName::from_bytes(name.trim().as_bytes())
.map_err(|e| anyhow!("Invalid header name '{}': {}", name.trim(), e))?;
let value = HeaderValue::from_str(value.trim())
.map_err(|e| anyhow!("Invalid header value '{}': {}", value.trim(), e))?;
Ok((name, value))
}
_ => Err(anyhow!(
"Invalid header format. Expected colon-separated string in the format 'HeaderName: HeaderValue', got '{}'",
header
)),
}
}
/// Parses a single HTTP header into a tuple of (String, String)
///
/// This does NOT merge multiple headers into one.
#[derive(Clone, Debug)]
struct HeaderParser;
impl TypedValueParser for HeaderParser {
type Value = (String, String);
fn parse_ref(
&self,
_cmd: &clap::Command,
_arg: Option<&clap::Arg>,
value: &std::ffi::OsStr,
) -> Result<Self::Value, clap::Error> {
let header_str = value.to_str().ok_or_else(|| {
clap::Error::raw(
clap::error::ErrorKind::InvalidValue,
"Header value contains invalid UTF-8",
)
})?;
match parse_single_header(header_str) {
Ok((name, value)) => {
let Ok(value) = value.to_str() else {
return Err(clap::Error::raw(
clap::error::ErrorKind::InvalidValue,
"Header value contains invalid UTF-8",
));
};
Ok((name.to_string(), value.to_string()))
}
Err(e) => Err(clap::Error::raw(
clap::error::ErrorKind::InvalidValue,
e.to_string(),
)),
}
}
}
impl clap::builder::ValueParserFactory for HeaderParser {
type Parser = HeaderParser;
fn value_parser() -> Self::Parser {
HeaderParser
}
}
/// Extension trait for converting a Vec of header pairs to a `HeaderMap`
pub(crate) trait HeaderMapExt {
/// Convert a collection of header key-value pairs to a `HeaderMap`
fn from_header_pairs(headers: &[(String, String)]) -> Result<HeaderMap, Error>;
}
impl HeaderMapExt for HeaderMap {
fn from_header_pairs(headers: &[(String, String)]) -> Result<HeaderMap, Error> {
let mut header_map = HeaderMap::new();
for (name, value) in headers {
let header_name = HeaderName::from_bytes(name.as_bytes())
.map_err(|e| anyhow!("Invalid header name '{}': {}", name, e))?;
let header_value = HeaderValue::from_str(value)
.map_err(|e| anyhow!("Invalid header value '{}': {}", value, e))?;
header_map.insert(header_name, header_value);
}
Ok(header_map)
}
}
/// A fast, async link checker
///
/// Finds broken URLs and mail addresses inside Markdown, HTML,
@ -235,14 +332,33 @@ impl LycheeOptions {
} else {
Some(self.config.exclude_path.clone())
};
let headers = HeaderMap::from_header_pairs(&self.config.header)?;
self.raw_inputs
.iter()
.map(|s| Input::new(s, None, self.config.glob_ignore_case, excluded.clone()))
.map(|s| {
Input::new(
s,
None,
self.config.glob_ignore_case,
excluded.clone(),
headers.clone(),
)
})
.collect::<Result<_, _>>()
.context("Cannot parse inputs from arguments")
}
}
// Custom deserializer function for the header field
fn deserialize_headers<'de, D>(deserializer: D) -> Result<Vec<(String, String)>, D::Error>
where
D: Deserializer<'de>,
{
let map = HashMap::<String, String>::deserialize(deserializer)?;
Ok(map.into_iter().collect())
}
/// The main configuration for lychee
#[allow(clippy::struct_excessive_bools)]
#[derive(Parser, Debug, Deserialize, Clone, Default)]
@ -450,10 +566,27 @@ Example: --fallback-extensions html,htm,php,asp,aspx,jsp,cgi"
)]
pub(crate) fallback_extensions: Vec<String>,
/// Custom request header
#[arg(long)]
/// Set custom header for requests
#[arg(
short = 'H',
long = "header",
// Note: We use a `Vec<(String, String)>` for headers, which is
// unfortunate. The reason is that `clap::ArgAction::Append` collects
// multiple values, and `clap` cannot automatically convert these tuples
// into a `HashMap<String, String>`.
action = clap::ArgAction::Append,
value_parser = HeaderParser,
value_name = "HEADER:VALUE",
long_help = "Set custom header for requests
Some websites require custom headers to be passed in order to return valid responses.
You can specify custom headers in the format 'Name: Value'. For example, 'Accept: text/html'.
This is the same format that other tools like curl or wget use.
Multiple headers can be specified by using the flag multiple times."
)]
#[serde(default)]
pub(crate) header: Vec<String>,
#[serde(deserialize_with = "deserialize_headers")]
pub header: Vec<(String, String)>,
/// A List of accepted status codes for valid links
#[arg(
@ -581,6 +714,20 @@ separated list of accepted status codes. This example will accept 200, 201,
}
impl Config {
/// Special handling for merging headers
///
/// Overwrites existing headers in `self` with the values from `other`.
fn merge_headers(&mut self, other: &[(String, String)]) {
let self_map = self.header.iter().cloned().collect::<HashMap<_, _>>();
let other_map = other.iter().cloned().collect::<HashMap<_, _>>();
// Merge the two maps, with `other` taking precedence
let merged_map: HashMap<_, _> = self_map.into_iter().chain(other_map).collect();
// Convert the merged map back to a Vec of tuples
self.header = merged_map.into_iter().collect();
}
/// Load configuration from a file
pub(crate) fn load_from_file(path: &Path) -> Result<Config> {
// Read configuration file
@ -590,52 +737,57 @@ impl Config {
/// Merge the configuration from TOML into the CLI configuration
pub(crate) fn merge(&mut self, toml: Config) {
// Special handling for headers before fold_in!
self.merge_headers(&toml.header);
fold_in! {
// Destination and source configs
self, toml;
// Keys with defaults to assign
verbose: Verbosity::default();
cache: false;
no_progress: false;
max_redirects: DEFAULT_MAX_REDIRECTS;
max_retries: DEFAULT_MAX_RETRIES;
max_concurrency: DEFAULT_MAX_CONCURRENCY;
max_cache_age: humantime::parse_duration(DEFAULT_MAX_CACHE_AGE).unwrap();
cache_exclude_status: StatusCodeExcluder::default();
threads: None;
user_agent: DEFAULT_USER_AGENT;
insecure: false;
scheme: Vec::<String>::new();
include: Vec::<String>::new();
exclude: Vec::<String>::new();
exclude_file: Vec::<String>::new(); // deprecated
exclude_path: Vec::<PathBuf>::new();
exclude_all_private: false;
exclude_private: false;
exclude_link_local: false;
exclude_loopback: false;
format: StatsFormat::default();
remap: Vec::<String>::new();
fallback_extensions: Vec::<String>::new();
header: Vec::<String>::new();
timeout: DEFAULT_TIMEOUT_SECS;
retry_wait_time: DEFAULT_RETRY_WAIT_TIME_SECS;
method: DEFAULT_METHOD;
accept: StatusCodeSelector::default();
base_url: None;
basic_auth: None;
skip_missing: false;
include_verbatim: false;
include_mail: false;
glob_ignore_case: false;
output: None;
require_https: false;
cache_exclude_status: StatusCodeExcluder::default();
cache: false;
cookie_jar: None;
include_fragments: false;
accept: StatusCodeSelector::default();
exclude_all_private: false;
exclude_file: Vec::<String>::new(); // deprecated
exclude_link_local: false;
exclude_loopback: false;
exclude_path: Vec::<PathBuf>::new();
exclude_private: false;
exclude: Vec::<String>::new();
extensions: FileType::default_extensions();
fallback_extensions: Vec::<String>::new();
format: StatsFormat::default();
glob_ignore_case: false;
header: Vec::<(String, String)>::new();
include_fragments: false;
include_mail: false;
include_verbatim: false;
include: Vec::<String>::new();
insecure: false;
max_cache_age: humantime::parse_duration(DEFAULT_MAX_CACHE_AGE).unwrap();
max_concurrency: DEFAULT_MAX_CONCURRENCY;
max_redirects: DEFAULT_MAX_REDIRECTS;
max_retries: DEFAULT_MAX_RETRIES;
method: DEFAULT_METHOD;
no_progress: false;
output: None;
remap: Vec::<String>::new();
require_https: false;
retry_wait_time: DEFAULT_RETRY_WAIT_TIME_SECS;
scheme: Vec::<String>::new();
skip_missing: false;
threads: None;
timeout: DEFAULT_TIMEOUT_SECS;
user_agent: DEFAULT_USER_AGENT;
verbose: Verbosity::default();
}
// If the config file has a value for the GitHub token, but the CLI
// doesn't, use the token from the config file.
if self
.github_token
.as_ref()
@ -654,6 +806,8 @@ impl Config {
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
#[test]
@ -683,4 +837,93 @@ mod tests {
);
assert_eq!(cli.cache_exclude_status, StatusCodeExcluder::new());
}
#[test]
fn test_parse_custom_headers() {
assert_eq!(
parse_single_header("accept:text/html").unwrap(),
(
HeaderName::from_static("accept"),
HeaderValue::from_static("text/html")
)
);
}
#[test]
fn test_parse_custom_header_multiple_colons() {
assert_eq!(
parse_single_header("key:x-test:check=this").unwrap(),
(
HeaderName::from_static("key"),
HeaderValue::from_static("x-test:check=this")
)
);
}
#[test]
fn test_parse_custom_headers_with_equals() {
assert_eq!(
parse_single_header("key:x-test=check=this").unwrap(),
(
HeaderName::from_static("key"),
HeaderValue::from_static("x-test=check=this")
)
);
}
#[test]
fn test_header_parsing_and_merging() {
// Simulate commandline arguments with multiple headers
let args = vec![
"lychee",
"--header",
"Accept: text/html",
"--header",
"X-Test: check=this",
"input.md",
];
// Parse the arguments
let opts = crate::LycheeOptions::parse_from(args);
// Check that the headers were collected correctly
let headers = &opts.config.header;
assert_eq!(headers.len(), 2);
// Convert to HashMap for easier testing
let header_map: HashMap<String, String> = headers.iter().cloned().collect();
assert_eq!(header_map["accept"], "text/html");
assert_eq!(header_map["x-test"], "check=this");
}
#[test]
fn test_merge_headers_with_config() {
let toml = Config {
header: vec![
("Accept".to_string(), "text/html".to_string()),
("X-Test".to_string(), "check=this".to_string()),
],
..Default::default()
};
// Set X-Test and see if it gets overwritten
let mut cli = Config {
header: vec![("X-Test".to_string(), "check=that".to_string())],
..Default::default()
};
cli.merge(toml);
assert_eq!(cli.header.len(), 2);
// Sort vector before assert
cli.header.sort();
assert_eq!(
cli.header,
vec![
("Accept".to_string(), "text/html".to_string()),
("X-Test".to_string(), "check=this".to_string()),
]
);
}
}

View file

@ -1,35 +1,12 @@
use anyhow::{anyhow, Context, Result};
use headers::{HeaderMap, HeaderName};
use anyhow::{Context, Result};
use lychee_lib::{remap::Remaps, Base};
use std::time::Duration;
/// Split a single HTTP header into a (key, value) tuple
fn read_header(input: &str) -> Result<(String, String), anyhow::Error> {
if let Some((key, value)) = input.split_once('=') {
Ok((key.to_string(), value.to_string()))
} else {
Err(anyhow!(
"Header value must be of the form key=value, got {}",
input
))
}
}
/// Parse seconds into a `Duration`
pub(crate) const fn parse_duration_secs(secs: usize) -> Duration {
Duration::from_secs(secs as u64)
}
/// Parse HTTP headers into a `HeaderMap`
pub(crate) fn parse_headers<T: AsRef<str>>(headers: &[T]) -> Result<HeaderMap> {
let mut out = HeaderMap::new();
for header in headers {
let (key, val) = read_header(header.as_ref())?;
out.insert(HeaderName::from_bytes(key.as_bytes())?, val.parse()?);
}
Ok(out)
}
/// Parse URI remaps
pub(crate) fn parse_remaps(remaps: &[String]) -> Result<Remaps> {
Remaps::try_from(remaps)
@ -42,30 +19,10 @@ pub(crate) fn parse_base(src: &str) -> Result<Base, lychee_lib::ErrorKind> {
#[cfg(test)]
mod tests {
use headers::HeaderMap;
use regex::Regex;
use reqwest::header;
use super::*;
#[test]
fn test_parse_custom_headers() {
let mut custom = HeaderMap::new();
custom.insert(header::ACCEPT, "text/html".parse().unwrap());
assert_eq!(parse_headers(&["accept=text/html"]).unwrap(), custom);
}
#[test]
fn test_parse_custom_headers_with_equals() {
let mut custom_with_equals = HeaderMap::new();
custom_with_equals.insert("x-test", "check=this".parse().unwrap());
assert_eq!(
parse_headers(&["x-test=check=this"]).unwrap(),
custom_with_equals
);
}
#[test]
fn test_parse_remap() {
let remaps =

View file

@ -12,7 +12,7 @@ mod cli {
use anyhow::anyhow;
use assert_cmd::Command;
use assert_json_diff::assert_json_include;
use http::StatusCode;
use http::{Method, StatusCode};
use lychee_lib::{InputSource, ResponseBody};
use predicates::{
prelude::{predicate, PredicateBooleanExt},
@ -1950,6 +1950,118 @@ mod cli {
Ok(())
}
#[tokio::test]
async fn test_no_header_set_on_input() -> Result<()> {
let mut cmd = main_command();
let server = wiremock::MockServer::start().await;
server
.register(
wiremock::Mock::given(wiremock::matchers::method("GET"))
.respond_with(wiremock::ResponseTemplate::new(200))
.expect(1),
)
.await;
cmd.arg("--verbose").arg(server.uri()).assert().success();
let received_requests = server.received_requests().await.unwrap();
assert_eq!(received_requests.len(), 1);
let received_request = &received_requests[0];
assert_eq!(received_request.method, Method::GET);
assert_eq!(received_request.url.path(), "/");
// Make sure the request does not contain the custom header
assert!(!received_request.headers.contains_key("X-Foo"));
Ok(())
}
#[tokio::test]
async fn test_header_set_on_input() -> Result<()> {
let mut cmd = main_command();
let server = wiremock::MockServer::start().await;
server
.register(
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::header("X-Foo", "Bar"))
.respond_with(wiremock::ResponseTemplate::new(200))
// We expect the mock to be called exactly least once.
.expect(1)
.named("GET expecting custom header"),
)
.await;
cmd.arg("--verbose")
.arg("--header")
.arg("X-Foo: Bar")
.arg(server.uri())
.assert()
.success();
// Check that the server received the request with the header
server.verify().await;
Ok(())
}
#[tokio::test]
async fn test_multi_header_set_on_input() -> Result<()> {
let mut cmd = main_command();
let server = wiremock::MockServer::start().await;
server
.register(
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::header("X-Foo", "Bar"))
.and(wiremock::matchers::header("X-Bar", "Baz"))
.respond_with(wiremock::ResponseTemplate::new(200))
// We expect the mock to be called exactly least once.
.expect(1)
.named("GET expecting custom header"),
)
.await;
cmd.arg("--verbose")
.arg("--header")
.arg("X-Foo: Bar")
.arg("--header")
.arg("X-Bar: Baz")
.arg(server.uri())
.assert()
.success();
// Check that the server received the request with the header
server.verify().await;
Ok(())
}
#[tokio::test]
async fn test_header_set_in_config() -> Result<()> {
let mut cmd = main_command();
let server = wiremock::MockServer::start().await;
server
.register(
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::header("X-Foo", "Bar"))
.and(wiremock::matchers::header("X-Bar", "Baz"))
.respond_with(wiremock::ResponseTemplate::new(200))
// We expect the mock to be called exactly least once.
.expect(1)
.named("GET expecting custom header"),
)
.await;
let config = fixtures_path().join("configs").join("headers.toml");
cmd.arg("--verbose")
.arg("--config")
.arg(config)
.arg(server.uri())
.assert()
.success();
// Check that the server received the request with the header
server.verify().await;
Ok(())
}
#[test]
fn test_sorted_error_output() -> Result<()> {
let test_files = ["TEST_GITHUB_404.md", "TEST_INVALID_URLS.html"];

View file

@ -181,7 +181,7 @@ impl Collector {
mod tests {
use std::{collections::HashSet, convert::TryFrom, fs::File, io::Write};
use http::StatusCode;
use http::{HeaderMap, StatusCode};
use reqwest::Url;
use super::*;
@ -230,7 +230,13 @@ mod tests {
// Treat as plaintext file (no extension)
let file_path = temp_dir.path().join("README");
let _file = File::create(&file_path).unwrap();
let input = Input::new(&file_path.as_path().display().to_string(), None, true, None)?;
let input = Input::new(
&file_path.as_path().display().to_string(),
None,
true,
None,
HeaderMap::new(),
)?;
let contents: Vec<_> = input
.get_contents(true, true, true, FileType::default_extensions())
.collect::<Vec<_>>()
@ -243,7 +249,7 @@ mod tests {
#[tokio::test]
async fn test_url_without_extension_is_html() -> Result<()> {
let input = Input::new("https://example.com/", None, true, None)?;
let input = Input::new("https://example.com/", None, true, None, HeaderMap::new())?;
let contents: Vec<_> = input
.get_contents(true, true, true, FileType::default_extensions())
.collect::<Vec<_>>()
@ -278,6 +284,7 @@ mod tests {
source: InputSource::String(TEST_STRING.to_owned()),
file_type_hint: None,
excluded_paths: None,
headers: HeaderMap::new(),
},
Input {
source: InputSource::RemoteUrl(Box::new(
@ -287,11 +294,13 @@ mod tests {
)),
file_type_hint: None,
excluded_paths: None,
headers: HeaderMap::new(),
},
Input {
source: InputSource::FsPath(file_path),
file_type_hint: None,
excluded_paths: None,
headers: HeaderMap::new(),
},
Input {
source: InputSource::FsGlob {
@ -300,6 +309,7 @@ mod tests {
},
file_type_hint: None,
excluded_paths: None,
headers: HeaderMap::new(),
},
];
@ -327,7 +337,8 @@ mod tests {
let input = Input {
source: InputSource::String("This is [a test](https://endler.dev). This is a relative link test [Relative Link Test](relative_link)".to_string()),
file_type_hint: Some(FileType::Markdown),
excluded_paths: None,
excluded_paths: None,
headers: HeaderMap::new(),
};
let links = collect(vec![input], None, Some(base)).await.ok().unwrap();
@ -354,6 +365,7 @@ mod tests {
),
file_type_hint: Some(FileType::Html),
excluded_paths: None,
headers: HeaderMap::new(),
};
let links = collect(vec![input], None, Some(base)).await.ok().unwrap();
@ -383,6 +395,7 @@ mod tests {
),
file_type_hint: Some(FileType::Html),
excluded_paths: None,
headers: HeaderMap::new(),
};
let links = collect(vec![input], None, Some(base)).await.ok().unwrap();
@ -409,6 +422,7 @@ mod tests {
),
file_type_hint: Some(FileType::Markdown),
excluded_paths: None,
headers: HeaderMap::new(),
};
let links = collect(vec![input], None, Some(base)).await.ok().unwrap();
@ -432,6 +446,7 @@ mod tests {
source: InputSource::String(input),
file_type_hint: Some(FileType::Html),
excluded_paths: None,
headers: HeaderMap::new(),
};
let links = collect(vec![input], None, Some(base)).await.ok().unwrap();
@ -464,6 +479,7 @@ mod tests {
source: InputSource::RemoteUrl(Box::new(server_uri.clone())),
file_type_hint: None,
excluded_paths: None,
headers: HeaderMap::new(),
};
let links = collect(vec![input], None, None).await.ok().unwrap();
@ -484,6 +500,7 @@ mod tests {
),
file_type_hint: None,
excluded_paths: None,
headers: HeaderMap::new(),
};
let links = collect(vec![input], None, None).await.ok().unwrap();
@ -514,6 +531,7 @@ mod tests {
)),
file_type_hint: Some(FileType::Html),
excluded_paths: None,
headers: HeaderMap::new(),
},
Input {
source: InputSource::RemoteUrl(Box::new(
@ -525,6 +543,7 @@ mod tests {
)),
file_type_hint: Some(FileType::Html),
excluded_paths: None,
headers: HeaderMap::new(),
},
];
@ -560,6 +579,7 @@ mod tests {
),
file_type_hint: Some(FileType::Html),
excluded_paths: None,
headers: HeaderMap::new(),
};
let links = collect(vec![input], None, Some(base)).await.ok().unwrap();

View file

@ -3,6 +3,7 @@ use crate::{utils, ErrorKind, Result};
use async_stream::try_stream;
use futures::stream::Stream;
use glob::glob_with;
use http::HeaderMap;
use ignore::WalkBuilder;
use reqwest::Url;
use serde::{Deserialize, Serialize};
@ -101,7 +102,7 @@ impl Display for InputSource {
}
/// Lychee Input with optional file hint for parsing
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Input {
/// Origin of input
pub source: InputSource,
@ -109,6 +110,8 @@ pub struct Input {
pub file_type_hint: Option<FileType>,
/// Excluded paths that will be skipped when reading content
pub excluded_paths: Option<Vec<PathBuf>>,
/// Custom headers to be used when fetching remote URLs
pub headers: reqwest::header::HeaderMap,
}
impl Input {
@ -125,6 +128,7 @@ impl Input {
file_type_hint: Option<FileType>,
glob_ignore_case: bool,
excluded_paths: Option<Vec<PathBuf>>,
headers: reqwest::header::HeaderMap,
) -> Result<Self> {
let source = if value == STDIN {
InputSource::Stdin
@ -190,9 +194,20 @@ impl Input {
source,
file_type_hint,
excluded_paths,
headers,
})
}
/// Convenience constructor with sane defaults
///
/// # Errors
///
/// Returns an error if the input does not exist (i.e. invalid path)
/// and the input cannot be parsed as a URL.
pub fn from_value(value: &str) -> Result<Self> {
Self::new(value, None, false, None, HeaderMap::new())
}
/// Retrieve the contents from the input
///
/// If the input is a path, only search through files that match the given
@ -215,7 +230,7 @@ impl Input {
try_stream! {
match self.source {
InputSource::RemoteUrl(ref url) => {
let content = Self::url_contents(url).await;
let content = Self::url_contents(url, &self.headers).await;
match content {
Err(_) if skip_missing => (),
Err(e) => Err(e)?,
@ -313,7 +328,7 @@ impl Input {
}
}
async fn url_contents(url: &Url) -> Result<InputContent> {
async fn url_contents(url: &Url, headers: &HeaderMap) -> Result<InputContent> {
// Assume HTML for default paths
let file_type = if url.path().is_empty() || url.path() == "/" {
FileType::Html
@ -321,7 +336,12 @@ impl Input {
FileType::from(url.as_str())
};
let res = reqwest::get(url.clone())
let client = reqwest::Client::new();
let res = client
.get(url.clone())
.headers(headers.clone())
.send()
.await
.map_err(ErrorKind::NetworkRequest)?;
let input_content = InputContent {
@ -414,6 +434,14 @@ impl Input {
}
}
impl TryFrom<&str> for Input {
type Error = crate::ErrorKind;
fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
Self::from_value(value)
}
}
/// Function for path exclusion tests
///
/// This is a standalone function to allow for easier testing
@ -428,6 +456,8 @@ fn is_excluded_path(excluded_paths: &[PathBuf], path: &PathBuf) -> bool {
#[cfg(test)]
mod tests {
use http::HeaderMap;
use super::*;
#[test]
@ -438,14 +468,15 @@ mod tests {
assert!(path.exists());
assert!(path.is_relative());
let input = Input::new(test_file, None, false, None);
let input = Input::new(test_file, None, false, None, HeaderMap::new());
assert!(input.is_ok());
assert!(matches!(
input,
Ok(Input {
source: InputSource::FsPath(PathBuf { .. }),
file_type_hint: None,
excluded_paths: None
excluded_paths: None,
headers: _,
})
));
}
@ -458,7 +489,7 @@ mod tests {
assert!(!path.exists());
assert!(path.is_relative());
let input = Input::new(test_file, None, false, None);
let input = Input::from_value(test_file);
assert!(input.is_err());
assert!(matches!(input, Err(ErrorKind::InvalidFile(PathBuf { .. }))));
}
@ -490,7 +521,7 @@ mod tests {
#[test]
fn test_url_without_scheme() {
let input = Input::new("example.com", None, false, None);
let input = Input::from_value("example.com");
assert_eq!(
input.unwrap().source.to_string(),
String::from("http://example.com/")
@ -501,7 +532,7 @@ mod tests {
#[cfg(windows)]
#[test]
fn test_windows_style_filepath_not_existing() {
let input = Input::new("C:\\example\\project\\here", None, false, None);
let input = Input::from_value("C:\\example\\project\\here");
assert!(input.is_err());
let input = input.unwrap_err();
@ -521,7 +552,7 @@ mod tests {
let dir = temp_dir();
let file = NamedTempFile::new_in(dir).unwrap();
let path = file.path();
let input = Input::new(path.to_str().unwrap(), None, false, None).unwrap();
let input = Input::from_value(path.to_str().unwrap()).unwrap();
match input.source {
InputSource::FsPath(_) => (),
@ -533,33 +564,28 @@ mod tests {
fn test_url_scheme_check_succeeding() {
// Valid http and https URLs
assert!(matches!(
Input::new("http://example.com", None, false, None),
Input::from_value("http://example.com"),
Ok(Input {
source: InputSource::RemoteUrl(_),
..
})
));
assert!(matches!(
Input::new("https://example.com", None, false, None),
Input::from_value("https://example.com"),
Ok(Input {
source: InputSource::RemoteUrl(_),
..
})
));
assert!(matches!(
Input::new(
"http://subdomain.example.com/path?query=value",
None,
false,
None
),
Input::from_value("http://subdomain.example.com/path?query=value",),
Ok(Input {
source: InputSource::RemoteUrl(_),
..
})
));
assert!(matches!(
Input::new("https://example.com:8080", None, false, None),
Input::from_value("https://example.com:8080"),
Ok(Input {
source: InputSource::RemoteUrl(_),
..
@ -571,19 +597,19 @@ mod tests {
fn test_url_scheme_check_failing() {
// Invalid schemes
assert!(matches!(
Input::new("ftp://example.com", None, false, None),
Input::from_value("ftp://example.com"),
Err(ErrorKind::InvalidFile(_))
));
assert!(matches!(
Input::new("httpx://example.com", None, false, None),
Input::from_value("httpx://example.com"),
Err(ErrorKind::InvalidFile(_))
));
assert!(matches!(
Input::new("file:///path/to/file", None, false, None),
Input::from_value("file:///path/to/file"),
Err(ErrorKind::InvalidFile(_))
));
assert!(matches!(
Input::new("mailto:user@example.com", None, false, None),
Input::from_value("mailto:user@example.com"),
Err(ErrorKind::InvalidFile(_))
));
}
@ -592,11 +618,11 @@ mod tests {
fn test_non_url_inputs() {
// Non-URL inputs
assert!(matches!(
Input::new("./local/path", None, false, None),
Input::from_value("./local/path"),
Err(ErrorKind::InvalidFile(_))
));
assert!(matches!(
Input::new("*.md", None, false, None),
Input::from_value("*.md"),
Ok(Input {
source: InputSource::FsGlob { .. },
..
@ -604,7 +630,7 @@ mod tests {
));
// Assuming the current directory exists
assert!(matches!(
Input::new(".", None, false, None),
Input::from_value("."),
Ok(Input {
source: InputSource::FsPath(_),
..

View file

@ -69,7 +69,7 @@ require_https = false
method = "get"
# Custom request headers
header = ["name=value", "other=value"]
header = { "accept" = "text/html", "x-custom-header" = "value" }
# Remap URI matching pattern to different URI.
remap = ["https://example.com http://example.invalid"]