diff --git a/canals/pipeline/draw/draw.py b/canals/pipeline/draw/draw.py index 33b7329837..5fc60ebefe 100644 --- a/canals/pipeline/draw/draw.py +++ b/canals/pipeline/draw/draw.py @@ -14,13 +14,13 @@ from canals.utils import _type_name logger = logging.getLogger(__name__) -RenderingEngines = Literal["graphviz", "mermaid-img", "mermaid-text"] +RenderingEngines = Literal["graphviz", "mermaid-image", "mermaid-text"] def _draw( graph: networkx.MultiDiGraph, path: Path, - engine: RenderingEngines = "mermaid-img", + engine: RenderingEngines = "mermaid-image", style_map: Optional[Dict[str, str]] = None, ) -> None: """ @@ -31,7 +31,7 @@ def _draw( if engine == "graphviz": converted_graph.draw(path) - elif engine == "mermaid-img": + elif engine == "mermaid-image": with open(path, "wb") as imagefile: imagefile.write(converted_graph) @@ -57,7 +57,7 @@ def _convert_for_debug( def _convert( graph: networkx.MultiDiGraph, - engine: RenderingEngines = "mermaid-img", + engine: RenderingEngines = "mermaid-image", style_map: Optional[Dict[str, str]] = None, ) -> Any: """ @@ -68,7 +68,7 @@ def _convert( if engine == "graphviz": return _to_agraph(graph=graph) - if engine == "mermaid-img": + if engine == "mermaid-image": return _to_mermaid_image(graph=graph) if engine == "mermaid-text": @@ -95,7 +95,7 @@ def _prepare_for_drawing(graph: networkx.MultiDiGraph, style_map: Dict[str, str] graph.add_node("input") for node, in_sockets in _find_pipeline_inputs(graph).items(): for in_socket in in_sockets: - if in_socket.sender is None: + if in_socket.sender is None and not in_socket.is_optional: # If this socket has no sender it could be a socket that receives input # directly when running the Pipeline. We can't know that for sure, in doubt # we draw it as receiving input directly. diff --git a/canals/pipeline/draw/mermaid.py b/canals/pipeline/draw/mermaid.py index 62584b3e84..0859072f74 100644 --- a/canals/pipeline/draw/mermaid.py +++ b/canals/pipeline/draw/mermaid.py @@ -3,14 +3,12 @@ # SPDX-License-Identifier: Apache-2.0 import logging import base64 -import json import requests import networkx from canals.errors import PipelineDrawingError from canals.utils import _type_name -from canals.serialization import component_to_dict logger = logging.getLogger(__name__) @@ -18,15 +16,11 @@ MERMAID_STYLED_TEMPLATE = """ %%{{ init: {{'theme': 'neutral' }} }}%% -stateDiagram-v2 - -{states} - -{notes} +graph TD; {connections} -classDef components text-align:center; +classDef component text-align:center; """ @@ -66,54 +60,43 @@ def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str: Converts a Networkx graph into Mermaid syntax. The output of this function can be used in the documentation with `mermaid` codeblocks and it will be automatically rendered. """ - init_params = {} - for name, comp in graph.nodes(data="instance"): - if name in ["input", "output"]: - continue - data = component_to_dict(comp) - params = [f"{k}={json.dumps(v)}" for k, v in data.get("init_parameters", {}).items()] - init_params[name] = ",
".join(params) - states = "\n".join( - [ - f"{comp}:::components: {comp}
{type(data['instance']).__name__}({init_params[comp]})" - for comp, data in graph.nodes(data=True) - if comp not in ["input", "output"] - ] - ) sockets = { - comp: "\n".join( + comp: "".join( [ - f"{name} {_type_name(socket.type)}" + f"
  • {name} ({_type_name(socket.type)})
  • " for name, socket in data.get("input_sockets", {}).items() if socket.is_optional and socket.sender is None ] ) for comp, data in graph.nodes(data=True) } - notes = "\n".join( - [ - f"note left of {comp}\n {sockets[comp]}\nend note" - for comp in graph.nodes - if comp not in ["input", "output"] and sockets[comp] != "" - ] - ) + optional_inputs = { + comp: f"

    Optional inputs:" if sockets else "" + for comp, sockets in sockets.items() + } + + states = { + comp: f"{comp}[\"{comp}
    {type(data['instance']).__name__}{optional_inputs[comp]}\"]:::component" + for comp, data in graph.nodes(data=True) + if comp not in ["input", "output"] + } connections_list = [ - f"{from_comp} --> {to_comp} : {conn_data['label']} ({conn_data['conn_type']})" + f"{states[from_comp]} -- \"{conn_data['label']}
    {conn_data['conn_type']}\" --> {states[to_comp]}" for from_comp, to_comp, conn_data in graph.edges(data=True) if from_comp != "input" and to_comp != "output" ] input_connections = [ - f"[*] --> {to_comp} : {conn_data['label']} ({conn_data['conn_type']})" + f"i{{*}} -- \"{conn_data['label']}
    {conn_data['conn_type']}\" --> {states[to_comp]}" for _, to_comp, conn_data in graph.out_edges("input", data=True) ] output_connections = [ - f"{from_comp} --> [*] : {conn_data['label']} ({conn_data['conn_type']})" + f"{states[from_comp]} -- \"{conn_data['label']}
    {conn_data['conn_type']}\"--> o{{*}}" for from_comp, _, conn_data in graph.in_edges("output", data=True) ] connections = "\n".join(connections_list + input_connections + output_connections) - graph_styled = MERMAID_STYLED_TEMPLATE.format(states=states, notes=notes, connections=connections) + graph_styled = MERMAID_STYLED_TEMPLATE.format(connections=connections) logger.debug("Mermaid diagram:\n%s", graph_styled) return graph_styled diff --git a/canals/pipeline/pipeline.py b/canals/pipeline/pipeline.py index bb54947573..2ebb16f7a6 100644 --- a/canals/pipeline/pipeline.py +++ b/canals/pipeline/pipeline.py @@ -338,15 +338,15 @@ def get_component(self, name: str) -> Component: except KeyError as exc: raise ValueError(f"Component named {name} not found in the pipeline.") from exc - def draw(self, path: Path, engine: RenderingEngines = "mermaid-img") -> None: + def draw(self, path: Path, engine: RenderingEngines = "mermaid-image") -> None: """ Draws the pipeline. Requires either `graphviz` as a system dependency, or an internet connection for Mermaid. Run `pip install canals[graphviz]` or `pip install canals[mermaid]` to install missing dependencies. Args: path: where to save the diagram. - engine: which format to save the graph as. Accepts 'graphviz', 'mermaid-text', 'mermaid-img'. - Default is 'mermaid-img'. + engine: which format to save the graph as. Accepts 'graphviz', 'mermaid-text', 'mermaid-image'. + Default is 'mermaid-image'. Returns: None @@ -360,7 +360,7 @@ def draw(self, path: Path, engine: RenderingEngines = "mermaid-img") -> None: for comp, data in self.graph.nodes(data=True) } print(sockets) - _draw(graph=deepcopy(self.graph), path=path, engine=engine) + _draw(graph=networkx.MultiDiGraph(self.graph), path=path, engine=engine) def warm_up(self): """ diff --git a/test/pipelines/unit/test_draw.py b/test/pipelines/unit/test_draw.py index acfcdb9c55..0cad88ca5f 100644 --- a/test/pipelines/unit/test_draw.py +++ b/test/pipelines/unit/test_draw.py @@ -28,14 +28,14 @@ def test_draw_pygraphviz(tmp_path, test_files): assert filecmp.cmp(tmp_path / "test_pipe.jpg", test_files / "pipeline_draw" / "pygraphviz.jpg") -def test_draw_mermaid_img(tmp_path, test_files): +def test_draw_mermaid_image(tmp_path, test_files): pipe = Pipeline() pipe.add_component("comp1", Double()) pipe.add_component("comp2", Double()) pipe.connect("comp1", "comp2") pipe.connect("comp2", "comp1") - _draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="mermaid-img") + _draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="mermaid-image") assert os.path.exists(tmp_path / "test_pipe.jpg") assert filecmp.cmp(tmp_path / "test_pipe.jpg", test_files / "mermaid_mock" / "test_response.png") @@ -59,7 +59,7 @@ def raise_for_status(self): mock_get.return_value = mock_response with pytest.raises(PipelineDrawingError, match="There was an issue with https://mermaid.ink/"): - _draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="mermaid-img") + _draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="mermaid-image") def test_draw_mermaid_txt(tmp_path): @@ -76,20 +76,12 @@ def test_draw_mermaid_txt(tmp_path): == """ %%{ init: {'theme': 'neutral' } }%% -stateDiagram-v2 +graph TD; -comp1:::components: comp1
    AddFixedValue(add=3) -comp2:::components: comp2
    Double() +comp1["comp1
    AddFixedValue

    Optional inputs:
    "]:::component -- "result -> value
    int" --> comp2["comp2
    Double"]:::component +comp2["comp2
    Double"]:::component -- "value -> value
    int" --> comp1["comp1
    AddFixedValue

    Optional inputs:
    "]:::component -note left of comp1 - add Optional[int] -end note - -comp1 --> comp2 : result -> value (int) -comp2 --> comp1 : value -> value (int) -[*] --> comp1 : add (Optional[int]) - -classDef components text-align:center; +classDef component text-align:center; """ )