Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add resource mechanism to Arena connector #420

Merged
merged 17 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions python/dalex/NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,20 @@
development
----------------------------------------------------------------

#### breaking changes

* method `set_options` in Arena now takies `option_category` instead of `plot_type` (`SHAPValues` => `ShapleyValues`, `FeatureImportance` => `VariableImportance`)

#### fixes

* fixed wrong error value when no `predict_function` is found in `Explainer` ([77ca90d](/~https://github.com/ModelOriented/DALEX/commit/77ca90d))
* set multiprocessing context to 'spawn'

#### features

* add resource mechanism to Arena
* add ShapleyValuesImportance and ShapleyValuesDependence charts to Arena

v1.1.0 (18/04/2021)
----------------------------------------------------------------

Expand Down
17 changes: 17 additions & 0 deletions python/dalex/dalex/arena/_option_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
class OptionBase:
"""
Base class providing methods for options. The goal of it is to
create common interface for PlotContainer and Resource.
"""
options_category = "base"
options = {}
def __init__(self, arena):
if type(arena).__name__ != 'Arena' or type(arena).__module__ != 'dalex.arena.object':
raise Exception('Invalid Arena argument')
self.arena = arena

def get_option(self, name):
return self.arena.get_option(self.__class__.options_category, name)

def set_option(self, name, value):
return self.arena.set_option(self.__class__.options_category, name, value)
93 changes: 87 additions & 6 deletions python/dalex/dalex/arena/_plot_container.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,68 @@
from ._option_base import OptionBase
from .params import ModelParam, DatasetParam, VariableParam, ObservationParam

class PlotContainer:
def __init__(self, arena):
if type(arena).__name__ != 'Arena' or type(arena).__module__ != 'dalex.arena.object':
raise Exception('Invalid Arena argument')

class PlotContainer(OptionBase):
"""
Class representing a chart.

Parameters
----------
arena : dalex.Arena
Instance of Arena.
cache : bool
If this object is allowed to use cache when requesting resources

Attributes
--------
arena : Arena
Instance of dalex.Arena
name : str
Display name of chart
plot_type : str
Identifier of chart type
plot_component : str
Identifier of Arena's component that should render this chart
plot_category : str
Name of category of chart
params : dict
Dictionary with required param types as keys and param labels as values.
This attribute is set when calling fit.
data : dict
Results of computations are placed there
progress : float
If progress is supprted, then value should be between [0,1]. For other situations -1 value must be set.
Progress of plot container is based of progress of used resources at the moment of calling fit method.
This value will not be updated.
use_cache : bool
If this object is allowed to use cache when requesting resources
"""
def __init__(self, arena, cache=True):
super().__init__(arena)
info = self.__class__.info
self.name = info.get('name')
self.plot_type = info.get('plotType')
self.plot_component = info.get('plotType')
self.plot_category = info.get('plotCategory')
self.params = {}
self.data = {}
self.arena = arena
self.is_done = None
self.progress = -1
# If plot class is allowed to use cache when requesting resources
self.use_cache = cache
def fit(self, params):
"""Function computes plots data for given params

Parameters
-----------
params : dict
Keys of this dict are params types (model, observation, variable, dataset)
and values are corresponding params values (class Param).

Returns
--------
PlotContainer object
"""
required_params = {}
for p in self.__class__.info.get('requiredParams'):
self.check_param(p, params.get(p))
Expand All @@ -22,21 +72,52 @@ def fit(self, params):
self._fit(**required_params)
return self
def serialize(self):
"""Saves important attributes of PlotContainer into a dict.
Returned dict is meant to be directly put into Arena data file.

Returns
--------
dict
"""
return {
'name': self.name,
'plotType': self.plot_type,
'plotComponent': self.plot_component,
'plotCategory': self.plot_category,
'params': self.params,
'data': self.data
'data': self.data,
'progress': self.progress,
'isDone': True if self.is_done is None else self.is_done
}
def set_message(self, msg, msg_type='info'):
"""Changes plot component to message and sets data with provided message

Parameters
-----------
msg : str
Text of message
msg_type : str
Type of message. One of ['info', 'error']
"""
if msg_type != 'info' and msg_type != 'error':
raise Exception('Invalid message type')
self.plot_component = 'Message'
self.data = {'message': msg, 'type': msg_type}

def check_param(self, param_type, value):
"""Function validates param values as param of given type

Parameters
-----------
param_type : str
One of ['model', 'variable', 'observation', 'dataset'].
value : Object
Function checks if this object have correct class.

Returns
--------
None if value is correct. Else raises exception
"""
correct_class = {'model': ModelParam, 'variable': VariableParam, 'observation': ObservationParam, 'dataset': DatasetParam}.get(param_type)
if not isinstance(value, correct_class):
raise Exception('Invalid param ' + str(param_type))
Expand Down
170 changes: 170 additions & 0 deletions python/dalex/dalex/arena/_plots_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import time
from . import plots
from tqdm import tqdm
from ._plot_container import PlotContainer


class PlotsManager:
"""Creates PlotsManager object

This class should be only created by arena instance to manage its plots.

Parameters
----------
arena : dalex.Arena
Instance of Arena.

Attributes
--------
arena : Arena
Instance of dalex.Arena
cache : list of PlotContainer objects
List of already calculated plots
mutex : _thread.lock
Mutex for params, plots and resources cache. Should be common with mutex from Arena instance.
plots : list of classes extending PlotContainer
List of available plot classes
"""
def __init__(self, arena):
if type(arena).__name__ != 'Arena' or type(arena).__module__ != 'dalex.arena.object':
raise Exception('Invalid Arena argument')
self.arena = arena
self.cache = []
self.mutex = arena.mutex
self.plots = [vars(plots)[res] for res in getattr(plots, '__all__')]

def get_supported_plots(self):
"""Returns plots classes that can produce at least one valid chart for parent arena.

Returns
-----------
List of classes extending PlotContainer
"""
return [plot for plot in self.plots if plot.test_arena(self.arena)]

def clear_cache(self, plot_type=None):
"""Clears cache

Parameters
-----------
plot_type : str or None
If None all cache is cleared. Otherwise only plots with
provided plot_type are removed.

Notes
-------
This function must be called from mutex context
"""
if plot_type is None:
self.cache = []
else:
self.cache = list(filter(lambda p: p.plot_type != plot_type, self.cache))
self.arena.update_timestamp()

def find_in_cache(self, plot_type, params):
"""Function searches for cached plot

Parameters
-----------
plot_type : str
Value of plot_type field, that requested plot must have
params : dict
Keys of this dict are params types (model, observation, variable, dataset)
and values are corresponding params labels. Requested plot must have equal
params field.

Returns
--------
PlotContainer or None
"""

def _filter(p):
return p.plot_type == plot_type and params == p.params
with self.mutex:
return next(filter(_filter, self.cache), None)

def put_to_cache(self, plot_container):
"""Puts new plot to cache

Parameters
-----------
plot_container : PlotContainer
"""
if not isinstance(plot_container, PlotContainer):
raise Exception('Invalid plot container')
with self.mutex:
self.cache.append(plot_container)

def fill_cache(self, fixed_params={}):
"""Generates all available plots and cache them

This function tries to generate all plots that are not cached already and
put them to cache. Range of generated plots can be narrow using `fixed_params`

Parameters
-----------
fixed_params : dict
This dict specifies which plots should be generated. Only those with
all keys from `fixed_params` present and having the same value will be
calculated.
"""
if not isinstance(fixed_params, dict):
raise Exception('Params argument must be a dict')
for plot_class in self.get_supported_plots():
required_params = plot_class.info.get('requiredParams')
# Test if all params fixed by user are used in this plot. If not, then skip it.
# This list contains fixed params' types, that are not required by plot.
# Loop will be skipped if this list is not empty.
if len([k for k in fixed_params.keys() if k not in required_params]) > 0:
continue
available_params = self.arena.get_available_params()
iteration_pools = map(lambda p: available_params.get(p) if fixed_params.get(p) is None else [fixed_params.get(p)], required_params)
combinations = [[]]
for pool in iteration_pools:
combinations = [x + [y] for x in combinations for y in pool]
if self.arena.verbose and len(combinations) > 0:
combinations = tqdm(combinations)
combinations.set_description(plot_class.info.get('name'))
for params_values in combinations:
params = dict(zip(required_params, params_values))
self.get_plot(plot_type=plot_class.info.get('plotType'), params_values=params, wait=True)

def get_plot(self, plot_type, params_values, cache=True, wait=False):
"""Returns plot for specified type and params

Function serches for plot in cache, when not present creates
requested plot and put it to cache.

Parameters
-----------
plot_type : str
Type of plot to be generated
params_values : dict
Dict for param types as keys and Param objects as values
cache : bool
If serach for plot in cache and put calculated plot into cache.

Returns
--------
PlotContainer
"""
plot_class = next((c for c in self.plots if c.info.get('plotType') == plot_type), None)
if plot_class is None:
raise Exception('Not supported plot type')
plot_type = plot_class.info.get('plotType')
required_params_values = {}
required_params_labels = {}
for p in plot_class.info.get('requiredParams'):
if params_values.get(p) is None:
raise Exception('Required param is missing')
required_params_values[p] = params_values.get(p)
required_params_labels[p] = params_values.get(p).get_label()
result = self.find_in_cache(plot_type, required_params_labels) if cache else None
if result is None:
result = plot_class(self.arena, cache=cache).fit(required_params_values)
while wait and result.is_done == False:
time.sleep(0.5)
result = plot_class(self.arena, cache=cache).fit(required_params_values)
if cache and result.is_done != False:
self.put_to_cache(result)
return result
Loading