diff --git a/src/main.py b/src/main.py index fd2b06d..d50945d 100644 --- a/src/main.py +++ b/src/main.py @@ -640,26 +640,36 @@ def progressive_extension( def short_context_recovery(model, data, base_length, lambda_factors_base, n_hat_base): """ - Recover performance on shorter context lengths. + This function ensures that the model maintains good performance on shorter contexts (4k and 8k) + even after being extended to very long contexts. + Args: - model (nn.Module): LongRoPE model. - data (list): List of input sequences. - base_length (int): Base context window length. - lambda_factors_base (list): Base lambda factors. - n_hat_base (int): Base n_hat. + model (nn.Module): Extended LongRoPE model. + data (list): List of input sequences for fine-tuning and evaluation. + base_length (int): Original context window length of the model. + lambda_factors_base (list): Base lambda factors for the extended model. + n_hat_base (int): Base n_hat for the extended model. Returns: - nn.Module: Recovered LongRoPE model. + nn.Module: LongRoPE model with recovered short context performance. """ - short_lengths = [base_length // 2, base_length // 4] + short_lengths = [4096, 8192] # Specific lengths mentioned in the paper for length in short_lengths: extension_ratio = length / base_length lambda_factors, n_hat = search_lambda_factors( - model, data, extension_ratio, max_length=length + model, + data, + extension_ratio, + population_size=64, + num_mutations=16, + num_crossovers=16, + max_iterations=40, ) - model = fine_tune(model, data, length, lambda_factors, n_hat) + # Fine-tune for short context recovery + model = fine_tune(model, data, length, lambda_factors, n_hat, steps=100) + # Store base factors for use during inference model.lambda_factors_base = lambda_factors_base model.n_hat_base = n_hat_base