-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
92 lines (74 loc) · 2.58 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig
import numpy as np
from utils.mlflow_logs import log_results
from utils.plot_results import plot_predictions
@hydra.main(version_base='1.3.2', config_path="configs", config_name="main_config.yaml")
def main(cfg: DictConfig) -> None:
dataset = instantiate(cfg.datasets.DataLoader)
if cfg.datasets.data_params.for_tsai:
params = cfg.datasets.data_params
# load tsai data
X, y, splitter, tfms, batch_tfms = dataset.get_tsai_data(**params)
# load the model
model = instantiate(cfg.models.model)
# params for training
params = {
'X': X,
'y': y,
'splits': splitter,
'tfms': tfms,
'batch_tfms': batch_tfms,
'arch': cfg.models.arch.model_name
}
# extend the params with cfg.training_params
params.update(cfg.training_params)
# train the model
model.train_model(**params)
# evaluate the model
eval_params = {
'arch': cfg.models.arch.model_name,
'X': X,
'y': y,
'splits': splitter
}
outputDict = model.evaluate_model(**eval_params)
# get fig
fig = plot_predictions(cfg.models.arch.model_name, outputDict['target'], outputDict['preds'])
outputDict['fig'] = fig
# log results to mlflow
log_results(cfg, outputDict)
else:
# load train and test data
X_train, y_train, X_test, y_test = dataset.train_test_split(
df=dataset.process_data(),
**cfg.datasets.data_params
)
if cfg.models.arch.model_name == 'XGB':
model = instantiate(cfg.models.model)
# params
params = {
'X_train': X_train,
'y_train': y_train,
'arch': cfg.models.arch.model_name
}
# train the model
model.train_model(**params)
# evaluate the model
eval_params = {
'arch': cfg.models.arch.model_name,
'X_test': X_test,
'y_test': y_test
}
outputDict = model.evaluate_model(**eval_params)
# get fig
fig = plot_predictions(cfg.models.arch.model_name, outputDict['target'], outputDict['preds'])
outputDict['fig'] = fig
# log results to mlflow
log_results(cfg, outputDict)
else:
return
return
if __name__ == "__main__":
main()