diff --git a/canals/pipeline/draw/draw.py b/canals/pipeline/draw/draw.py
index 33b7329837..5fc60ebefe 100644
--- a/canals/pipeline/draw/draw.py
+++ b/canals/pipeline/draw/draw.py
@@ -14,13 +14,13 @@
from canals.utils import _type_name
logger = logging.getLogger(__name__)
-RenderingEngines = Literal["graphviz", "mermaid-img", "mermaid-text"]
+RenderingEngines = Literal["graphviz", "mermaid-image", "mermaid-text"]
def _draw(
graph: networkx.MultiDiGraph,
path: Path,
- engine: RenderingEngines = "mermaid-img",
+ engine: RenderingEngines = "mermaid-image",
style_map: Optional[Dict[str, str]] = None,
) -> None:
"""
@@ -31,7 +31,7 @@ def _draw(
if engine == "graphviz":
converted_graph.draw(path)
- elif engine == "mermaid-img":
+ elif engine == "mermaid-image":
with open(path, "wb") as imagefile:
imagefile.write(converted_graph)
@@ -57,7 +57,7 @@ def _convert_for_debug(
def _convert(
graph: networkx.MultiDiGraph,
- engine: RenderingEngines = "mermaid-img",
+ engine: RenderingEngines = "mermaid-image",
style_map: Optional[Dict[str, str]] = None,
) -> Any:
"""
@@ -68,7 +68,7 @@ def _convert(
if engine == "graphviz":
return _to_agraph(graph=graph)
- if engine == "mermaid-img":
+ if engine == "mermaid-image":
return _to_mermaid_image(graph=graph)
if engine == "mermaid-text":
@@ -95,7 +95,7 @@ def _prepare_for_drawing(graph: networkx.MultiDiGraph, style_map: Dict[str, str]
graph.add_node("input")
for node, in_sockets in _find_pipeline_inputs(graph).items():
for in_socket in in_sockets:
- if in_socket.sender is None:
+ if in_socket.sender is None and not in_socket.is_optional:
# If this socket has no sender it could be a socket that receives input
# directly when running the Pipeline. We can't know that for sure, in doubt
# we draw it as receiving input directly.
diff --git a/canals/pipeline/draw/mermaid.py b/canals/pipeline/draw/mermaid.py
index 62584b3e84..0859072f74 100644
--- a/canals/pipeline/draw/mermaid.py
+++ b/canals/pipeline/draw/mermaid.py
@@ -3,14 +3,12 @@
# SPDX-License-Identifier: Apache-2.0
import logging
import base64
-import json
import requests
import networkx
from canals.errors import PipelineDrawingError
from canals.utils import _type_name
-from canals.serialization import component_to_dict
logger = logging.getLogger(__name__)
@@ -18,15 +16,11 @@
MERMAID_STYLED_TEMPLATE = """
%%{{ init: {{'theme': 'neutral' }} }}%%
-stateDiagram-v2
-
-{states}
-
-{notes}
+graph TD;
{connections}
-classDef components text-align:center;
+classDef component text-align:center;
"""
@@ -66,54 +60,43 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str:
Converts a Networkx graph into Mermaid syntax. The output of this function can be used in the documentation
with `mermaid` codeblocks and it will be automatically rendered.
"""
- init_params = {}
- for name, comp in graph.nodes(data="instance"):
- if name in ["input", "output"]:
- continue
- data = component_to_dict(comp)
- params = [f"{k}={json.dumps(v)}" for k, v in data.get("init_parameters", {}).items()]
- init_params[name] = ",
".join(params)
- states = "\n".join(
- [
- f"{comp}:::components: {comp}
{type(data['instance']).__name__}({init_params[comp]})"
- for comp, data in graph.nodes(data=True)
- if comp not in ["input", "output"]
- ]
- )
sockets = {
- comp: "\n".join(
+ comp: "".join(
[
- f"{name} {_type_name(socket.type)}"
+ f"
{name} ({_type_name(socket.type)})"
for name, socket in data.get("input_sockets", {}).items()
if socket.is_optional and socket.sender is None
]
)
for comp, data in graph.nodes(data=True)
}
- notes = "\n".join(
- [
- f"note left of {comp}\n {sockets[comp]}\nend note"
- for comp in graph.nodes
- if comp not in ["input", "output"] and sockets[comp] != ""
- ]
- )
+ optional_inputs = {
+ comp: f"
Optional inputs:" if sockets else ""
+ for comp, sockets in sockets.items()
+ }
+
+ states = {
+ comp: f"{comp}[\"{comp}
{type(data['instance']).__name__}{optional_inputs[comp]}\"]:::component"
+ for comp, data in graph.nodes(data=True)
+ if comp not in ["input", "output"]
+ }
connections_list = [
- f"{from_comp} --> {to_comp} : {conn_data['label']} ({conn_data['conn_type']})"
+ f"{states[from_comp]} -- \"{conn_data['label']}
{conn_data['conn_type']}\" --> {states[to_comp]}"
for from_comp, to_comp, conn_data in graph.edges(data=True)
if from_comp != "input" and to_comp != "output"
]
input_connections = [
- f"[*] --> {to_comp} : {conn_data['label']} ({conn_data['conn_type']})"
+ f"i{{*}} -- \"{conn_data['label']}
{conn_data['conn_type']}\" --> {states[to_comp]}"
for _, to_comp, conn_data in graph.out_edges("input", data=True)
]
output_connections = [
- f"{from_comp} --> [*] : {conn_data['label']} ({conn_data['conn_type']})"
+ f"{states[from_comp]} -- \"{conn_data['label']}
{conn_data['conn_type']}\"--> o{{*}}"
for from_comp, _, conn_data in graph.in_edges("output", data=True)
]
connections = "\n".join(connections_list + input_connections + output_connections)
- graph_styled = MERMAID_STYLED_TEMPLATE.format(states=states, notes=notes, connections=connections)
+ graph_styled = MERMAID_STYLED_TEMPLATE.format(connections=connections)
logger.debug("Mermaid diagram:\n%s", graph_styled)
return graph_styled
diff --git a/canals/pipeline/pipeline.py b/canals/pipeline/pipeline.py
index bb54947573..2ebb16f7a6 100644
--- a/canals/pipeline/pipeline.py
+++ b/canals/pipeline/pipeline.py
@@ -338,15 +338,15 @@ def get_component(self, name: str) -> Component:
except KeyError as exc:
raise ValueError(f"Component named {name} not found in the pipeline.") from exc
- def draw(self, path: Path, engine: RenderingEngines = "mermaid-img") -> None:
+ def draw(self, path: Path, engine: RenderingEngines = "mermaid-image") -> None:
"""
Draws the pipeline. Requires either `graphviz` as a system dependency, or an internet connection for Mermaid.
Run `pip install canals[graphviz]` or `pip install canals[mermaid]` to install missing dependencies.
Args:
path: where to save the diagram.
- engine: which format to save the graph as. Accepts 'graphviz', 'mermaid-text', 'mermaid-img'.
- Default is 'mermaid-img'.
+ engine: which format to save the graph as. Accepts 'graphviz', 'mermaid-text', 'mermaid-image'.
+ Default is 'mermaid-image'.
Returns:
None
@@ -360,7 +360,7 @@ def draw(self, path: Path, engine: RenderingEngines = "mermaid-img") -> None:
for comp, data in self.graph.nodes(data=True)
}
print(sockets)
- _draw(graph=deepcopy(self.graph), path=path, engine=engine)
+ _draw(graph=networkx.MultiDiGraph(self.graph), path=path, engine=engine)
def warm_up(self):
"""
diff --git a/test/pipelines/unit/test_draw.py b/test/pipelines/unit/test_draw.py
index acfcdb9c55..0cad88ca5f 100644
--- a/test/pipelines/unit/test_draw.py
+++ b/test/pipelines/unit/test_draw.py
@@ -28,14 +28,14 @@ def test_draw_pygraphviz(tmp_path, test_files):
assert filecmp.cmp(tmp_path / "test_pipe.jpg", test_files / "pipeline_draw" / "pygraphviz.jpg")
-def test_draw_mermaid_img(tmp_path, test_files):
+def test_draw_mermaid_image(tmp_path, test_files):
pipe = Pipeline()
pipe.add_component("comp1", Double())
pipe.add_component("comp2", Double())
pipe.connect("comp1", "comp2")
pipe.connect("comp2", "comp1")
- _draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="mermaid-img")
+ _draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="mermaid-image")
assert os.path.exists(tmp_path / "test_pipe.jpg")
assert filecmp.cmp(tmp_path / "test_pipe.jpg", test_files / "mermaid_mock" / "test_response.png")
@@ -59,7 +59,7 @@ def raise_for_status(self):
mock_get.return_value = mock_response
with pytest.raises(PipelineDrawingError, match="There was an issue with https://mermaid.ink/"):
- _draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="mermaid-img")
+ _draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="mermaid-image")
def test_draw_mermaid_txt(tmp_path):
@@ -76,20 +76,12 @@ def test_draw_mermaid_txt(tmp_path):
== """
%%{ init: {'theme': 'neutral' } }%%
-stateDiagram-v2
+graph TD;
-comp1:::components: comp1
AddFixedValue(add=3)
-comp2:::components: comp2
Double()
+comp1["comp1
AddFixedValue
Optional inputs:"]:::component -- "result -> value
int" --> comp2["comp2
Double"]:::component
+comp2["comp2
Double"]:::component -- "value -> value
int" --> comp1["comp1
AddFixedValue
Optional inputs:"]:::component
-note left of comp1
- add Optional[int]
-end note
-
-comp1 --> comp2 : result -> value (int)
-comp2 --> comp1 : value -> value (int)
-[*] --> comp1 : add (Optional[int])
-
-classDef components text-align:center;
+classDef component text-align:center;
"""
)