Skip to content

Commit

Permalink
Handle retries in the aws client, use adaptive backoff (#241)
Browse files Browse the repository at this point in the history
  • Loading branch information
undfined authored Feb 19, 2025
1 parent 17242c3 commit 6c643eb
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/deduper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ fn write_attributes(
label_temp: bool,
) -> Result<(), io::Error> {
let cache = FileCache {
s3_client: Box::new(s3_util::new_client(None)?),
s3_client: Box::new(s3_util::new_client(None, None)?),
work: work_dirs.clone(),
};

Expand Down
55 changes: 42 additions & 13 deletions src/s3_util.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::io;
use std::path::Path;

use aws_sdk_s3::config::retry::RetryConfig;
use aws_sdk_s3::config::Region;
use aws_sdk_s3::error::ProvideErrorMetadata;
use aws_sdk_s3::primitives::ByteStream;
Expand Down Expand Up @@ -284,7 +285,18 @@ pub fn find_objects_matching_patterns(
Ok(stream_inputs)
}

pub fn new_client(region_name: Option<String>) -> Result<S3Client, io::Error> {
pub fn new_client(
region_name: Option<String>,
retry_attempts: Option<u32>,
) -> Result<S3Client, io::Error> {
// Check that retry_attempts is greater than 0
let retry_attempts = retry_attempts.unwrap_or(3); // Default to 3 if not provided
if retry_attempts < 1 {
return Err(io::Error::new(
io::ErrorKind::Other,
"retry_attempts must be greater than or equal to 1",
));
}
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
Expand All @@ -295,7 +307,19 @@ pub fn new_client(region_name: Option<String>) -> Result<S3Client, io::Error> {
.unwrap_or_else(|_| region_name.unwrap_or_else(|| String::from("us-east-1"))),
);

let config = rt.block_on(aws_config::from_env().region(region).load());
let retry_config = RetryConfig::adaptive() // Use adaptive retry strategy instead of naive default
.with_max_attempts(retry_attempts) // Set maximum number of retry attempts
.with_initial_backoff(Duration::from_millis(100)) // Initial delay between retries
.with_max_backoff(Duration::from_secs(10));

// Load the AWS configuration with the custom retry config
let config = rt.block_on(
aws_config::from_env()
.retry_config(retry_config)
.region(region)
.load(),
);

let s3_client = S3Client::new(&config);
Ok(s3_client)
}
Expand All @@ -310,6 +334,7 @@ mod test {
use std::io;
use std::io::{BufRead, BufReader};
use std::path::Path;
use std::println as info;

use flate2::read::MultiGzDecoder;

Expand Down Expand Up @@ -393,7 +418,7 @@ mod test {
.enable_all()
.build()
.unwrap();
let s3_client = new_client(None)?;
let s3_client = new_client(None, None)?;
let s3_prefix = get_dolma_test_prefix();

let s3_dest = "/pretraining-data/tests/mixer/inputs/v0/documents/head/0000.json.gz";
Expand All @@ -407,7 +432,7 @@ mod test {
Path::new(local_source_file),
s3_bucket,
s3_key,
Some(3), // number of attempts
Some(1),
))?;

// check the size matches expected
Expand All @@ -427,7 +452,7 @@ mod test {
.enable_all()
.build()
.unwrap();
let s3_client = new_client(None)?;
let s3_client = new_client(None, None)?;

let s3_prefix = get_dolma_test_prefix();
let s3_dest = "/pretraining-data/tests/mixer/inputs/v0/documents/head/0000.json.gz";
Expand All @@ -441,7 +466,7 @@ mod test {
Path::new(local_source_file),
s3_bucket,
s3_key,
Some(3), // number of attempts
Some(1),
))?;

// download the file back from s3
Expand All @@ -453,7 +478,7 @@ mod test {
s3_bucket,
s3_key,
Path::new(local_output_file),
Some(3), // number of attempts
Some(1),
))?;

// compare the contents of the two files
Expand All @@ -471,7 +496,7 @@ mod test {
.enable_all()
.build()
.unwrap();
let s3_client = new_client(None)?;
let s3_client = new_client(None, None)?;

let s3_prefix = get_dolma_test_prefix();
let s3_dest = "/foo/bar/baz.json.gz";
Expand Down Expand Up @@ -499,15 +524,19 @@ mod test {
s3_bucket,
s3_key,
Path::new(local_output_file),
Some(3), // number of attempts
Some(1),
));

assert!(resp_no_such_location.is_err());
let exp_msg = format!(
"All 3 attempts to download '{}' to '{}' failed",
"All 1 attempts to download '{}' to '{}' failed",
s3_path, local_output_file
);
assert_eq!(resp_no_such_location.unwrap_err().to_string(), exp_msg);
let error_string = resp_no_such_location.unwrap_err().to_string();
let actual: &str = error_string.as_str();
info!("actual: {}", actual);

assert_eq!(actual, exp_msg);
Ok(())
}

Expand All @@ -522,7 +551,7 @@ mod test {
.build()
.unwrap();

let s3_client = new_client(None)?;
let s3_client = new_client(None, None)?;
let s3_prefix = get_dolma_test_prefix();

let local_source_dir = "tests/data/expected";
Expand All @@ -547,7 +576,7 @@ mod test {
Path::new(local_source_file.to_str().unwrap()),
s3_bucket,
s3_key,
Some(3), // number of attempts
Some(1),
))?;
}

Expand Down
10 changes: 5 additions & 5 deletions src/shard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ impl Shard {
// Upload the output file to S3.
pub fn process(&self, work_dirs: WorkDirConfig) -> Result<(), IoError> {
let cache: FileCache = FileCache {
s3_client: Box::new(s3_util::new_client(None)?),
s3_client: Box::new(s3_util::new_client(None, None)?),
work: work_dirs.clone(),
};
let min_text_length = self.min_text_length.clone().unwrap_or(0);
Expand Down Expand Up @@ -738,7 +738,7 @@ impl FileCache {
bucket,
key,
&path,
Some(3), // retry twice if fail
Some(1),
))?;
log::info!("Download complete.");
Ok(path.clone())
Expand Down Expand Up @@ -802,7 +802,7 @@ impl FileCache {
&path,
bucket,
key,
Some(3), // retry twice if fail
Some(1),
))?;
std::fs::remove_file(&path)?;
{
Expand Down Expand Up @@ -831,7 +831,7 @@ pub fn find_objects_matching_patterns(patterns: &Vec<String>) -> Result<Vec<Stri
}
Ok(matches)
} else if s3_url_count == patterns.len() {
let s3_client = s3_util::new_client(None)?;
let s3_client = s3_util::new_client(None, None)?;
s3_util::find_objects_matching_patterns(&s3_client, patterns)
} else {
Err(IoError::new(
Expand All @@ -855,7 +855,7 @@ pub fn get_object_sizes(locations: &Vec<String>) -> Result<Vec<usize>, IoError>
.collect();
Ok(sizes)
} else if s3_url_count == locations.len() {
let s3_client = s3_util::new_client(None)?;
let s3_client = s3_util::new_client(None, None)?;
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
Expand Down

0 comments on commit 6c643eb

Please sign in to comment.