diff --git a/CHANGELOG.md b/CHANGELOG.md index 95081b6942a..6c8e8627e5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed Broken link in `allennlp.fairness.fairness_metrics.Separation` docs - Ensured all `allennlp` submodules are imported with `allennlp.common.plugins.import_plugins()`. - Fixed `IndexOutOfBoundsException` in `MultiOptimizer` when checking if optimizer received any parameters. -- Removed confusing zero mask from VilBERT +- Removed confusing zero mask from VilBERT. +- Ensured `ensure_model_can_train_save_and_load` is consistently random. ### Changed diff --git a/allennlp/common/testing/model_test_case.py b/allennlp/common/testing/model_test_case.py index 48afe3307d8..c58c4f6777c 100644 --- a/allennlp/common/testing/model_test_case.py +++ b/allennlp/common/testing/model_test_case.py @@ -119,11 +119,6 @@ def ensure_model_can_train_save_and_load( Specifies which loss to test. For example, which_loss may be "adversary_loss" for `adversarial_bias_mitigator`. """ - if seed is not None: - random.seed(seed) - numpy.random.seed(seed) - torch.manual_seed(seed) - save_dir = self.TEST_DIR / "save_and_load_test" archive_file = save_dir / "model.tar.gz" model = train_model_from_file(param_file, save_dir, overrides=overrides) @@ -158,12 +153,22 @@ def ensure_model_can_train_save_and_load( data_loader_params["shuffle"] = False data_loader_params2 = Params(copy.deepcopy(data_loader_params.as_dict())) + if seed is not None: + random.seed(seed) + numpy.random.seed(seed) + torch.manual_seed(seed) + print("Reading with original model") data_loader = DataLoader.from_params( params=data_loader_params, reader=reader, data_path=params["validation_data_path"] ) data_loader.index_with(model.vocab) + if seed is not None: + random.seed(seed) + numpy.random.seed(seed) + torch.manual_seed(seed) + print("Reading with loaded model") data_loader2 = DataLoader.from_params( params=data_loader_params2, reader=reader, data_path=params["validation_data_path"]