Source code for lineapy.execution.context

"""
This module contains a number of globals, which are set by the execution when 
it is processing call nodes.

They are used as a side channel to pass values from the executor to special functions
which need to know more about the execution context, like in the `exec` to know
the source code of the current node.

This module exposes three global functions, which are meant to be used like:

1. The `executor` calls `set_context` before executing every call node.
2. The function being called can call `get_context` to get the current context.
3. The `executor` calls `teardown_context` after its finished executing

I.e. the context is created for every call.
"""
from __future__ import annotations

from dataclasses import dataclass, field
from types import ModuleType
from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional

from lineapy.execution.globals_dict import GlobalsDict, GlobalsDictResult
from lineapy.execution.inspect_function import is_mutable
from lineapy.execution.side_effects import (
    ID,
    AccessedGlobals,
    ExecutorPointer,
    MutatedNode,
    SideEffect,
    SideEffects,
    Variable,
    ViewOfNodes,
)
from lineapy.system_tracing.function_call import FunctionCall
from lineapy.system_tracing.function_calls_to_side_effects import (
    function_calls_to_side_effects,
)
from lineapy.utils.analytics import ExceptionEvent, track

if TYPE_CHECKING:
    from lineapy.data.types import CallNode, LineaID
    from lineapy.execution.executor import Executor
"""
Use the same globals for all executions, so that a function created during 
  one execution will have the same globals as when it is later called, 
  so we can see what globals have been written during that.
  Consider the following motivating example:
```python
a = 10
def f():
    print(a)
a = 15
f() # should print 15, instead of 10
```
"""

_global_variables: GlobalsDict = GlobalsDict()
_current_context: Optional[ExecutionContext] = None


[docs]@dataclass class ExecutionContext: """ This class is available during execution of CallNodes to the functions which are being called. It is used as a side channel to pass in metadata about the execution, such as the current node, and other global nodes (used during exec). The `side_effects` property is read after the function is finished, by the executor, so that the function can pass additional side effects that were triggered back to it indirectly. This is also used by the exec functions. """ # The current node being executed node: CallNode # The executor that is running executor: Executor # Mapping from each input global name to its ID _input_node_ids: Mapping[str, LineaID] # Mapping from each input global name to whether it is mutable _input_globals_mutable: Mapping[str, bool] # Mapping of input node IDs to their values. # Used by the exec function to understand what side effects to emit, by knowing the nodes associated with each global value used. input_nodes: Mapping[LineaID, object] # Additional function calls made in this call, to be processed for side effects at the end. # The exec function will add to this and we will retrieve it at the end. function_calls: Optional[List[FunctionCall]] = field(default=None) @property def global_variables(self) -> Dict[str, object]: """ The current globals dictionary """ return _global_variables
[docs]def set_context( executor: Executor, variables: Optional[Dict[str, LineaID]], node: CallNode, ) -> None: """ Sets the context of the executor to the given node. """ global _current_context if _current_context: raise RuntimeError("Context already set") # Set up our global variables, by first clearing out the old, and then # by updating with our new inputs # Note: We need to save our inputs so that we can check what has changed # at the end assert not _global_variables # The first time this is run, variables is set, and we know # the scoping, so we set all of the variables we know. # The subsequent times, we only use those that were recorded input_node_ids = variables or node.global_reads global_name_to_value = { k: executor._id_to_value[id_] for k, id_ in input_node_ids.items() } _global_variables.setup_globals(global_name_to_value) global_node_id_to_value = { id_: executor._id_to_value[id_] for id_ in input_node_ids.values() } _current_context = ExecutionContext( _input_node_ids=input_node_ids, _input_globals_mutable={ # Don't consider modules or classes as mutable inputs, so that any code which uses a module # we assume it doesn't mutate it. k: is_mutable(v) and not isinstance(v, (ModuleType, type)) for k, v in global_name_to_value.items() }, node=node, executor=executor, input_nodes=global_node_id_to_value, )
def get_context() -> ExecutionContext: if not _current_context: track(ExceptionEvent("DBError", "No context set")) raise RuntimeError("No context set") return _current_context
[docs]def teardown_context() -> ContextResult: """ Tearsdown the context, returning the nodes that were accessed and a mapping variables to new values that were added or crated """ global _current_context if not _current_context: raise RuntimeError("No context set") res = _global_variables.teardown_globals() # If we didn't trace some function calls, use the legacy worst case assumptions for side effects if _current_context.function_calls is None: side_effects = list(_compute_side_effects(_current_context, res)) else: # Compute the side effects based on the function calls that happened, to understand what input nodes # were mutated, what views were added, and what other side effects were created. side_effects = list( function_calls_to_side_effects( _current_context.executor._function_inspector, _current_context.function_calls, _current_context.input_nodes, res.added_or_modified, ) ) if res.accessed_inputs or res.added_or_modified: # Record that this execution accessed and saved certain globals, as first side effect side_effects.insert( 0, AccessedGlobals( res.accessed_inputs, list(res.added_or_modified.keys()), ), ) _current_context = None return ContextResult(res.added_or_modified, side_effects)
def _compute_side_effects( context: ExecutionContext, globals_result: GlobalsDictResult ) -> Iterable[SideEffect]: """ This is the legacy worst case side effect computation and is applied when settrace's bytecode is not supported. Currently, anything related to generators is not supported, e.g., ```python f(*x) # if we read the `next` function it will exhaust the generator and change # the semantics of the code. ``` We can remove once we support tracking globals updated during calling user defined functions. """ # Any nodes that we retrieved that were mutable, assume were mutated # Filter the vars by if the value is mutable mutable_input_vars: List[ExecutorPointer] = [ ID(context._input_node_ids[name]) for name in globals_result.accessed_inputs if context._input_globals_mutable[name] ] yield from map(MutatedNode, mutable_input_vars) # Now we mark all the mutable input variables as well as mutable # output variables as all views of one another # Note that we retrieve the input IDs before executing, # and refer to the output IDs as variables after executing # This will mean any accessed mutable variables will resolve # to the node they pointed to before the function call # and mutated ones will refer to the new GlobalNodes that # were created when processing `AccessedGlobals` mutable_output_vars: List[ExecutorPointer] = [ Variable(k) for k, v in globals_result.added_or_modified.items() if is_mutable(v) ] if mutable_input_vars or mutable_output_vars: yield ViewOfNodes(mutable_input_vars + mutable_output_vars)
[docs]@dataclass class ContextResult: # Mapping of global name to value for every global which was added or updated added_or_modified: Dict[str, object] # List of side effects added during the execution. side_effects: SideEffects