Skip to content

Commit

Permalink
fix: draw function vs init parameters (#115)
Browse files Browse the repository at this point in the history
* fix draw

* stray print
  • Loading branch information
ZanSara authored Sep 27, 2023
1 parent 424c6f5 commit e2f5187
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 60 deletions.
12 changes: 6 additions & 6 deletions canals/pipeline/draw/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from canals.utils import _type_name

logger = logging.getLogger(__name__)
RenderingEngines = Literal["graphviz", "mermaid-img", "mermaid-text"]
RenderingEngines = Literal["graphviz", "mermaid-image", "mermaid-text"]


def _draw(
graph: networkx.MultiDiGraph,
path: Path,
engine: RenderingEngines = "mermaid-img",
engine: RenderingEngines = "mermaid-image",
style_map: Optional[Dict[str, str]] = None,
) -> None:
"""
Expand All @@ -31,7 +31,7 @@ def _draw(
if engine == "graphviz":
converted_graph.draw(path)

elif engine == "mermaid-img":
elif engine == "mermaid-image":
with open(path, "wb") as imagefile:
imagefile.write(converted_graph)

Expand All @@ -57,7 +57,7 @@ def _convert_for_debug(

def _convert(
graph: networkx.MultiDiGraph,
engine: RenderingEngines = "mermaid-img",
engine: RenderingEngines = "mermaid-image",
style_map: Optional[Dict[str, str]] = None,
) -> Any:
"""
Expand All @@ -68,7 +68,7 @@ def _convert(
if engine == "graphviz":
return _to_agraph(graph=graph)

if engine == "mermaid-img":
if engine == "mermaid-image":
return _to_mermaid_image(graph=graph)

if engine == "mermaid-text":
Expand All @@ -95,7 +95,7 @@ def _prepare_for_drawing(graph: networkx.MultiDiGraph, style_map: Dict[str, str]
graph.add_node("input")
for node, in_sockets in _find_pipeline_inputs(graph).items():
for in_socket in in_sockets:
if in_socket.sender is None:
if in_socket.sender is None and not in_socket.is_optional:
# If this socket has no sender it could be a socket that receives input
# directly when running the Pipeline. We can't know that for sure, in doubt
# we draw it as receiving input directly.
Expand Down
53 changes: 18 additions & 35 deletions canals/pipeline/draw/mermaid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,24 @@
# SPDX-License-Identifier: Apache-2.0
import logging
import base64
import json

import requests
import networkx

from canals.errors import PipelineDrawingError
from canals.utils import _type_name
from canals.serialization import component_to_dict

logger = logging.getLogger(__name__)


MERMAID_STYLED_TEMPLATE = """
%%{{ init: {{'theme': 'neutral' }} }}%%
stateDiagram-v2
{states}
{notes}
graph TD;
{connections}
classDef components text-align:center;
classDef component text-align:center;
"""


Expand Down Expand Up @@ -66,54 +60,43 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str:
Converts a Networkx graph into Mermaid syntax. The output of this function can be used in the documentation
with `mermaid` codeblocks and it will be automatically rendered.
"""
init_params = {}
for name, comp in graph.nodes(data="instance"):
if name in ["input", "output"]:
continue
data = component_to_dict(comp)
params = [f"{k}={json.dumps(v)}" for k, v in data.get("init_parameters", {}).items()]
init_params[name] = ",<br>".join(params)
states = "\n".join(
[
f"{comp}:::components: <b>{comp}</b><br><small><i>{type(data['instance']).__name__}({init_params[comp]})</i></small>"
for comp, data in graph.nodes(data=True)
if comp not in ["input", "output"]
]
)
sockets = {
comp: "\n".join(
comp: "".join(
[
f"{name} <small><i>{_type_name(socket.type)}</i></small>"
f"<li>{name} ({_type_name(socket.type)})</li>"
for name, socket in data.get("input_sockets", {}).items()
if socket.is_optional and socket.sender is None
]
)
for comp, data in graph.nodes(data=True)
}
notes = "\n".join(
[
f"note left of {comp}\n {sockets[comp]}\nend note"
for comp in graph.nodes
if comp not in ["input", "output"] and sockets[comp] != ""
]
)
optional_inputs = {
comp: f"<br><br>Optional inputs:<ul style='text-align:left;'>{sockets}</ul>" if sockets else ""
for comp, sockets in sockets.items()
}

states = {
comp: f"{comp}[\"<b>{comp}</b><br><small><i>{type(data['instance']).__name__}{optional_inputs[comp]}</i></small>\"]:::component"
for comp, data in graph.nodes(data=True)
if comp not in ["input", "output"]
}

connections_list = [
f"{from_comp} --> {to_comp} : {conn_data['label']} <small><i>({conn_data['conn_type']})</i></small>"
f"{states[from_comp]} -- \"{conn_data['label']}<br><small><i>{conn_data['conn_type']}</i></small>\" --> {states[to_comp]}"
for from_comp, to_comp, conn_data in graph.edges(data=True)
if from_comp != "input" and to_comp != "output"
]
input_connections = [
f"[*] --> {to_comp} : {conn_data['label']} <small><i>({conn_data['conn_type']})</i></small>"
f"i{{*}} -- \"{conn_data['label']}<br><small><i>{conn_data['conn_type']}</i></small>\" --> {states[to_comp]}"
for _, to_comp, conn_data in graph.out_edges("input", data=True)
]
output_connections = [
f"{from_comp} --> [*] : {conn_data['label']} <small><i>({conn_data['conn_type']})</i></small>"
f"{states[from_comp]} -- \"{conn_data['label']}<br><small><i>{conn_data['conn_type']}</i></small>\"--> o{{*}}"
for from_comp, _, conn_data in graph.in_edges("output", data=True)
]
connections = "\n".join(connections_list + input_connections + output_connections)

graph_styled = MERMAID_STYLED_TEMPLATE.format(states=states, notes=notes, connections=connections)
graph_styled = MERMAID_STYLED_TEMPLATE.format(connections=connections)
logger.debug("Mermaid diagram:\n%s", graph_styled)

return graph_styled
8 changes: 4 additions & 4 deletions canals/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,15 +338,15 @@ def get_component(self, name: str) -> Component:
except KeyError as exc:
raise ValueError(f"Component named {name} not found in the pipeline.") from exc

def draw(self, path: Path, engine: RenderingEngines = "mermaid-img") -> None:
def draw(self, path: Path, engine: RenderingEngines = "mermaid-image") -> None:
"""
Draws the pipeline. Requires either `graphviz` as a system dependency, or an internet connection for Mermaid.
Run `pip install canals[graphviz]` or `pip install canals[mermaid]` to install missing dependencies.
Args:
path: where to save the diagram.
engine: which format to save the graph as. Accepts 'graphviz', 'mermaid-text', 'mermaid-img'.
Default is 'mermaid-img'.
engine: which format to save the graph as. Accepts 'graphviz', 'mermaid-text', 'mermaid-image'.
Default is 'mermaid-image'.
Returns:
None
Expand All @@ -360,7 +360,7 @@ def draw(self, path: Path, engine: RenderingEngines = "mermaid-img") -> None:
for comp, data in self.graph.nodes(data=True)
}
print(sockets)
_draw(graph=deepcopy(self.graph), path=path, engine=engine)
_draw(graph=networkx.MultiDiGraph(self.graph), path=path, engine=engine)

def warm_up(self):
"""
Expand Down
22 changes: 7 additions & 15 deletions test/pipelines/unit/test_draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ def test_draw_pygraphviz(tmp_path, test_files):
assert filecmp.cmp(tmp_path / "test_pipe.jpg", test_files / "pipeline_draw" / "pygraphviz.jpg")


def test_draw_mermaid_img(tmp_path, test_files):
def test_draw_mermaid_image(tmp_path, test_files):
pipe = Pipeline()
pipe.add_component("comp1", Double())
pipe.add_component("comp2", Double())
pipe.connect("comp1", "comp2")
pipe.connect("comp2", "comp1")

_draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="mermaid-img")
_draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="mermaid-image")
assert os.path.exists(tmp_path / "test_pipe.jpg")
assert filecmp.cmp(tmp_path / "test_pipe.jpg", test_files / "mermaid_mock" / "test_response.png")

Expand All @@ -59,7 +59,7 @@ def raise_for_status(self):
mock_get.return_value = mock_response

with pytest.raises(PipelineDrawingError, match="There was an issue with https://mermaid.ink/"):
_draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="mermaid-img")
_draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="mermaid-image")


def test_draw_mermaid_txt(tmp_path):
Expand All @@ -76,20 +76,12 @@ def test_draw_mermaid_txt(tmp_path):
== """
%%{ init: {'theme': 'neutral' } }%%
stateDiagram-v2
graph TD;
comp1:::components: <b>comp1</b><br><small><i>AddFixedValue(add=3)</i></small>
comp2:::components: <b>comp2</b><br><small><i>Double()</i></small>
comp1["<b>comp1</b><br><small><i>AddFixedValue<br><br>Optional inputs:<ul style='text-align:left;'><li>add (Optional[int])</li></ul></i></small>"]:::component -- "result -> value<br><small><i>int</i></small>" --> comp2["<b>comp2</b><br><small><i>Double</i></small>"]:::component
comp2["<b>comp2</b><br><small><i>Double</i></small>"]:::component -- "value -> value<br><small><i>int</i></small>" --> comp1["<b>comp1</b><br><small><i>AddFixedValue<br><br>Optional inputs:<ul style='text-align:left;'><li>add (Optional[int])</li></ul></i></small>"]:::component
note left of comp1
add <small><i>Optional[int]</i></small>
end note
comp1 --> comp2 : result -> value <small><i>(int)</i></small>
comp2 --> comp1 : value -> value <small><i>(int)</i></small>
[*] --> comp1 : add <small><i>(Optional[int])</i></small>
classDef components text-align:center;
classDef component text-align:center;
"""
)

Expand Down

0 comments on commit e2f5187

Please sign in to comment.