Source code for lineapy.execution.globals_dict

from __future__ import annotations

import builtins
from dataclasses import dataclass, field
from typing import Dict, List, Optional


[docs]class GlobalsDict(Dict[str, object]): """ A custom dict that is meant to be accessed in a particular way, in order to record getitems. It is used for setting as the globals when executing some code, so we can try to understand which globals were accessed. It is meant to be used like: 1. Instantiate it empty like `GlobalsDict()` 2. Call `setup_globals(d)` to update it with the input globals 3. Execute some code that uses it as globals, which will call `__setitem__` as well as our custom `__getitem__`. 4. Call `teardown_globals()` which will return the `Result`, containing the a record of all the original globals that were accessed and any new globals that were updated or added. We cannot overload the `__setitem__` method, since Python will not respect it for custom globals, but we can overload the __getitem__ method. See https://stackoverflow.com/a/12185315/907060 which refers to https://bugs.python.org/issue14385 """ def __init__(self): self._state: Optional[State] = None super().__init__() def __getitem__(self, k): v = super().__getitem__(k) if not self._state: raise RuntimeError("GlobalsDict not setup") self._state.process_getitem(k, v) return v def setup_globals(self, inputs: Dict[str, object]) -> None: self._state = State(inputs) self.update(inputs) self["__builtins__"] = builtins def teardown_globals(self) -> GlobalsDictResult: if not self._state: raise RuntimeError("GlobalsDict not setup") state = self._state # Calculate what globals have changed or have been added. Compare by pointer, # not by value, since we want to see if the global variable has been re-assigned # not if the value has been mutated changed_globals = { k: v for k, v, in self.items() if k != "__builtins__" and ( # The global was changed if it is new, i.e. was not in the our variables k not in state.inputs # Or if it is different or state.inputs[k] is not v ) } self._state = None self.clear() return GlobalsDictResult(state.accessed_inputs, changed_globals)
[docs]@dataclass class State: # The mapping of input globals inputs: Dict[str, object] # A subset of the input globals, containing only the keys that were accessed # from it accessed_inputs: List[str] = field(default_factory=list)
[docs] def process_getitem(self, k: str, v: object) -> None: """ If we haven't recorded this key and its value is the same as the value in the input globals (meaning we haven't overwritten it), then record it as a getitem. """ if ( k != "__builtins__" and k not in self.accessed_inputs and k in self.inputs and self.inputs[k] is v ): self.accessed_inputs.append(k)
[docs]@dataclass class GlobalsDictResult: accessed_inputs: List[str] added_or_modified: Dict[str, object]