-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest.py
98 lines (78 loc) · 3.72 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import logging
import torch
torch.set_float32_matmul_precision("medium")
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from data_loader import DataModule
from trainer import MusicClassifier
import yaml
from omegaconf import DictConfig
import hydra
from hydra.utils import to_absolute_path
from hydra.core.hydra_config import HydraConfig
from pytorch_lightning.utilities.combined_loader import CombinedLoader
log = logging.getLogger(__name__)
def get_latest_version(log_dir):
version_dirs = [d for d in os.listdir(log_dir) if d.startswith('version_')]
version_dirs.sort(key=lambda x: int(x.split('_')[-1])) # Sort by version number
return version_dirs[-1] if version_dirs else None
def save_metrics_and_checkpoint(metrics, checkpoint, output_file):
data = {
'checkpoint': checkpoint,
'metrics': metrics
}
with open(output_file, 'w') as f:
yaml.dump(data, f)
def read_best_checkpoint_info(file_path, dataset_type=None):
"""Read the best checkpoint file."""
if not os.path.exists(file_path):
raise FileNotFoundError(f"Checkpoint info file not found: {file_path}")
with open(file_path, 'r') as f:
lines = f.readlines()
if dataset_type == "mood":
checkpoint_line = next((line for line in lines if line.startswith("Best checkpoint (mood):")), None)
elif dataset_type == "va":
checkpoint_line = next((line for line in lines if line.startswith("Best checkpoint (va):")), None)
else:
checkpoint_line = next((line for line in lines if line.startswith("Best checkpoint:")), None)
if not checkpoint_line:
raise ValueError(f"No checkpoint found for dataset type '{dataset_type}' in the file.")
return checkpoint_line.split(": ")[-1].strip()
@hydra.main(version_base=None, config_path="config", config_name="test_config")
def main(config: DictConfig):
log.info("Testing starts")
log_base_dir = 'tb_logs/train_audio_classification'
# log_base_dir = to_absolute_path('tb_logs/train_audio_classification')
latest_version = get_latest_version(log_base_dir)
if not latest_version:
raise FileNotFoundError("No version directories found in log base directory.")
version_log_dir = os.path.join(log_base_dir, latest_version)
output_file = os.path.join(version_log_dir, 'test_metrics.txt')
if config.checkpoint_latest:
if config.multitask:
dataset_type = config.dataset_type # Expecting 'mood' or 'va'
best_checkpoint_file = os.path.join(version_log_dir, 'best_checkpoint.txt')
ckpt = read_best_checkpoint_info(best_checkpoint_file, dataset_type)
else:
best_checkpoint_file = os.path.join(version_log_dir, 'best_checkpoint.txt')
ckpt = read_best_checkpoint_info(best_checkpoint_file)
else:
ckpt = config.checkpoint
if not os.path.exists(ckpt):
raise FileNotFoundError(f"Checkpoint file not found: {ckpt}")
log.info(f"Using checkpoint: {ckpt}")
data_module = DataModule( config )
data_module.setup()
testloaders = {dataset_name: loader for dataset_name, loader in zip(config.datasets, data_module.test_dataloader())}
combined_test_loader = CombinedLoader(testloaders, mode="max_size")
model = MusicClassifier.load_from_checkpoint(ckpt, cfg=config, output_file=output_file)
logger = TensorBoardLogger(save_dir=log_base_dir,
name="",
version=latest_version)
trainer = pl.Trainer(**config.trainer,
logger=logger)
trainer.test(model, combined_test_loader)
if __name__ == '__main__':
main()