Source code for lineapy.execution.executor

from __future__ import annotations

import builtins
import logging
import operator
from dataclasses import dataclass, field
from datetime import datetime
from os import chdir, getcwd
from typing import (
    Callable,
    Dict,
    Hashable,
    Iterable,
    List,
    Optional,
    Tuple,
    cast,
)

from lineapy.data.graph import Graph
from lineapy.data.types import (
    CallNode,
    Execution,
    GlobalNode,
    ImportNode,
    LineaID,
    LiteralNode,
    LookupNode,
    MutateNode,
    Node,
)
from lineapy.db.db import RelationalLineaDB
from lineapy.editors.ipython_cell_storage import get_location_path
from lineapy.exceptions.db_exceptions import ArtifactSaveException
from lineapy.exceptions.l_import_error import LImportError
from lineapy.exceptions.user_exception import (
    AddFrame,
    RemoveFrames,
    RemoveFramesWhile,
    TracebackChange,
    UserException,
)
from lineapy.execution.context import set_context, teardown_context
from lineapy.execution.inspect_function import FunctionInspector
from lineapy.execution.side_effects import (
    ID,
    ExecutorPointer,
    ImplicitDependencyNode,
    MutatedNode,
    SideEffect,
    ViewOfNodes,
)
from lineapy.instrumentation.annotation_spec import (
    BoundSelfOfFunction,
    ExternalState,
    ImplicitDependencyValue,
    InspectFunctionSideEffect,
    KeywordArgument,
    MutatedValue,
    PositionalArg,
    Result,
    ValuePointer,
    ViewOfValues,
)
from lineapy.utils.lineabuiltins import LINEA_BUILTINS
from lineapy.utils.utils import get_new_id

try:
    from functools import singledispatchmethod  # type: ignore
except ImportError:  # pragma: no cover
    # this is the fallback for python < 3.8
    # https://stackoverflow.com/questions/24601722
    from lineapy.utils.deprecation_utils import (  # type: ignore
        singledispatchmethod,
    )


logger = logging.getLogger(__name__)


