Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

visualizer interface #1062

Merged
merged 5 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions python/src/robyn/visualization/allocator_visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Dict, Any
import matplotlib.pyplot as plt
from .base_visualizer import BaseVisualizer

class AllocatorVisualizer(BaseVisualizer):
def __init__(self, allocator_data: Dict[str, Any]):
super().__init__()
self.allocator_data = allocator_data

def plot_allocator(self) -> plt.Figure:
"""
Plot allocator's output.

Returns:
plt.Figure: The generated figure.
"""
pass
9 changes: 9 additions & 0 deletions python/src/robyn/visualization/base_visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import matplotlib.pyplot as plt

class BaseVisualizer:
def __init__(self):
pass

def _setup_plot(self):
# Common plot setup logic
pass
41 changes: 41 additions & 0 deletions python/src/robyn/visualization/input_visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Dict, Any
import matplotlib.pyplot as plt
from .base_visualizer import BaseVisualizer

class InputVisualizer(BaseVisualizer):
def __init__(self, input_data: Dict[str, Any]):
super().__init__()
self.input_data = input_data

def plot_adstock(self) -> plt.Figure:
"""
Create example plots for adstock hyperparameters.

Returns:
plt.Figure: The generated figure.
"""
fig, ax = plt.subplots()
# Add plotting logic here
return fig

def plot_saturation(self) -> plt.Figure:
"""
Create example plots for saturation hyperparameters.

Returns:
plt.Figure: The generated figure.
"""
fig, ax = plt.subplots()
# Add plotting logic here
return fig

def plot_spend_exposure_fit(self) -> Dict[str, plt.Figure]:
"""
Check spend exposure fit if available.

Returns:
Dict[str, plt.Figure]: A dictionary of generated figures.
"""
figures = {}
# Add plotting logic here
return figures
49 changes: 49 additions & 0 deletions python/src/robyn/visualization/model_visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Dict, Any
import matplotlib.pyplot as plt
from .base_visualizer import BaseVisualizer

class ModelVisualizer(BaseVisualizer):
def __init__(self, model_data: Dict[str, Any]):
super().__init__()
self.model_data = model_data

def plot_moo_distribution(self) -> plt.Figure:
"""
Plot MOO (multi-objective optimization) distribution.

Returns:
plt.Figure: The generated figure.
"""
pass

def plot_moo_cloud(self) -> plt.Figure:
"""
Plot MOO (multi-objective optimization) cloud.

Returns:
plt.Figure: The generated figure.
"""
pass

def plot_ts_validation(self) -> plt.Figure:
"""
Plot time-series validation.

Returns:
plt.Figure: The generated figure.
"""
pass

def plot_onepager(self, input_collect: Dict[str, Any], output_collect: Dict[str, Any], select_model: str) -> Dict[str, plt.Figure]:
"""
Generate one-pager plots for a selected model.

Args:
input_collect (Dict[str, Any]): The input collection data.
output_collect (Dict[str, Any]): The output collection data.
select_model (str): The selected model identifier.

Returns:
Dict[str, plt.Figure]: A dictionary of generated figures.
"""
pass
26 changes: 26 additions & 0 deletions python/src/robyn/visualization/response_visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Dict, Any
import matplotlib.pyplot as plt
from .base_visualizer import BaseVisualizer

class ResponseVisualizer(BaseVisualizer):
def __init__(self, response_data: Dict[str, Any]):
super().__init__()
self.response_data = response_data

def plot_response(self) -> plt.Figure:
"""
Plot response curves.

Returns:
plt.Figure: The generated figure.
"""
pass

def plot_marginal_response(self) -> plt.Figure:
"""
Plot marginal response curves.

Returns:
plt.Figure: The generated figure.
"""
pass
52 changes: 52 additions & 0 deletions python/src/robyn/visualization/robyn_visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Dict, Any
import matplotlib.pyplot as plt
from .input_visualizer import InputVisualizer
from .model_visualizer import ModelVisualizer
from .allocator_visualizer import AllocatorVisualizer
from .response_visualizer import ResponseVisualizer

class RobynVisualizer:
def __init__(self):
self.input_visualizer = None
self.model_visualizer = None
self.allocator_visualizer = None
self.response_visualizer = None

def set_input_data(self, input_data: Dict[str, Any]):
self.input_visualizer = InputVisualizer(input_data)

def set_model_data(self, model_data: Dict[str, Any]):
self.model_visualizer = ModelVisualizer(model_data)

def set_allocator_data(self, allocator_data: Dict[str, Any]):
self.allocator_visualizer = AllocatorVisualizer(allocator_data)

def set_response_data(self, response_data: Dict[str, Any]):
self.response_visualizer = ResponseVisualizer(response_data)

def plot_adstock(self) -> plt.Figure:
return self.input_visualizer.plot_adstock()

def plot_saturation(self) -> plt.Figure:
return self.input_visualizer.plot_saturation()

def plot_moo_distribution(self) -> plt.Figure:
return self.model_visualizer.plot_moo_distribution()

def plot_moo_cloud(self) -> plt.Figure:
return self.model_visualizer.plot_moo_cloud()

def plot_ts_validation(self) -> plt.Figure:
return self.model_visualizer.plot_ts_validation()

def plot_onepager(self, input_collect: Dict[str, Any], output_collect: Dict[str, Any], select_model: str) -> Dict[str, plt.Figure]:
return self.model_visualizer.plot_onepager(input_collect, output_collect, select_model)

def plot_allocator(self) -> plt.Figure:
return self.allocator_visualizer.plot_allocator()

def plot_response(self) -> plt.Figure:
return self.response_visualizer.plot_response()

def plot_marginal_response(self) -> plt.Figure:
return self.response_visualizer.plot_marginal_response()