From 87099b9a029fbe096840204ec2b32e4fb4381aea Mon Sep 17 00:00:00 2001 From: Zanie Blue Date: Mon, 8 Apr 2024 15:58:51 -0500 Subject: [PATCH] Allow the request URL to be used for subsequent responses --- src/lib.rs | 86 ++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 80 insertions(+), 6 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 2b99aa0..5fad4b4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -72,7 +72,8 @@ pub use error::AsyncHttpRangeReaderError; /// if response.status() == reqwest::StatusCode::NOT_MODIFIED { /// Ok(None) /// } else { -/// let reader = AsyncHttpRangeReader::from_head_response(client, response, HeaderMap::default()).await?; +/// let url = response.url().clone(); +/// let reader = AsyncHttpRangeReader::from_head_response(client, response, url, HeaderMap::default()).await?; /// Ok(Some(reader)) /// } /// } @@ -131,6 +132,15 @@ pub enum CheckSupportMethod { Head, } +/// Which URL should be used for subsequent range requests? +pub enum RangeRequestUrlSource { + /// Use the initial request URL + Request, + + /// Use the initial response URL + Response, +} + fn error_for_status(response: reqwest::Response) -> reqwest_middleware::Result { response .error_for_status() @@ -143,6 +153,7 @@ impl AsyncHttpRangeReader { client: impl Into, url: reqwest::Url, check_method: CheckSupportMethod, + range_request_url_source: RangeRequestUrlSource, extra_headers: HeaderMap, ) -> Result<(Self, HeaderMap), AsyncHttpRangeReaderError> { let client = client.into(); @@ -156,7 +167,11 @@ impl AsyncHttpRangeReader { ) .await?; let response_headers = response.headers().clone(); - let self_ = Self::from_tail_response(client, response, extra_headers).await?; + let url = match range_request_url_source { + RangeRequestUrlSource::Request => url, + RangeRequestUrlSource::Response => response.url().clone(), + }; + let self_ = Self::from_tail_response(client, response, url, extra_headers).await?; Ok((self_, response_headers)) } CheckSupportMethod::Head => { @@ -164,7 +179,11 @@ impl AsyncHttpRangeReader { Self::initial_head_request(client.clone(), url.clone(), HeaderMap::default()) .await?; let response_headers = response.headers().clone(); - let self_ = Self::from_head_response(client, response, extra_headers).await?; + let url = match range_request_url_source { + RangeRequestUrlSource::Request => url, + RangeRequestUrlSource::Response => response.url().clone(), + }; + let self_ = Self::from_head_response(client, response, url, extra_headers).await?; Ok((self_, response_headers)) } } @@ -200,6 +219,7 @@ impl AsyncHttpRangeReader { pub async fn from_tail_response( client: impl Into, tail_request_response: Response, + url: Url, extra_headers: HeaderMap, ) -> Result { let client = client.into(); @@ -245,7 +265,7 @@ impl AsyncHttpRangeReader { let (state_tx, state_rx) = watch::channel(StreamerState::default()); tokio::spawn(run_streamer( client, - tail_request_response.url().clone(), + url, extra_headers, Some((tail_request_response, start)), memory_map, @@ -300,6 +320,7 @@ impl AsyncHttpRangeReader { pub async fn from_head_response( client: impl Into, head_response: Response, + url: Url, extra_headers: HeaderMap, ) -> Result { let client = client.into(); @@ -345,7 +366,7 @@ impl AsyncHttpRangeReader { let (state_tx, state_rx) = watch::channel(StreamerState::default()); tokio::spawn(run_streamer( client, - head_response.url().clone(), + url, extra_headers, None, memory_map, @@ -688,6 +709,7 @@ mod test { Client::new(), server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(), check_method, + RangeRequestUrlSource::Response, HeaderMap::default(), ) .await @@ -728,7 +750,7 @@ mod test { ); // Prefetch the data for the metadata.json file - let entry = reader.file().entries().get(0).unwrap(); + let entry = reader.file().entries().first().unwrap(); let offset = entry.header_offset(); // Get the size of the entry plus the header + size of the filename. We should also actually // include bytes for the extra fields but we don't have that information. @@ -783,6 +805,57 @@ mod test { Client::new(), server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(), check_method, + RangeRequestUrlSource::Response, + HeaderMap::default(), + ) + .await + .expect("bla"); + + // Also open a simple file reader + let mut file = tokio::fs::File::open(path.join("andes-1.8.3-pyhd8ed1ab_0.conda")) + .await + .unwrap(); + + // Read until the end and make sure that the contents matches + let mut range_read = vec![0; 64 * 1024]; + let mut file_read = vec![0; 64 * 1024]; + loop { + // Read with the async reader + let range_read_bytes = range.read(&mut range_read).await.unwrap(); + + // Read directly from the file + let file_read_bytes = file + .read_exact(&mut file_read[0..range_read_bytes]) + .await + .unwrap(); + + assert_eq!(range_read_bytes, file_read_bytes); + assert_eq!( + range_read[0..range_read_bytes], + file_read[0..file_read_bytes] + ); + + if file_read_bytes == 0 && range_read_bytes == 0 { + break; + } + } + } + + #[rstest] + #[case(RangeRequestUrlSource::Request)] + #[case(RangeRequestUrlSource::Response)] + #[tokio::test] + async fn async_range_reader_url_source(#[case] url_source: RangeRequestUrlSource) { + // Spawn a static file server + let path = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("test-data"); + let server = StaticDirectoryServer::new(&path); + + // Construct an AsyncRangeReader + let (mut range, _) = AsyncHttpRangeReader::new( + Client::new(), + server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(), + CheckSupportMethod::Head, + url_source, HeaderMap::default(), ) .await @@ -825,6 +898,7 @@ mod test { Client::new(), server.url().join("not-found").unwrap(), CheckSupportMethod::Head, + RangeRequestUrlSource::Response, HeaderMap::default(), ) .await