Skip to content

graph_printer

GraphPrinter dataclass

Pretty prints a graph, in a similar way as how you would create it by hand.

This representation should be consistent despite UUIDs being different.

Source code in lineapy/graph_reader/graph_printer.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
@dataclass
class GraphPrinter:
    """
    Pretty prints a graph, in a similar way as how you would create it by hand.

    This representation should be consistent despite UUIDs being different.
    """

    graph: Graph

    # Whether to include the source locations from the graph
    include_source_location: bool = field(default=True)
    # Whether to print the ID fields for nodes
    include_id_field: bool = field(default=True)
    # Whether to include the session
    include_session: bool = field(default=True)
    # Whether to include the imports needed to run the file
    include_imports: bool = field(default=False)
    # Whether to include timing information
    include_timing: bool = field(default=True)

    # Set to True to nest node strings, when they only have one successor.
    nest_nodes: bool = field(default=True)
    id_to_attribute_name: Dict[LineaID, str] = field(default_factory=dict)

    # Mapping of each node types to the count of nodes of that type printed
    # so far to create variables based on node type.
    node_type_to_count: Dict[NodeType, int] = field(
        default_factory=lambda: collections.defaultdict(lambda: 0)
    )

    source_code_count: int = field(default=0)

    def print(self) -> str:
        return "\n".join(self.lines())

    def get_node_type_count(self, node_type: NodeType) -> int:
        prev = self.node_type_to_count[node_type]
        next = prev + 1
        self.node_type_to_count[node_type] = next
        return next

    def get_node_type_name(self, node_type: NodeType) -> str:
        return f"{pretty_print_node_type(node_type)}_{self.get_node_type_count(node_type)}"

    def get_cached_node_type_name(self, node: Node) -> str:
        if node.id not in self.id_to_attribute_name:
            attr_name = self.get_node_type_name(node.node_type)
            self.id_to_attribute_name[node.id] = attr_name
            return attr_name
        else:
            return self.id_to_attribute_name[node.id]

    def lines(self) -> Iterable[str]:
        if self.include_imports:
            yield "import datetime"
            yield "from pathlib import *"
            yield "from lineapy.data.types import *"
            yield "from lineapy.utils.utils import get_new_id"
        if self.include_session:
            yield "session = ("
            yield from self.pretty_print_model(self.graph.session_context)
            yield ")"

        for node in self.graph.visit_order():
            node_id = node.id
            attr_name = self.get_cached_node_type_name(node)

            # If the node has source code, and we haven't printed it before
            # print that first so it will just reference
            source_location = node.source_location
            if (
                source_location
                and source_location.source_code.id
                not in self.id_to_attribute_name
            ):
                self.source_code_count += 1
                name = f"source_{self.source_code_count}"
                self.id_to_attribute_name[
                    source_location.source_code.id
                ] = name
                yield f"{name} = ("
                yield from self.pretty_print_model(source_location.source_code)
                yield ")"
            # If the node only has one successor, then save its body
            # as the attribute name, so its inlined when accessed.
            if node.node_type == NodeType.ImportNode:
                # if node.name == "lineapy":  # type: ignore
                # do not track version change, pin to 0.0.1
                node.version = ""  # type: ignore

            if (
                self.nest_nodes
                and len(list(self.graph.nx_graph.successors(node_id))) == 1
            ):
                self.id_to_attribute_name[node_id] = "\n".join(
                    self.pretty_print_model(node)
                )

            else:
                yield f"{attr_name} = ("
                yield from self.pretty_print_model(node)
                yield ")"
                self.id_to_attribute_name[node_id] = attr_name

    def pretty_print_model(self, model: BaseModel) -> Iterable[str]:
        yield f"{type(model).__name__}("
        yield from self.pretty_print_node_lines(model)
        yield ")"

    def lookup_id(self, id: LineaID) -> str:
        if id in self.id_to_attribute_name:
            return self.id_to_attribute_name[id] + ".id"
        return repr(id)

    def pretty_print_node_lines(self, node: BaseModel) -> Iterable[str]:
        for k in node.__fields__.keys():
            v = getattr(node, k)

            # Ignore nodes that are none
            if v is None:
                continue

            field = node.__fields__[k]
            tp = field.type_
            shape = field.shape
            v_str: str
            if k == "node_type":
                continue
            if k == "source_location" and not self.include_source_location:
                continue
            if k == "id" and not self.include_id_field:
                continue
            if k == "session_id" and not self.include_session:
                continue
            # don't print empty args, kwargs, or reads
            if (
                k
                in {
                    "positional_args",
                    "keyword_args",
                    "global_reads",
                    "implicit_dependencies",
                }
            ) and not v:
                continue
            if tp == LineaID and shape == SHAPE_LIST:
                args = [self.lookup_id(id_) for id_ in v]
                # Arguments are unordered and we need to sort them to
                #   make sure that the diffing do not create false negatives
                v_str = "[" + ", ".join(args) + "]"
            elif tp == PositionalArgument and shape == SHAPE_LIST:
                # special case for positional arguments here because we added starred args support.
                # the only difference will be an appearance of a star in front of the node reference
                # eg positional_args = [callnode.id] vs positional_args = [*callnode.id]
                args = [
                    id_.starred * "*" + str(self.lookup_id(id_.id))
                    for id_ in v
                ]
                # Arguments are unordered and we need to sort them to
                #   make sure that the diffing do not create false negatives
                v_str = "[" + ", ".join(args) + "]"
            elif tp == KeywordArgument and shape == SHAPE_LIST:
                # Sort kwargs on printing for consistent ordering
                args = [
                    f"{repr(kwa.key)}: {self.lookup_id(kwa.value)}"
                    for kwa in sorted(v, key=lambda x: x.key)
                ]
                v_str = "{" + ", ".join(args) + "}"
            elif tp == LineaID and shape == SHAPE_DICT:
                # Sort kwargs on printing for consistent ordering
                args = [
                    f"{repr(k)}: {self.lookup_id(id_)}"
                    for k, id_ in sorted(
                        cast(Dict[str, LineaID], v).items(), key=lambda x: x[0]
                    )
                ]
                v_str = "{" + ", ".join(args) + "}"
            # Singleton NewTypes get cast to str by pydantic, so we can't differentiate at the field
            # level between them and strings, so we just see if can look up the ID
            elif isinstance(v, str) and v in self.id_to_attribute_name:
                v_str = self.lookup_id(v)  # type: ignore
            elif isinstance(v, datetime.datetime) and not self.include_timing:
                continue
            else:
                v_str = "\n".join(self.pretty_print_value(v))
            yield f"{k}={v_str},"

    def pretty_print_value(self, v: object) -> Iterable[str]:
        if isinstance(v, SourceCode):
            yield self.lookup_id(v.id)
        elif isinstance(v, enum.Enum):
            yield f"{type(v).__name__}.{v.name}"
        elif isinstance(v, BaseModel):
            yield from self.pretty_print_model(v)
        elif isinstance(v, list):
            yield "["
            for x in v:
                yield from self.pretty_print_value(x)
                yield ","
            yield "]"
        elif isinstance(v, str):
            v_lid = LineaID(v)
            node = self.graph.get_node(v_lid)
            if node is not None:
                attr_name = self.get_cached_node_type_name(node)
                self.id_to_attribute_name[v_lid] = attr_name
                yield self.lookup_id(v_lid)
            else:
                yield pretty_print_str(v)
        else:
            value = repr(v)
            # Try parsing as Python code, if we can't, then wrap in string.
            try:
                compile(value, "", "exec")
            except SyntaxError:
                value = repr(value)
            yield value

pretty_print_node_type(type)

Turns a node type into something that can be used as a variable name.

Source code in lineapy/graph_reader/graph_printer.py
257
258
259
260
261
def pretty_print_node_type(type: NodeType) -> str:
    """
    Turns a node type into something that can be used as a variable name.
    """
    return camel_to_snake_case(type.name.replace("Node", ""))

pretty_print_str(s)

Pretty prints a string, so that if it has a newline, prints it as a triple quoted string.

Source code in lineapy/graph_reader/graph_printer.py
246
247
248
249
250
251
252
253
254
def pretty_print_str(s: str) -> str:
    """
    Pretty prints a string, so that if it has a newline, prints it as a triple
    quoted string.
    """
    if "\n" in s:
        string_escape_single_quote = s.replace("'", "\\'")
        return f"'''{string_escape_single_quote}'''"
    return repr(s)

Was this helpful?

Help us improve docs with your feedback!