Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update repr for dataset/workflow classes and add uri kwarg for QlibRecorder #302

Merged
merged 8 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ tags

.pytest_cache/
.vscode/

*.swp
17 changes: 11 additions & 6 deletions qlib/data/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ...utils.serial import Serializable
from typing import Union, List, Tuple
from typing import Union, List, Tuple, Dict, Text, Optional
from ...utils import init_instance_by_config, np_ffill
from ...log import get_module_logger
from .handler import DataHandler, DataHandlerLP
Expand Down Expand Up @@ -76,7 +76,7 @@ class DatasetH(Dataset):
- The processing is related to data split.
"""

def __init__(self, handler: Union[dict, DataHandler], segments: dict):
def __init__(self, handler: Union[Dict, DataHandler], segments: Dict):
"""
Parameters
----------
Expand All @@ -87,7 +87,7 @@ def __init__(self, handler: Union[dict, DataHandler], segments: dict):
"""
super().__init__(handler, segments)

def init(self, handler_kwargs: dict = None, segment_kwargs: dict = None):
def init(self, handler_kwargs: Optional[Dict] = None, segment_kwargs: Optional[Dict] = None):
"""
Initialize the DatasetH

Expand Down Expand Up @@ -124,7 +124,7 @@ def init(self, handler_kwargs: dict = None, segment_kwargs: dict = None):
raise TypeError(f"param handler_kwargs must be type dict, not {type(segment_kwargs)}")
self.segments = segment_kwargs.copy()

def setup_data(self, handler: Union[dict, DataHandler], segments: dict):
def setup_data(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple]):
"""
Setup the underlying data.

Expand Down Expand Up @@ -156,6 +156,11 @@ def setup_data(self, handler: Union[dict, DataHandler], segments: dict):
self.handler = init_instance_by_config(handler, accept_types=DataHandler)
self.segments = segments.copy()

def __repr__(self):
return "{name}(handler={handler}, segments={segments})".format(
name=self.__class__.__name__, handler=self.handler, segments=self.segments
)

def _prepare_seg(self, slc: slice, **kwargs):
"""
Give a slice, retrieve the according data
Expand All @@ -168,7 +173,7 @@ def _prepare_seg(self, slc: slice, **kwargs):

