Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Making model test case consistently random (#5278)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacob-morrison authored Jun 22, 2021
1 parent 5a7844b commit 86504e6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed Broken link in `allennlp.fairness.fairness_metrics.Separation` docs
- Ensured all `allennlp` submodules are imported with `allennlp.common.plugins.import_plugins()`.
- Fixed `IndexOutOfBoundsException` in `MultiOptimizer` when checking if optimizer received any parameters.
- Removed confusing zero mask from VilBERT
- Removed confusing zero mask from VilBERT.
- Ensured `ensure_model_can_train_save_and_load` is consistently random.

### Changed

Expand Down
15 changes: 10 additions & 5 deletions allennlp/common/testing/model_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,6 @@ def ensure_model_can_train_save_and_load(
Specifies which loss to test. For example, which_loss may be "adversary_loss" for
`adversarial_bias_mitigator`.
"""
if seed is not None:
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)

save_dir = self.TEST_DIR / "save_and_load_test"
archive_file = save_dir / "model.tar.gz"
model = train_model_from_file(param_file, save_dir, overrides=overrides)
Expand Down Expand Up @@ -158,12 +153,22 @@ def ensure_model_can_train_save_and_load(
data_loader_params["shuffle"] = False
data_loader_params2 = Params(copy.deepcopy(data_loader_params.as_dict()))

if seed is not None:
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)

print("Reading with original model")
data_loader = DataLoader.from_params(
params=data_loader_params, reader=reader, data_path=params["validation_data_path"]
)
data_loader.index_with(model.vocab)

if seed is not None:
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)

print("Reading with loaded model")
data_loader2 = DataLoader.from_params(
params=data_loader_params2, reader=reader, data_path=params["validation_data_path"]
Expand Down

0 comments on commit 86504e6

Please sign in to comment.