Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move get_path to get_or_create_path, use the best model of SFM / TabNet #328

Merged
merged 2 commits into from
Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions qlib/contrib/model/pytorch_alstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
create_save_path,
get_or_create_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
Expand Down Expand Up @@ -230,8 +230,7 @@ def fit(
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]

if save_path == None:
save_path = create_save_path(save_path)
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
Expand Down
5 changes: 2 additions & 3 deletions qlib/contrib/model/pytorch_alstm_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
create_save_path,
get_or_create_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
Expand Down Expand Up @@ -220,8 +220,7 @@ def fit(
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
)

if save_path == None:
save_path = create_save_path(save_path)
save_path = get_or_create_path(save_path)

stop_steps = 0
train_loss = 0
Expand Down
5 changes: 2 additions & 3 deletions qlib/contrib/model/pytorch_gats.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
create_save_path,
get_or_create_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
Expand Down Expand Up @@ -248,8 +248,7 @@ def fit(
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]

if save_path == None:
save_path = create_save_path(save_path)
save_path = get_or_create_path(save_path)
stop_steps = 0
best_score = -np.inf
best_epoch = 0
Expand Down
5 changes: 2 additions & 3 deletions qlib/contrib/model/pytorch_gats_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
create_save_path,
get_or_create_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
Expand Down Expand Up @@ -264,8 +264,7 @@ def fit(
train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs, drop_last=True)
valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs, drop_last=True)

if save_path == None:
save_path = create_save_path(save_path)
save_path = get_or_create_path(save_path)

stop_steps = 0
train_loss = 0
Expand Down
5 changes: 2 additions & 3 deletions qlib/contrib/model/pytorch_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
create_save_path,
get_or_create_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
Expand Down Expand Up @@ -230,8 +230,7 @@ def fit(
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]

if save_path == None:
save_path = create_save_path(save_path)
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
Expand Down
5 changes: 2 additions & 3 deletions qlib/contrib/model/pytorch_gru_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
create_save_path,
get_or_create_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
Expand Down Expand Up @@ -220,8 +220,7 @@ def fit(
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
)

if save_path == None:
save_path = create_save_path(save_path)
save_path = get_or_create_path(save_path)

stop_steps = 0
train_loss = 0
Expand Down
5 changes: 2 additions & 3 deletions qlib/contrib/model/pytorch_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
create_save_path,
get_or_create_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
Expand Down Expand Up @@ -226,8 +226,7 @@ def fit(
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]

if save_path == None:
save_path = create_save_path(save_path)
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
Expand Down
5 changes: 2 additions & 3 deletions qlib/contrib/model/pytorch_lstm_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
create_save_path,
get_or_create_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
Expand Down Expand Up @@ -216,8 +216,7 @@ def fit(
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
)

if save_path == None:
save_path = create_save_path(save_path)
save_path = get_or_create_path(save_path)

stop_steps = 0
train_loss = 0
Expand Down
4 changes: 2 additions & 2 deletions qlib/contrib/model/pytorch_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path, drop_nan_by_y_index
from ...log import get_module_logger, TimeInspector
from ...workflow import R

Expand Down Expand Up @@ -176,7 +176,7 @@ def fit(
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)

save_path = create_save_path(save_path)
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_loss = np.inf
Expand Down
6 changes: 5 additions & 1 deletion qlib/contrib/model/pytorch_sfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
create_save_path,
get_or_create_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
Expand Down Expand Up @@ -380,6 +380,7 @@ def fit(
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]

save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
Expand Down Expand Up @@ -412,7 +413,10 @@ def fit(
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break

self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.sfm_model.load_state_dict(best_param)
torch.save(best_param, save_path)
if self.device != "cpu":
torch.cuda.empty_cache()

Expand Down
12 changes: 7 additions & 5 deletions qlib/contrib/model/pytorch_tabnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
create_save_path,
get_or_create_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
Expand Down Expand Up @@ -117,10 +117,7 @@ def __init__(
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))

def pretrain_fn(self, dataset=DatasetH, pretrain_file="./pretrain/best.model"):
# make a directory if pretrian director does not exist
if pretrain_file.startswith("./pretrain") and not os.path.exists("pretrain"):
self.logger.info("make folder to store model...")
os.makedirs("pretrain")
get_or_create_path(pretrain_file)

[df_train, df_valid] = dataset.prepare(
["pretrain", "pretrain_validation"],
Expand Down Expand Up @@ -181,6 +178,7 @@ def fit(
df_train.fillna(df_train.mean(), inplace=True)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
save_path = get_or_create_path(save_path)

stop_steps = 0
train_loss = 0
Expand All @@ -207,12 +205,16 @@ def fit(
best_score = val_score
stop_steps = 0
best_epoch = epoch_idx
best_param = copy.deepcopy(self.tabnet_model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break

self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.tabnet_model.load_state_dict(best_param)
torch.save(best_param, save_path)

def predict(self, dataset):
if not self.fitted:
Expand Down
26 changes: 17 additions & 9 deletions qlib/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Union, Tuple
from typing import Union, Tuple, Text, Optional

from ..config import C
from ..log import get_module_logger, set_log_with_config
Expand Down Expand Up @@ -276,23 +276,31 @@ def default(self, o):
return changes


def create_save_path(save_path=None):
"""Create save path
def get_or_create_path(path: Optional[Text] = None, return_dir: bool = False):
"""Create or get a file or directory given the path and return_dir.

Parameters
----------
save_path: str
path: a string indicates the path or None indicates creating a temporary path.
return_dir: if True, create and return a directory; otherwise c&r a file.

"""
if save_path:
if not os.path.exists(save_path):
os.makedirs(save_path)
if path:
if return_dir and not os.path.exists(path):
os.makedirs(path)
elif not return_dir: # return a file, thus we need to create its parent directory
xpath = os.path.abspath(os.path.join(path, ".."))
if not os.path.exists(xpath):
os.makedirs(xpath)
else:
temp_dir = os.path.expanduser("~/tmp")
if not os.path.exists(temp_dir):
os.makedirs(temp_dir)
_, save_path = tempfile.mkstemp(dir=temp_dir)
return save_path
if return_dir:
_, path = tempfile.mkdtemp(dir=temp_dir)
else:
_, path = tempfile.mkstemp(dir=temp_dir)
return path


@contextlib.contextmanager
Expand Down