From 7f8ea7d644daf13aaa20601ec7d7478711401ce7 Mon Sep 17 00:00:00 2001 From: fengcunguang Date: Tue, 2 Nov 2021 10:47:04 +0800 Subject: [PATCH] Add A New Baseline: TCN --- README.md | 1 + examples/benchmarks/README.md | 2 + examples/benchmarks/TCN/requirements.txt | 4 + .../TCN/workflow_config_tcn_Alpha158.yaml | 100 ++++++ .../TCN/workflow_config_tcn_Alpha360.yaml | 90 +++++ qlib/contrib/model/__init__.py | 3 +- qlib/contrib/model/pytorch_tcn.py | 317 ++++++++++++++++++ qlib/contrib/model/pytorch_tcn_ts.py | 300 +++++++++++++++++ qlib/contrib/model/tcn.py | 77 +++++ 9 files changed, 893 insertions(+), 1 deletion(-) create mode 100644 examples/benchmarks/TCN/requirements.txt create mode 100755 examples/benchmarks/TCN/workflow_config_tcn_Alpha158.yaml create mode 100644 examples/benchmarks/TCN/workflow_config_tcn_Alpha360.yaml create mode 100755 qlib/contrib/model/pytorch_tcn.py create mode 100755 qlib/contrib/model/pytorch_tcn_ts.py create mode 100644 qlib/contrib/model/tcn.py diff --git a/README.md b/README.md index 0f9cccceb2..f38e8949fa 100644 --- a/README.md +++ b/README.md @@ -294,6 +294,7 @@ Here is a list of models built on `Qlib`. - [Transformer based on pytorch (Ashish Vaswani, et al. NeurIPS 2017)](qlib/contrib/model/pytorch_transformer.py) - [Localformer based on pytorch (Juyong Jiang, et al.)](qlib/contrib/model/pytorch_localformer.py) - [TRA based on pytorch (Hengxu, Dong, et al. KDD 2021)](qlib/contrib/model/pytorch_tra.py) +- [TCN based on pytorch (Shaojie Bai, et al. 2018)](qlib/contrib/model/pytorch_tcn.py) Your PR of new Quant models is highly welcomed. diff --git a/examples/benchmarks/README.md b/examples/benchmarks/README.md index dfc4f492ec..0088128df0 100644 --- a/examples/benchmarks/README.md +++ b/examples/benchmarks/README.md @@ -34,6 +34,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of | MLP | Alpha158 | 0.0376±0.00 | 0.2846±0.02 | 0.0429±0.00 | 0.3220±0.01 | 0.0895±0.02 | 1.1408±0.23 | -0.1103±0.02 | | LightGBM(Guolin Ke, et al.) | Alpha158 | 0.0448±0.00 | 0.3660±0.00 | 0.0469±0.00 | 0.3877±0.00 | 0.0901±0.00 | 1.0164±0.00 | -0.1038±0.00 | | DoubleEnsemble(Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4340±0.00 | 0.0523±0.00 | 0.4284±0.01 | 0.1168±0.01 | 1.3384±0.12 | -0.1036±0.01 | +| TCN | Alpha158 | 0.0275±0.00 | 0.2157±0.01 | 0.0411±0.00 | 0.3379±0.01 | 0.0190±0.02 | 0.2887±0.27 | -0.1202±0.03 | @@ -55,6 +56,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of | GATs (Petar Velickovic, et al.) | Alpha360 | 0.0476±0.00 | 0.3508±0.02 | 0.0598±0.00 | 0.4604±0.01 | 0.0824±0.02 | 1.1079±0.26 | -0.0894±0.03 | | TCTS(Xueqing Wu, et al.) | Alpha360 | 0.0508±0.00 | 0.3931±0.04 | 0.0599±0.00 | 0.4756±0.03 | 0.0893±0.03 | 1.2256±0.36 | -0.0857±0.02 | | TRA(Hengxu Lin, et al.) | Alpha360 | 0.0485±0.00 | 0.3787±0.03 | 0.0587±0.00 | 0.4756±0.03 | 0.0920±0.03 | 1.2789±0.42 | -0.0834±0.02 | +| TCN(Shaojie Bai, et al.) | Alpha360 | 0.0441±0.00 | 0.3301±0.02 | 0.0519±0.00 | 0.4130±0.01 | 0.0604±0.02 | 0.8295±0.34 | -0.1018±0.03 | - The selected 20 features are based on the feature importance of a lightgbm-based model. - The base model of DoubleEnsemble is LGBM. diff --git a/examples/benchmarks/TCN/requirements.txt b/examples/benchmarks/TCN/requirements.txt new file mode 100644 index 0000000000..1fc2779c0f --- /dev/null +++ b/examples/benchmarks/TCN/requirements.txt @@ -0,0 +1,4 @@ +numpy==1.17.4 +pandas==1.1.2 +scikit_learn==0.23.2 +torch==1.7.0 diff --git a/examples/benchmarks/TCN/workflow_config_tcn_Alpha158.yaml b/examples/benchmarks/TCN/workflow_config_tcn_Alpha158.yaml new file mode 100755 index 0000000000..c6f663f948 --- /dev/null +++ b/examples/benchmarks/TCN/workflow_config_tcn_Alpha158.yaml @@ -0,0 +1,100 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn +market: &market csi300 +benchmark: &benchmark SH000300 +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: + - class: FilterCol + kwargs: + fields_group: feature + col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", + "ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5", + "RSQR20", "CORD60", "CORD10", "CORR20", "KLOW" + ] + - class: RobustZScoreNorm + kwargs: + fields_group: feature + clip_outlier: true + - class: Fillna + kwargs: + fields_group: feature + learn_processors: + - class: DropnaLabel + - class: CSRankNorm + kwargs: + fields_group: label + label: ["Ref($close, -2) / Ref($close, -1) - 1"] + +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy + kwargs: + model: + dataset: + topk: 50 + n_drop: 5 + backtest: + start_time: 2017-01-01 + end_time: 2020-08-01 + account: 100000000 + benchmark: *benchmark + exchange_kwargs: + limit_threshold: 0.095 + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 +task: + model: + class: TCN + module_path: qlib.contrib.model.pytorch_tcn_ts + kwargs: + d_feat: 20 + num_layers: 5 + n_chans: 32 + kernel_size: 7 + dropout: 0.5 + n_epochs: 200 + lr: 1e-4 + early_stop: 20 + batch_size: 2000 + metric: loss + loss: mse + optimizer: adam + n_jobs: 20 + GPU: 0 + dataset: + class: TSDatasetH + module_path: qlib.data.dataset + kwargs: + handler: + class: Alpha158 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + step_len: 20 + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: + model: + dataset: + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config diff --git a/examples/benchmarks/TCN/workflow_config_tcn_Alpha360.yaml b/examples/benchmarks/TCN/workflow_config_tcn_Alpha360.yaml new file mode 100644 index 0000000000..e383662fc1 --- /dev/null +++ b/examples/benchmarks/TCN/workflow_config_tcn_Alpha360.yaml @@ -0,0 +1,90 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn +market: &market csi300 +benchmark: &benchmark SH000300 +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: + - class: RobustZScoreNorm + kwargs: + fields_group: feature + clip_outlier: true + - class: Fillna + kwargs: + fields_group: feature + learn_processors: + - class: DropnaLabel + - class: CSRankNorm + kwargs: + fields_group: label + label: ["Ref($close, -2) / Ref($close, -1) - 1"] +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy + kwargs: + model: + dataset: + topk: 50 + n_drop: 5 + backtest: + start_time: 2017-01-01 + end_time: 2020-08-01 + account: 100000000 + benchmark: *benchmark + exchange_kwargs: + limit_threshold: 0.095 + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 +task: + model: + class: TCN + module_path: qlib.contrib.model.pytorch_tcn + kwargs: + d_feat: 6 + num_layers: 5 + n_chans: 128 + kernel_size: 3 + dropout: 0.5 + n_epochs: 200 + lr: 1e-3 + early_stop: 20 + batch_size: 2000 + metric: loss + loss: mse + optimizer: adam + GPU: 0 + dataset: + class: DatasetH + module_path: qlib.data.dataset + kwargs: + handler: + class: Alpha360 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: + model: + dataset: + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config diff --git a/qlib/contrib/model/__init__.py b/qlib/contrib/model/__init__.py index 09b0c929b6..b691db1560 100644 --- a/qlib/contrib/model/__init__.py +++ b/qlib/contrib/model/__init__.py @@ -30,8 +30,9 @@ from .pytorch_nn import DNNModelPytorch from .pytorch_tabnet import TabnetModel from .pytorch_sfm import SFM_Model + from .pytorch_tcn import TCN - pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model) + pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model, TCN) except ModuleNotFoundError: pytorch_classes = () print("Please install necessary libs for PyTorch models.") diff --git a/qlib/contrib/model/pytorch_tcn.py b/qlib/contrib/model/pytorch_tcn.py new file mode 100755 index 0000000000..c649dfa0b3 --- /dev/null +++ b/qlib/contrib/model/pytorch_tcn.py @@ -0,0 +1,317 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from __future__ import division +from __future__ import print_function + +import numpy as np +import pandas as pd +from typing import Text, Union +import copy +from ...utils import get_or_create_path +from ...log import get_module_logger + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.nn.utils import weight_norm + +from .pytorch_utils import count_parameters +from ...model.base import Model +from ...data.dataset import DatasetH +from ...data.dataset.handler import DataHandlerLP +from .tcn import TemporalConvNet + + +class TCN(Model): + """TCN Model + + Parameters + ---------- + d_feat : int + input dimension for each time step + n_chans: int + number of channels + metric: str + the evaluate metric used in early stop + optimizer : str + optimizer name + GPU : str + the GPU ID(s) used for training + """ + + def __init__( + self, + d_feat=6, + n_chans=128, + kernel_size=5, + num_layers=5, + dropout=0.5, + n_epochs=200, + lr=0.0001, + metric="", + batch_size=2000, + early_stop=20, + loss="mse", + optimizer="adam", + GPU=0, + seed=None, + **kwargs + ): + # Set logger. + self.logger = get_module_logger("TCN") + self.logger.info("TCN pytorch version...") + + # set hyper-parameters. + self.d_feat = d_feat + self.n_chans = n_chans + self.kernel_size = kernel_size + self.num_layers = num_layers + self.dropout = dropout + self.n_epochs = n_epochs + self.lr = lr + self.metric = metric + self.batch_size = batch_size + self.early_stop = early_stop + self.optimizer = optimizer.lower() + self.loss = loss + self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.seed = seed + + self.logger.info( + "TCN parameters setting:" + "\nd_feat : {}" + "\nn_chans : {}" + "\nkernel_size : {}" + "\nnum_layers : {}" + "\ndropout : {}" + "\nn_epochs : {}" + "\nlr : {}" + "\nmetric : {}" + "\nbatch_size : {}" + "\nearly_stop : {}" + "\noptimizer : {}" + "\nloss_type : {}" + "\nvisible_GPU : {}" + "\nuse_GPU : {}" + "\nseed : {}".format( + d_feat, + n_chans, + kernel_size, + num_layers, + dropout, + n_epochs, + lr, + metric, + batch_size, + early_stop, + optimizer.lower(), + loss, + GPU, + self.use_gpu, + seed, + ) + ) + + if self.seed is not None: + np.random.seed(self.seed) + torch.manual_seed(self.seed) + + self.tcn_model = TCNModel( + num_input=self.d_feat, + output_size=1, + num_channels=[self.n_chans] * self.num_layers, + kernel_size=self.kernel_size, + dropout=self.dropout, + ) + self.logger.info("model:\n{:}".format(self.tcn_model)) + self.logger.info("model size: {:.4f} MB".format(count_parameters(self.tcn_model))) + + if optimizer.lower() == "adam": + self.train_optimizer = optim.Adam(self.tcn_model.parameters(), lr=self.lr) + elif optimizer.lower() == "gd": + self.train_optimizer = optim.SGD(self.tcn_model.parameters(), lr=self.lr) + else: + raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + + self.fitted = False + self.tcn_model.to(self.device) + + @property + def use_gpu(self): + return self.device != torch.device("cpu") + + def mse(self, pred, label): + loss = (pred - label) ** 2 + return torch.mean(loss) + + def loss_fn(self, pred, label): + mask = ~torch.isnan(label) + + if self.loss == "mse": + return self.mse(pred[mask], label[mask]) + + raise ValueError("unknown loss `%s`" % self.loss) + + def metric_fn(self, pred, label): + + mask = torch.isfinite(label) + + if self.metric == "" or self.metric == "loss": + return -self.loss_fn(pred[mask], label[mask]) + + raise ValueError("unknown metric `%s`" % self.metric) + + def train_epoch(self, x_train, y_train): + + x_train_values = x_train.values + y_train_values = np.squeeze(y_train.values) + + self.tcn_model.train() + + indices = np.arange(len(x_train_values)) + np.random.shuffle(indices) + + for i in range(len(indices))[:: self.batch_size]: + + if len(indices) - i < self.batch_size: + break + + feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device) + label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device) + + pred = self.tcn_model(feature) + loss = self.loss_fn(pred, label) + + self.train_optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_value_(self.tcn_model.parameters(), 3.0) + self.train_optimizer.step() + + def test_epoch(self, data_x, data_y): + x_values = data_x.values + y_values = np.squeeze(data_y.values) + + self.tcn_model.eval() + + scores = [] + losses = [] + + indices = np.arange(len(x_values)) + + for i in range(len(indices))[:: self.batch_size]: + + if len(indices) - i < self.batch_size: + break + + feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device) + label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device) + + with torch.no_grad(): + pred = self.tcn_model(feature) + loss = self.loss_fn(pred, label) + losses.append(loss.item()) + + score = self.metric_fn(pred, label) + scores.append(score.item()) + + return np.mean(losses), np.mean(scores) + + def fit( + self, + dataset: DatasetH, + evals_result=dict(), + save_path=None, + ): + + df_train, df_valid, df_test = dataset.prepare( + ["train", "valid", "test"], + col_set=["feature", "label"], + data_key=DataHandlerLP.DK_L, + ) + + x_train, y_train = df_train["feature"], df_train["label"] + x_valid, y_valid = df_valid["feature"], df_valid["label"] + + save_path = get_or_create_path(save_path) + stop_steps = 0 + train_loss = 0 + best_score = -np.inf + best_epoch = 0 + evals_result["train"] = [] + evals_result["valid"] = [] + + # train + self.logger.info("training...") + self.fitted = True + + for step in range(self.n_epochs): + self.logger.info("Epoch%d:", step) + self.logger.info("training...") + self.train_epoch(x_train, y_train) + self.logger.info("evaluating...") + train_loss, train_score = self.test_epoch(x_train, y_train) + val_loss, val_score = self.test_epoch(x_valid, y_valid) + self.logger.info("train %.6f, valid %.6f" % (train_score, val_score)) + evals_result["train"].append(train_score) + evals_result["valid"].append(val_score) + + if val_score > best_score: + best_score = val_score + stop_steps = 0 + best_epoch = step + best_param = copy.deepcopy(self.tcn_model.state_dict()) + else: + stop_steps += 1 + if stop_steps >= self.early_stop: + self.logger.info("early stop") + break + + self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch)) + self.tcn_model.load_state_dict(best_param) + torch.save(best_param, save_path) + + if self.use_gpu: + torch.cuda.empty_cache() + + def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): + if not self.fitted: + raise ValueError("model is not fitted yet!") + + x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + index = x_test.index + self.tcn_model.eval() + x_values = x_test.values + sample_num = x_values.shape[0] + preds = [] + + for begin in range(sample_num)[:: self.batch_size]: + + if sample_num - begin < self.batch_size: + end = sample_num + else: + end = begin + self.batch_size + + x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device) + + with torch.no_grad(): + pred = self.tcn_model(x_batch).detach().cpu().numpy() + + preds.append(pred) + + return pd.Series(np.concatenate(preds), index=index) + + +class TCNModel(nn.Module): + def __init__(self, num_input, output_size, num_channels, kernel_size, dropout): + super().__init__() + self.num_input = num_input + self.tcn = TemporalConvNet(num_input, num_channels, kernel_size, dropout=dropout) + self.linear = nn.Linear(num_channels[-1], output_size) + + def forward(self, x): + x = x.reshape(x.shape[0], self.num_input, -1) + output = self.tcn(x) + output = self.linear(output[:, :, -1]) + return output.squeeze() diff --git a/qlib/contrib/model/pytorch_tcn_ts.py b/qlib/contrib/model/pytorch_tcn_ts.py new file mode 100755 index 0000000000..3e0a15e046 --- /dev/null +++ b/qlib/contrib/model/pytorch_tcn_ts.py @@ -0,0 +1,300 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from __future__ import division +from __future__ import print_function + +import numpy as np +import pandas as pd +import copy +from ...utils import get_or_create_path +from ...log import get_module_logger + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader + +from .pytorch_utils import count_parameters +from ...model.base import Model +from ...data.dataset.handler import DataHandlerLP +from .tcn import TemporalConvNet + + +class TCN(Model): + """TCN Model + + Parameters + ---------- + d_feat : int + input dimension for each time step + metric: str + the evaluate metric used in early stop + optimizer : str + optimizer name + GPU : str + the GPU ID(s) used for training + """ + + def __init__( + self, + d_feat=6, + n_chans=128, + kernel_size=5, + num_layers=2, + dropout=0.0, + n_epochs=200, + lr=0.001, + metric="", + batch_size=2000, + early_stop=20, + loss="mse", + optimizer="adam", + n_jobs=10, + GPU=0, + seed=None, + **kwargs + ): + # Set logger. + self.logger = get_module_logger("TCN") + self.logger.info("TCN pytorch version...") + + # set hyper-parameters. + self.d_feat = d_feat + self.n_chans = n_chans + self.kernel_size = kernel_size + self.num_layers = num_layers + self.dropout = dropout + self.n_epochs = n_epochs + self.lr = lr + self.metric = metric + self.batch_size = batch_size + self.early_stop = early_stop + self.optimizer = optimizer.lower() + self.loss = loss + self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.n_jobs = n_jobs + self.seed = seed + + self.logger.info( + "TCN parameters setting:" + "\nd_feat : {}" + "\nn_chans : {}" + "\nkernel_size : {}" + "\nnum_layers : {}" + "\ndropout : {}" + "\nn_epochs : {}" + "\nlr : {}" + "\nmetric : {}" + "\nbatch_size : {}" + "\nearly_stop : {}" + "\noptimizer : {}" + "\nloss_type : {}" + "\ndevice : {}" + "\nn_jobs : {}" + "\nuse_GPU : {}" + "\nseed : {}".format( + d_feat, + n_chans, + kernel_size, + num_layers, + dropout, + n_epochs, + lr, + metric, + batch_size, + early_stop, + optimizer.lower(), + loss, + self.device, + n_jobs, + self.use_gpu, + seed, + ) + ) + + if self.seed is not None: + np.random.seed(self.seed) + torch.manual_seed(self.seed) + + self.TCN_model = TCNModel( + num_input=self.d_feat, + output_size=1, + num_channels=[self.n_chans] * self.num_layers, + kernel_size=self.kernel_size, + dropout=self.dropout, + ) + self.logger.info("model:\n{:}".format(self.TCN_model)) + self.logger.info("model size: {:.4f} MB".format(count_parameters(self.TCN_model))) + + if optimizer.lower() == "adam": + self.train_optimizer = optim.Adam(self.TCN_model.parameters(), lr=self.lr) + elif optimizer.lower() == "gd": + self.train_optimizer = optim.SGD(self.TCN_model.parameters(), lr=self.lr) + else: + raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + + self.fitted = False + self.TCN_model.to(self.device) + + @property + def use_gpu(self): + return self.device != torch.device("cpu") + + def mse(self, pred, label): + loss = (pred - label) ** 2 + return torch.mean(loss) + + def loss_fn(self, pred, label): + mask = ~torch.isnan(label) + + if self.loss == "mse": + return self.mse(pred[mask], label[mask]) + + raise ValueError("unknown loss `%s`" % self.loss) + + def metric_fn(self, pred, label): + + mask = torch.isfinite(label) + + if self.metric == "" or self.metric == "loss": + return -self.loss_fn(pred[mask], label[mask]) + + raise ValueError("unknown metric `%s`" % self.metric) + + def train_epoch(self, data_loader): + + self.TCN_model.train() + + for data in data_loader: + feature = data[:, :, 0:-1].to(self.device) + label = data[:, -1, -1].to(self.device) + + pred = self.TCN_model(feature.float()) + loss = self.loss_fn(pred, label) + + self.train_optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_value_(self.TCN_model.parameters(), 3.0) + self.train_optimizer.step() + + def test_epoch(self, data_loader): + + self.TCN_model.eval() + + scores = [] + losses = [] + + for data in data_loader: + + feature = data[:, :, 0:-1].to(self.device) + # feature[torch.isnan(feature)] = 0 + label = data[:, -1, -1].to(self.device) + + with torch.no_grad(): + pred = self.TCN_model(feature.float()) + loss = self.loss_fn(pred, label) + losses.append(loss.item()) + + score = self.metric_fn(pred, label) + scores.append(score.item()) + + return np.mean(losses), np.mean(scores) + + def fit( + self, + dataset, + evals_result=dict(), + save_path=None, + ): + dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + + # process nan brought by dataloader + dl_train.config(fillna_type="ffill+bfill") + # process nan brought by dataloader + dl_valid.config(fillna_type="ffill+bfill") + + train_loader = DataLoader( + dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True + ) + valid_loader = DataLoader( + dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True + ) + + save_path = get_or_create_path(save_path) + + stop_steps = 0 + train_loss = 0 + best_score = -np.inf + best_epoch = 0 + evals_result["train"] = [] + evals_result["valid"] = [] + + # train + self.logger.info("training...") + self.fitted = True + + for step in range(self.n_epochs): + self.logger.info("Epoch%d:", step) + self.logger.info("training...") + self.train_epoch(train_loader) + self.logger.info("evaluating...") + train_loss, train_score = self.test_epoch(train_loader) + val_loss, val_score = self.test_epoch(valid_loader) + self.logger.info("train %.6f, valid %.6f" % (train_score, val_score)) + evals_result["train"].append(train_score) + evals_result["valid"].append(val_score) + + if val_score > best_score: + best_score = val_score + stop_steps = 0 + best_epoch = step + best_param = copy.deepcopy(self.TCN_model.state_dict()) + else: + stop_steps += 1 + if stop_steps >= self.early_stop: + self.logger.info("early stop") + break + + self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch)) + self.TCN_model.load_state_dict(best_param) + torch.save(best_param, save_path) + + if self.use_gpu: + torch.cuda.empty_cache() + + def predict(self, dataset): + if not self.fitted: + raise ValueError("model is not fitted yet!") + + dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) + dl_test.config(fillna_type="ffill+bfill") + test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs) + self.TCN_model.eval() + preds = [] + + for data in test_loader: + + feature = data[:, :, 0:-1].to(self.device) + + with torch.no_grad(): + pred = self.TCN_model(feature.float()).detach().cpu().numpy() + + preds.append(pred) + + return pd.Series(np.concatenate(preds), index=dl_test.get_index()) + + +class TCNModel(nn.Module): + def __init__(self, num_input, output_size, num_channels, kernel_size, dropout): + super().__init__() + self.num_input = num_input + self.tcn = TemporalConvNet(num_input, num_channels, kernel_size, dropout=dropout) + self.linear = nn.Linear(num_channels[-1], output_size) + + def forward(self, x): + output = self.tcn(x) + output = self.linear(output[:, :, -1]) + return output.squeeze() diff --git a/qlib/contrib/model/tcn.py b/qlib/contrib/model/tcn.py new file mode 100644 index 0000000000..ba6a85b8f7 --- /dev/null +++ b/qlib/contrib/model/tcn.py @@ -0,0 +1,77 @@ +# MIT License +# Copyright (c) 2018 CMU Locus Lab +import torch +import torch.nn as nn +from torch.nn.utils import weight_norm + + +class Chomp1d(nn.Module): + def __init__(self, chomp_size): + super(Chomp1d, self).__init__() + self.chomp_size = chomp_size + + def forward(self, x): + return x[:, :, : -self.chomp_size].contiguous() + + +class TemporalBlock(nn.Module): + def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2): + super(TemporalBlock, self).__init__() + self.conv1 = weight_norm( + nn.Conv1d(n_inputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation) + ) + self.chomp1 = Chomp1d(padding) + self.relu1 = nn.ReLU() + self.dropout1 = nn.Dropout(dropout) + + self.conv2 = weight_norm( + nn.Conv1d(n_outputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation) + ) + self.chomp2 = Chomp1d(padding) + self.relu2 = nn.ReLU() + self.dropout2 = nn.Dropout(dropout) + + self.net = nn.Sequential( + self.conv1, self.chomp1, self.relu1, self.dropout1, self.conv2, self.chomp2, self.relu2, self.dropout2 + ) + self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None + self.relu = nn.ReLU() + self.init_weights() + + def init_weights(self): + self.conv1.weight.data.normal_(0, 0.01) + self.conv2.weight.data.normal_(0, 0.01) + if self.downsample is not None: + self.downsample.weight.data.normal_(0, 0.01) + + def forward(self, x): + out = self.net(x) + res = x if self.downsample is None else self.downsample(x) + return self.relu(out + res) + + +class TemporalConvNet(nn.Module): + def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2): + super(TemporalConvNet, self).__init__() + layers = [] + num_levels = len(num_channels) + for i in range(num_levels): + dilation_size = 2 ** i + in_channels = num_inputs if i == 0 else num_channels[i - 1] + out_channels = num_channels[i] + layers += [ + TemporalBlock( + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=dilation_size, + padding=(kernel_size - 1) * dilation_size, + dropout=dropout, + ) + ] + + self.network = nn.Sequential(*layers) + + def forward(self, x): + return self.network(x)