def prepare(
self,
segments: Union[List[str], Tuple[str], str, slice],
segments: Union[List[Text], Tuple[Text], Text, slice],
col_set=DataHandler.CS_ALL,
data_key=DataHandlerLP.DK_I,
**kwargs,
Expand All @@ -178,7 +183,7 @@ def prepare(

Parameters
----------
segments : Union[List[str], Tuple[str], str, slice]
segments : Union[List[Text], Tuple[Text], Text, slice]
Describe the scope of the data to be prepared
Here are some examples:

Expand Down
6 changes: 3 additions & 3 deletions qlib/data/dataset/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class DataHandler(Serializable):
The data handler try to maintain a handler with 2 level.
`datetime` & `instruments`.

Any order of the index level can be suported(The order will implied in the data).
Any order of the index level can be suported (The order will be implied in the data).
The order <`datetime`, `instruments`> will be used when the dataframe index name is missed.

Example of the data:
Expand All @@ -47,8 +47,8 @@ class DataHandler(Serializable):
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
datetime instrument
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289

"""

Expand Down
2 changes: 1 addition & 1 deletion qlib/data/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,7 +1489,7 @@ def __init__(self, feature_left, feature_right, N):
]


class OpsWrapper(object):
class OpsWrapper:
"""Ops Wrapper"""

def __init__(self):
Expand Down
14 changes: 11 additions & 3 deletions qlib/workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ class QlibRecorder:
def __init__(self, exp_manager):
self.exp_manager = exp_manager

def __repr__(self):
return "{name}(manager={manager})".format(name=self.__class__.__name__, manager=self.exp_manager)

@contextmanager
def start(self, experiment_name=None, recorder_name=None):
def start(self, experiment_name=None, recorder_name=None, uri=None):
Derek-Wds marked this conversation as resolved.
Show resolved Hide resolved
"""
Method to start an experiment. This method can only be called within a Python's `with` statement. Here is the example code:

Expand All @@ -34,8 +37,13 @@ def start(self, experiment_name=None, recorder_name=None):
name of the experiment one wants to start.
recorder_name : str
name of the recorder under the experiment one wants to start.
uri : str
D-X-Y marked this conversation as resolved.
Show resolved Hide resolved
The tracking uri of the experiment, where all the artifacts/metrics etc. will be stored.
The default uri is set in the qlib.config. Note that this uri argument will not change the one defined in the config file.
Therefore, the next time when users call this function in the same experiment,
they have to also specify this argument with the same value. Otherwise, inconsistent uri may occur.
"""
run = self.start_exp(experiment_name, recorder_name)
run = self.start_exp(experiment_name, recorder_name, uri)
try:
yield run
except Exception as e:
Expand Down Expand Up @@ -272,7 +280,7 @@ def get_uri(self):
-------
The uri of current experiment manager.
"""
return self.exp_manager.get_uri()
return self.exp_manager.uri

def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None):
"""
Expand Down
17 changes: 7 additions & 10 deletions qlib/workflow/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, id, name):
self.active_recorder = None # only one recorder can running each time

def __repr__(self):
return str(self.info)
return "{name}(info={info})".format(name=self.__class__.__name__, info=self.info)

def __str__(self):
return str(self.info)
Expand Down Expand Up @@ -173,11 +173,9 @@ def __init__(self, id, name, uri):
self._uri = uri
self._default_name = None
self._default_rec_name = "mlflow_recorder"
self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)

def start(self, recorder_name=None):
# set the active experiment
D-X-Y marked this conversation as resolved.
Show resolved Hide resolved
mlflow.set_experiment(self.name)
logger.info(f"Experiment {self.id} starts running ...")
# set up recorder
recorder = self.create_recorder(recorder_name)
Expand Down Expand Up @@ -210,7 +208,6 @@ def get_recorder(self, recorder_id=None, recorder_name=None, create=True):
else:
recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
if is_new:
mlflow.set_experiment(self.name)
self.active_recorder = recorder
# start the recorder
self.active_recorder.start_run()
Expand Down Expand Up @@ -239,7 +236,7 @@ def _get_recorder(self, recorder_id=None, recorder_name=None):
), "Please input at least one of recorder id or name before retrieving recorder."
if recorder_id is not None:
try:
run = self.client.get_run(recorder_id)
run = self._client.get_run(recorder_id)
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=run)
return recorder
except MlflowException:
Expand All @@ -260,18 +257,18 @@ def search_records(self, **kwargs):
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
order_by = kwargs.get("order_by")

return self.client.search_runs([self.id], filter_string, run_view_type, max_results, order_by)
return self._client.search_runs([self.id], filter_string, run_view_type, max_results, order_by)

def delete_recorder(self, recorder_id=None, recorder_name=None):
assert (
recorder_id is not None or recorder_name is not None
), "Please input a valid recorder id or name before deleting."
try:
if recorder_id is not None:
self.client.delete_run(recorder_id)
self._client.delete_run(recorder_id)
else:
recorder = self._get_recorder(recorder_name=recorder_name)
self.client.delete_run(recorder.id)
self._client.delete_run(recorder.id)
except MlflowException as e:
raise Exception(
f"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct."
Expand All @@ -280,7 +277,7 @@ def delete_recorder(self, recorder_id=None, recorder_name=None):
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!

def list_recorders(self, max_results=UNLIMITED):
runs = self.client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1]
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1]
recorders = dict()
for i in range(len(runs)):
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
Expand Down
87 changes: 66 additions & 21 deletions qlib/workflow/expm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import os
from pathlib import Path
from contextlib import contextmanager
from typing import Optional, Text

from .exp import MLflowExperiment, Experiment
from .recorder import Recorder, MLflowRecorder
from .recorder import Recorder
from ..log import get_module_logger

logger = get_module_logger("workflow", "INFO")
Expand All @@ -20,12 +22,24 @@ class ExpManager:
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
"""

def __init__(self, uri, default_exp_name):
self.uri = uri
def __init__(self, uri: Text, default_exp_name: Optional[Text]):
self._default_uri = uri
self._current_uri = None
self.default_exp_name = default_exp_name
self.active_experiment = None # only one experiment can active each time

def start_exp(self, experiment_name=None, recorder_name=None, uri=None, **kwargs):
def __repr__(self):
Derek-Wds marked this conversation as resolved.
Show resolved Hide resolved
return "{name}(default_uri={duri}, current_uri={curi})".format(
name=self.__class__.__name__, duri=self._default_uri, curi=self._current_uri
)

def start_exp(
self,
experiment_name: Optional[Text] = None,
recorder_name: Optional[Text] = None,
uri: Optional[Text] = None,
**kwargs,
):
"""
Start an experiment. This method includes first get_or_create an experiment, and then
set it to be active.
Expand All @@ -45,7 +59,7 @@ def start_exp(self, experiment_name=None, recorder_name=None, uri=None, **kwargs
"""
raise NotImplementedError(f"Please implement the `start_exp` method.")

def end_exp(self, recorder_status: str = Recorder.STATUS_S, **kwargs):
def end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs):
"""
End an active experiment.

Expand All @@ -58,7 +72,7 @@ def end_exp(self, recorder_status: str = Recorder.STATUS_S, **kwargs):
"""
raise NotImplementedError(f"Please implement the `end_exp` method.")

