Source code for lineapy.data.graph

from queue import PriorityQueue, Queue
from typing import Callable, Dict, Iterator, List, Optional, Set, TypeVar

import networkx as nx

from lineapy.data.types import LineaID, Node, SessionContext
from lineapy.graph_reader.graph_printer import GraphPrinter
from lineapy.utils.utils import listify, prettify


class Graph(object):
    def __init__(self, nodes: List[Node], session_context: SessionContext):
        """
        Graph represents a
        It is constructed based on the following variables:
        :param nodes: a list of nodes and the session context in wh
        :param session_context: the session context associated with the graph

        NOTE:
        # TODO: Possibly remove session context since we aren't using it anywhere
        # we refer to the graph
        - It makes sense to include session_context in the constructor of
          the graph because the information in session_context is semantically
          important to the notion of a Graph. Concretely, we are starting
          to also use the code entry from the session_context.
        """
        self.nodes: List[Node] = nodes
        self.ids: Dict[LineaID, Node] = dict((n.id, n) for n in nodes)
        self.nx_graph = nx.DiGraph()
        self.nx_graph.add_nodes_from([node.id for node in nodes])

        self.nx_graph.add_edges_from(
            [
                (parent_id, node.id)
                for node in nodes
                for parent_id in node.parents()
            ]
        )

        self.session_context = session_context

        # validation
        if not nx.is_directed_acyclic_graph(self.nx_graph):
            raise AssertionError("Graph should not be cyclic")

    def __eq__(self, other) -> bool:
        return nx.is_isomorphic(self.nx_graph, other.nx_graph)

    def print(self, **kwargs) -> str:
        return GraphPrinter(self, **kwargs).print()

    @listify
    def visit_order(self) -> Iterator[Node]:
        """
        Just using the line number as tie breaker for now since we don't have
          a good way to track dependencies
          Note that we cannot just use the line number to sort because
            there are nodes created by us that do not have line numbers...
        """
        # TODO: See if we could replace this with python's built in topological sort
        # https://docs.python.org/3/library/graphlib.html
        # It seems to suggest that resulting order is determined by insertion order
        # so this could possibly also implicitly sort by line number, if we insert
        # in that order.

        # Generally, we want to traverse the graph in a way to maintain two
        # constraints:

        # 1. All parents must be traversed before their children
        # 2. If we have any freedom, those with earlier line number should come first

        # To do this, we do a breadth first traversal, keeping our queue ordered
        # by their line number. The sorting is done via the __lt__ method
        # of the Node
        queue: PriorityQueue[Node] = PriorityQueue()

        # We also keep track of all nodes we have already added to the queue
        # so that we don't add them again
        seen: Set[LineaID] = set()

        # We also keep a mapping of each node to the number of parents left
        # which have not been visited yet.
        # Note that we want to skip counting parents which are not part of our nodes
        # This can happen we evaluate part of a graph, then another part.
        # When evaluating the next part, we just have those nodes, so some
        # of the parents will be missing, we assume they are already executed
        remaining_parents: Dict[str, int] = {}

        for node in self.nodes:
            n_remaining_parents = len(
                [
                    parent_id
                    for parent_id in self.nx_graph.pred[node.id]
                    if parent_id in self.ids
                ]
            )
            # First we add all of the nodes to the queue which have no parents

            if n_remaining_parents == 0:
                seen.add(node.id)
                queue.put(node)
            remaining_parents[node.id] = n_remaining_parents

        while queue.qsize():
            # Find the first node in the queue which has all its parents removed
            node = queue_get_when(
                queue, lambda n: remaining_parents[n.id] == 0
            )

            # Then, we add all of its children to the queue, making sure to mark
            # for each that we have seen one of its parents
            yield node
            for child_id in self.get_children(node.id):
                remaining_parents[child_id] -= 1
                if child_id in seen:
                    continue
                child_node = self.ids[child_id]
                queue.put(child_node)
                seen.add(child_id)

    def get_parents(self, node_id: LineaID) -> List[LineaID]:
        return list(self.nx_graph.predecessors(node_id))

    def get_ancestors(self, node_id: LineaID) -> List[LineaID]:
        return list(nx.ancestors(self.nx_graph, node_id))

    def get_children(self, node_id: LineaID) -> List[LineaID]:
        return list(self.nx_graph.successors(node_id))

    def get_descendants(self, node_id: LineaID) -> List[LineaID]:
        return list(nx.descendants(self.nx_graph, node_id))

    def get_leaf_nodes(self) -> List[LineaID]:
        return [
            node
            for node in self.nx_graph.nodes
            if self.nx_graph.out_degree(node) == 0
        ]

    def get_node(self, node_id: Optional[LineaID]) -> Optional[Node]:
        if node_id is not None and node_id in self.ids:
            return self.ids[node_id]
        return None

    def get_subgraph(self, nodes: List[Node]) -> "Graph":
        """
        FIXME
        """
        return Graph(nodes, self.session_context)

    def __str__(self):
        return prettify(
            self.print(
                include_source_location=False,
                include_id_field=True,
                include_session=False,
            )
        )

    def __repr__(self):
        return prettify(self.print())


T = TypeVar("T")


[docs]def queue_get_when(queue: "Queue[T]", filter_fn: Callable[[T], bool]) -> T: """ Gets the first element in the queue that satisfies the filter function. """ # We have to pop off a number of elements, stopping when we find one that # satisfies our conditional, since we can't iterate through a queue. # Use a timeout of 0 for the gets, otherwise if we have some bug # where we are trying to get off the queue and its empty it will just # block forever. with a timeout of 0, it will raise an exception instead. popped_off = [queue.get(timeout=0)] while not filter_fn(popped_off[-1]): popped_off.append(queue.get(timeout=0)) *add_back_to_queue, found = popped_off for tmp_node in add_back_to_queue: queue.put(tmp_node) return found