# Need to define first in file, even though private, since used as type param
# for single dispatch decorator and that requires typings to resolve
[docs]@dataclass class PrivateExecuteResult: value: object start_time: datetime end_time: datetime side_effects: List[SideEffect]
[docs]@dataclass class Executor: """ An executor that is responsible for executing a graph, either node by node as it is created, or in a batch, after the fact. To use the executor, you first instantiate it. Then you can execute nodes, by calling `execute_node`. This returns a list of side effects that executing that node causes. You can also query for the time a node took to execute or its value, using `get_value` and `get_execution_time`. """ # The database to use for saving the execution db: RelationalLineaDB # The globals for this execution, to use when trying to lookup a value # Note: This is set in Jupyter so that `get_ipython` is defined _globals: dict[str, object] # The __file__ for the module being executed # https://docs.python.org/3/reference/import.html#file__ module_file: Optional[str] = None # The execution to record the values in # This is accessed via the ExecutionContext, which is set when executing a node # so that artifacts created during the execution know which execution they should refer to. execution: Execution = field(init=False) _function_inspector: FunctionInspector = field( default_factory=FunctionInspector ) _id_to_value: dict[LineaID, object] = field(default_factory=dict) _execution_time: dict[LineaID, Tuple[datetime, datetime]] = field( default_factory=dict ) # Mapping of bound method node ids to the ID of the instance they are bound to _node_to_bound_self: Dict[LineaID, LineaID] = field(default_factory=dict) # Mapping of call node to the values of the globals that were updated # Saved so that when we get these globals, we know their values # TODO: rename to variable _node_to_globals: Dict[LineaID, Dict[str, object]] = field( default_factory=dict ) # Mapping of values to their nodes. Currently the only values # in here are external state values _value_to_node: Dict[Hashable, LineaID] = field(default_factory=dict) def __post_init__(self): self.execution = Execution( id=get_new_id(), timestamp=datetime.now(), ) self.db.write_execution(self.execution)
[docs] def get_execution_time( self, node_id: LineaID ) -> Tuple[datetime, datetime]: """ Returns the (startime, endtime) for a node that was execute. Only applies for function call nodes. """ return self._execution_time[node_id]
[docs] def get_value(self, node_id: LineaID) -> object: """ Gets the Python in memory value for a node which was already executed. """ return self._id_to_value[node_id]
[docs] def execute_node( self, node: Node, variables: Optional[Dict[str, LineaID]] = None ) -> Iterable[SideEffect]: """ Variables is the mapping from local variable names to their nodes. It is passed in on the first execution, but on re-executions it is empty. At that point we know which variables each call node depends on, since the first time we executed we captured that. Does the following: - Executes a node - And records - value (currently: only for call nodes and all call nodes) - execution time - Add a new frame to the stack to support error reporting. Without it, the traceback will be empty. - Returns the `SideEffects` of this node that's analyzed at runtime (hence in the executor). """ logger.debug("Executing node %s", node) # To use if we need to raise an exception and change the frame default_changes: List[AddFrame] = [] # If we know the source location, add that frame at the top if node.source_location: location = node.source_location.source_code.location default_changes.append( AddFrame( str(get_location_path(location).absolute()), node.source_location.lineno, ) ) res = self._execute(node, default_changes, variables) value = res.value self._id_to_value[node.id] = value self._execution_time[node.id] = res.start_time, res.end_time # If this is some external state node, save it by its value, # so we can look it up later if we try access it if isinstance(value, ExternalState): # If we already know about this node, add an implicit # dependency from the old version to the new one if value in self._value_to_node: # However, don't add any edges for mutate nodes, since # they already should have it from the source if not isinstance(node, MutateNode): res.side_effects.append( ImplicitDependencyNode(ID(self._value_to_node[value])) ) # If this is a mutate node, then update the value to node to the new # value, so we always get the last one else: self._value_to_node[value] = node.id # Otherwise, this is the first time we are seeing it, so # add it to our lookup else: self._value_to_node[value] = node.id return res.side_effects
@singledispatchmethod def _execute( self, node: Node, changes: Iterable[TracebackChange], variables: Optional[Dict[str, LineaID]], ) -> PrivateExecuteResult: """ Executes a node, returning the resulting value, the start and end times, and any side effects """ raise NotImplementedError( f"Don't know how to execute node type {type(node)}" ) @_execute.register def _execute_lookup( self, node: LookupNode, changes: Iterable[TracebackChange], variables: Optional[Dict[str, LineaID]], ) -> PrivateExecuteResult: # If we get a lookup error, change it to a name error to match python try: start_time = datetime.now() value = self._lookup_value(node.name) end_time = datetime.now() except KeyError: # Matches Python's message---our execution causes a KeyError, # but for the same user code, vanilla Python would throw NameError # which is why we needed to change the error type. message = f"name '{node.name}' is not defined" raise UserException(NameError(message), *changes) return PrivateExecuteResult(value, start_time, end_time, []) @_execute.register def _execute_call( self, node: CallNode, changes: Iterable[TracebackChange], variables: Optional[Dict[str, LineaID]], ) -> PrivateExecuteResult: fn = cast(Callable, self._id_to_value[node.function_id]) # If we are getting an attribute, save the value in case # we later call it as a bound method and need to track its mutations # For example, for `a = [1]; a.append(2)` # We need to trace `a`, as opposed to `a.append` when tracking the # mutation. if fn is getattr: self._node_to_bound_self[node.id] = node.positional_args[0].id args: List[object] = [] for p_arg in node.positional_args: if p_arg.starred: args.extend(cast(Iterable, self._id_to_value[p_arg.id])) else: args.append(self._id_to_value[p_arg.id]) kwargs = {} for k in node.keyword_args: if k.starred: kwargs.update(cast(Dict, self._id_to_value[k.value])) else: kwargs.update({k.key: self._id_to_value[k.value]}) logger.debug("Calling function %s %s %s", fn, args, kwargs) # Set up our execution context, with our globals and node set_context(self, variables, node) try: start_time = datetime.now() res = fn(*args, **kwargs) end_time = datetime.now() except ArtifactSaveException: # keep the error stack if its artifact save raise except LImportError as exc: # Remove all importlib frames # There are a different number depending on whether the import # can be resolved filter = RemoveFramesWhile( lambda frame: frame.f_code.co_filename.startswith( "<frozen importlib" ) ) raise UserException( exc.__cause__, # type: ignore # Remove the first two frames, which are always there RemoveFrames(2), # Then filter all frozen importlib frames filter, *changes, ) except Exception as exc: # this is user error, so use the custom exception so we can clean # up our call stack raise UserException(exc, RemoveFrames(1), *changes) finally: logger.debug("Tearing down context") # Check what has been changed and accessed in the globals # Do this in a finally, so its always torn down even after exceptions globals_result = teardown_context() self._node_to_globals[node.id] = globals_result.added_or_modified """ Add all side effects from context as well as side effects from inspecting function. `side_effects` is an iterable, so that each translated side effect is resolved after the previous one has been executed (the control is yielded). Consider ``` with open("...", "w") as f ``` We create both the node that represents the side-effect, and a view node so the side-effect from the inspect_function in the executor, if the executor sees that there is no node (first time), then sends to tracer, tracer creates a lookup node, then give back to executor to process it. NOTE: we have a near term eng goal to refactor how side-effect is handled. """ logger.debug("Resolving side effects") side_effects = globals_result.side_effects + [ self._translate_side_effect(node, e) for e in self._function_inspector.inspect(fn, args, kwargs, res) ] return PrivateExecuteResult(res, start_time, end_time, side_effects)
[docs] def lookup_external_state(self, state: ExternalState) -> Optional[LineaID]: """ Returns the node ID if we have created a node already for some external state. Otherwise, returns None. """ return self._value_to_node.get(state, None)
@_execute.register def _execute_import( self, node: ImportNode, changes: Iterable[TracebackChange], variables: Optional[Dict[str, LineaID]], ) -> PrivateExecuteResult: # Dummy return PrivateExecuteResult( value=None, start_time=datetime.now(), end_time=datetime.now(), side_effects=[], ) @_execute.register def _execute_literal( self, node: LiteralNode, changes: Iterable[TracebackChange], variables: Optional[Dict[str, LineaID]], ) -> PrivateExecuteResult: return PrivateExecuteResult( node.value, datetime.now(), datetime.now(), [] ) @_execute.register def _execute_global( self, node: GlobalNode, changes: Iterable[TracebackChange], variables: Optional[Dict[str, LineaID]], ) -> PrivateExecuteResult: return PrivateExecuteResult( # An execute global is looking up a global set by a call node so, # Copy the result and the timing from the call node self._node_to_globals[node.call_id][node.name], *self._execution_time[node.call_id], [], ) @_execute.register def _execute_mutate( self, node: MutateNode, changes: Iterable[TracebackChange], variables: Optional[Dict[str, LineaID]], ) -> PrivateExecuteResult: return PrivateExecuteResult( # Copy the result and the timing from the source node self._id_to_value[node.source_id], *self._execution_time[node.call_id], [ViewOfNodes([ID(node.id), ID(node.source_id)])], )
[docs] def execute_graph(self, graph: Graph) -> None: """ Executes a graph in visit order making sure to setup the working directory first. TODO: Possibly move to graph instead of on executor, since it rather cleanly uses the executor's public API? Or move to function? """ logger.debug("Executing graph %s", graph) prev_working_dir = getcwd() chdir(graph.session_context.working_directory) for node in graph.visit_order(): self.execute_node(node, variables=None) chdir(prev_working_dir) # Add executed nodes to DB self.db.session.commit()
def _translate_pointer( self, node: CallNode, pointer: ValuePointer ) -> ExecutorPointer: """ Maps from a pointer output by the inspect function, to one output by the executor. """ if isinstance(pointer, PositionalArg): return ID( node.positional_args[pointer.positional_argument_index].id ) elif isinstance(pointer, KeywordArgument): # these come from annotation specs so should not need to worry about ** dicts for k in node.keyword_args: if k.key == pointer.argument_keyword: return ID(k.value) elif isinstance(pointer, Result): return ID(node.id) elif isinstance(pointer, BoundSelfOfFunction): return ID(self._node_to_bound_self[node.function_id]) elif isinstance(pointer, ExternalState): return pointer raise ValueError(f"Unknown pointer {pointer}, of type {type(pointer)}") def _translate_side_effect( self, node: CallNode, e: InspectFunctionSideEffect ) -> SideEffect: if isinstance(e, MutatedValue): return MutatedNode(self._translate_pointer(node, e.mutated_value)) elif isinstance(e, ImplicitDependencyValue): return ImplicitDependencyNode( self._translate_pointer(node, e.dependency) ) elif isinstance(e, ViewOfValues): return ViewOfNodes( [self._translate_pointer(node, ptr) for ptr in e.views] ) raise NotImplementedError( f"Unknown side effect {e}, of type {type(e)}" ) def _lookup_value(self, name: str) -> object: """ Lookup a value from a string identifier. """ if name == "__file__": if self.module_file: return self.module_file raise ValueError("No __file__ set") if hasattr(builtins, name): return getattr(builtins, name) if hasattr(operator, name): return getattr(operator, name) if name in LINEA_BUILTINS: return LINEA_BUILTINS[name] return self._globals[name]