def create_exp(self, experiment_name=None):
def create_exp(self, experiment_name: Optional[Text] = None):
"""
Create an experiment.

Expand Down Expand Up @@ -203,15 +217,40 @@ def delete_exp(self, experiment_id=None, experiment_name=None):
"""
raise NotImplementedError(f"Please implement the `delete_exp` method.")

def get_uri(self):
@property
def uri(self):
"""
Get the default tracking URI or current URI.

Returns
-------
The tracking URI string.
"""
return self.uri
return self._current_uri or self._default_uri

def set_uri(self, uri: Optional[Text] = None):
"""
Set the current tracking URI and the corresponding variables.

Parameters
----------
uri : str

"""
if uri is None:
logger.info("No tracking URI is provided. Use the default tracking URI.")
self._current_uri = self._default_uri
else:
# Temporarily re-set the current uri as the uri argument.
self._current_uri = uri
# Customized features for subclasses.
self._set_uri()

def _set_uri(self):
"""
Customized features for subclasses' set_uri function.
"""
raise NotImplementedError(f"Please implement the `_set_uri` method.")

def list_experiments(self):
"""
Expand All @@ -229,37 +268,43 @@ class MLflowExpManager(ExpManager):
Use mlflow to implement ExpManager.
"""

def __init__(self, uri, default_exp_name):
def __init__(self, uri: Text, default_exp_name: Optional[Text]):
super(MLflowExpManager, self).__init__(uri, default_exp_name)
self._client = None

def _set_uri(self):
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
logger.info("{:}".format(self._client))

@property
def client(self):
# Delay the creation of mlflow client in case of creating `mlruns` folder when importing qlib
if not hasattr(self, "_client"):
if self._client is None:
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
return self._client

def start_exp(self, experiment_name=None, recorder_name=None, uri=None):
# set the tracking uri
if uri is None:
logger.info("No tracking URI is provided. Use the default tracking URI.")
else:
self.uri = uri
# create experiment
def start_exp(
self, experiment_name: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None
):
# Set the tracking uri
self.set_uri(uri)
# Create experiment
experiment, _ = self._get_or_create_exp(experiment_name=experiment_name)
# set up active experiment
# Set up active experiment
self.active_experiment = experiment
# start the experiment
# Start the experiment
self.active_experiment.start(recorder_name)

return self.active_experiment

def end_exp(self, recorder_status: str = Recorder.STATUS_S):
def end_exp(self, recorder_status: Text = Recorder.STATUS_S):
if self.active_experiment is not None:
self.active_experiment.end(recorder_status)
self.active_experiment = None
# When an experiment end, we will release the current uri.
self._current_uri = None

def create_exp(self, experiment_name=None):
def create_exp(self, experiment_name: Optional[Text] = None):
assert experiment_name is not None
# init experiment
experiment_id = self.client.create_experiment(experiment_name)
Expand Down
2 changes: 1 addition & 1 deletion qlib/workflow/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, experiment_id, name):
self.status = Recorder.STATUS_S

def __repr__(self):
return str(self.info)
return "{name}(info={info})".format(name=self.__class__.__name__, info=self.info)

def __str__(self):
return str(self.info)
Expand Down
Loading