from __future__ import annotations
from dataclasses import InitVar, dataclass, field
try:
import graphviz
except ModuleNotFoundError:
raise ModuleNotFoundError(
"graphviz is not installed, please install graphviz in your local environment to visualize artifacts"
) from None
from IPython.display import HTML, DisplayObject
from lineapy.data.graph import Graph
from lineapy.instrumentation.tracer import Tracer
from lineapy.visualizer.graphviz import to_graphviz
from lineapy.visualizer.optimize_svg import optimize_svg
from lineapy.visualizer.visual_graph import VisualGraphOptions
[docs]@dataclass
class Visualizer:
"""
Stores a rendered graphviz digraph. Has helper classmethods to use
for construction, as well as methods for output as different useful
formats.
"""
options: InitVar[VisualGraphOptions]
digraph: graphviz.Digraph = field(init=False)
def __post_init__(self, options: VisualGraphOptions):
self.digraph = to_graphviz(options)
[docs] def render_pdf_file(self, filename: str = "tracer") -> None:
"""
Renders a PDF file for the graph and tries to open it.
"""
self.digraph.render(filename, view=True, format="pdf")
def render_svg(self) -> str:
return optimize_svg(self.digraph.pipe(format="svg").decode())
def ipython_display_object(self) -> DisplayObject:
svg_text = self.render_svg()
# We emit this HTML to get a zoomable SVG
# Copied from https://github.com/jupyterlab/jupyterlab/issues/7497#issuecomment-557334236
# Which references https://github.com/pygraphkit/graphtik/blob/56a513c665e26e7bf3e81b6fb07d9475c5bf1614/graphtik/plot.py#L144-L183
# Which uses this library: https://github.com/bumbu/svg-pan-zoom
html_text = f"""
<div class="svg_container">
<style>
.svg_container SVG {{
width: 100%;
height: 100%;
}}
</style>
<script src="https://bumbu.me/svg-pan-zoom/dist/svg-pan-zoom.min.js"></script>
<script type="text/javascript">
var scriptTag = document.scripts[document.scripts.length - 1];
var parentTag = scriptTag.parentNode;
var svg_el = parentTag.querySelector(".svg_container svg");
svgPanZoom(svg_el, {{
controlIconsEnabled: true,
fit: true,
zoomScaleSensitivity: 0.2,
minZoom: 0.1,
maxZoom: 10
}});
</script>
{svg_text}
</div>
"""
return HTML(html_text)
[docs] @classmethod
def for_test_snapshot(cls, tracer: Tracer) -> Visualizer:
"""
Create a graph for saving as a snapshot, to help with visual diffs in PRs.
"""
options = VisualGraphOptions(
tracer.graph,
tracer,
highlight_node=None,
# This is generally repetitive, and we can avoid it.
show_implied_mutations=False,
# Views are too verbose to show in the test output
show_views=False,
show_artifacts=True,
show_variables=True,
)
return cls(options)
[docs] @classmethod
def for_test_cli(cls, tracer: Tracer) -> Visualizer:
"""
Create a graph to use when visualizing after passing in `--visualize`
during testing.
Show as much as we can for debugging.
"""
options = VisualGraphOptions(
tracer.graph,
tracer,
highlight_node=None,
show_implied_mutations=True,
show_views=True,
show_artifacts=True,
show_variables=True,
)
return cls(options)
[docs] @classmethod
def for_public(cls, tracer: Tracer) -> Visualizer:
"""
Create a graph for our public API, when showing the whole graph.
"""
options = VisualGraphOptions(
tracer.graph,
tracer,
highlight_node=None,
show_implied_mutations=False,
show_views=False,
show_artifacts=True,
show_variables=True,
)
return cls(options)
[docs] @classmethod
def for_public_node(cls, graph: Graph, node_id: str) -> Visualizer:
"""
Create a graph for our public API, when showing a single node.
Note: The tracer won't be passed in this case, since it is happening
inside the executor and we don't have access to the tracer.
"""
options = VisualGraphOptions(
graph,
tracer=None,
highlight_node=node_id,
show_implied_mutations=False,
show_views=False,
show_artifacts=False,
show_variables=False,
)
return cls(options)