Source code for lineapy.execution.inspect_function

from __future__ import annotations

import glob
import logging
import os
import sys
from collections import defaultdict
from dataclasses import dataclass, field
from io import IOBase
from types import ModuleType
from typing import Callable, Dict, Hashable, Iterable, List, Optional, Tuple

import yaml
from pydantic import ValidationError

from lineapy.instrumentation.annotation_spec import (
    AllPositionalArgs,
    Annotation,
    BoundSelfOfFunction,
    ClassMethodName,
    ClassMethodNames,
    Criteria,
    ExternalState,
    FunctionName,
    FunctionNames,
    InspectFunctionSideEffect,
    KeywordArgument,
    KeywordArgumentCriteria,
    ModuleAnnotation,
    MutatedValue,
    PositionalArg,
    Result,
    ValuePointer,
    ViewOfValues,
)

logger = logging.getLogger(__name__)

"""
helper functions
"""


[docs]def is_mutable(obj: object) -> bool: """ Returns true if the object is mutable. Note that currently, `tempfile.NamedTemporaryFile()` is not mutable, and the semantics is actually correct, because it doesn't end up changing the file system. However, the following registers as normal files (which are mutable). ```python filename = NamedTemporaryFile().name handle = open(filename, "wb") ``` """ # We have to special case any types which are hashable, but are mutable. # Since there is no way to see if a class is mutable a priori, we could add a list of types # like this to our annotations mutable_hashable_types: Tuple[type, ...] = ( ModuleType, type, type(iter([])), IOBase, ) if "sklearn.base" in sys.modules: mutable_hashable_types += (sys.modules["sklearn.base"].BaseEstimator,) # type: ignore # Special case some mutable hashable types if isinstance(obj, mutable_hashable_types): return True # Otherwise assume all hashable objects are immutable try: hash(obj) except Exception: return True return False
[docs]def validate(item: Dict) -> Optional[ModuleAnnotation]: """ We cannot filer the specs by module, because it might be loaded later. This causes a bit of inefficiency in our function inspection, but we can fix later if it's a problem. """ try: spec = ModuleAnnotation(**item) return spec except ValidationError as e: # want to warn the user but not break the whole thing logger.warning( f"Validation failed parsing {item} as annotation spec: {e}" ) return None
[docs]def get_specs() -> Dict[str, List[Annotation]]: """ yaml specs are for non-built in functions. Captures all the .annotations.yaml files in the lineapy directory. """ relative_path = "../*.annotations.yaml" path = os.path.join(os.path.dirname(__file__), relative_path) valid_specs: Dict[str, List[Annotation]] = defaultdict(list) for filename in glob.glob(path): with open(filename, "r") as f: doc = yaml.safe_load(f) for item in doc: v = validate(item) if v is None: continue valid_specs[v.module].extend(v.annotations) return valid_specs
[docs]def new_side_effect_without_all_positional_arg( side_effect: ViewOfValues, args: list, ) -> ViewOfValues: """ This method must NOT modify the original side_effect, since these annotations are dependent on the runtime values that are different for each call---AllPositionalArgs will have a different set of arguments. Note that we might need to add something like "all keyword arguments", but that use case hasn't come up yet. """ new_side_effect = ViewOfValues(views=[]) for view in side_effect.views: new_side_effect.views.append(view.copy(deep=True)) for i, v in enumerate(new_side_effect.views): if isinstance(v, AllPositionalArgs): new_side_effect.views.pop(i) new_side_effect.views.extend( ( PositionalArg(positional_argument_index=i) for i, a in enumerate(args) ) ) return new_side_effect return new_side_effect
def process_side_effect( side_effect: InspectFunctionSideEffect, args: list, kwargs: dict[str, object], result: object, ) -> Optional[InspectFunctionSideEffect]: def is_reference_mutable(p: ValuePointer) -> bool: if isinstance(p, Result): return is_mutable(result) if isinstance(p, PositionalArg): if len(args) > p.positional_argument_index: return is_mutable(args[p.positional_argument_index]) return False if isinstance(p, BoundSelfOfFunction) or isinstance(p, ExternalState): return True # object if isinstance(p, KeywordArgument): return is_mutable(kwargs[p.argument_keyword]) raise Exception(f"ValuePointer {p} of type {type(p)} not handled.") if isinstance(side_effect, ViewOfValues): new_side_effect = new_side_effect_without_all_positional_arg( side_effect, args ) new_side_effect.views = list( filter(lambda x: is_reference_mutable(x), new_side_effect.views) ) # If we don't have at least two items to view each other, skip this one if len(new_side_effect.views) < 2: return None return new_side_effect if isinstance(side_effect, MutatedValue): if is_reference_mutable(side_effect.mutated_value): return side_effect return None return side_effect
[docs]@dataclass class FunctionInspectorParsed: """ Contains the parsed function inspector criteria. """ # Function criteria function_to_side_effects: Dict[ Callable, List[InspectFunctionSideEffect] ] = field(default_factory=lambda: defaultdict(list)) # Method criteria method_name_to_type_to_side_effects: Dict[ str, Dict[type, List[InspectFunctionSideEffect]] ] = field(default_factory=lambda: defaultdict(lambda: defaultdict(list))) # Method keyword argument criteria keyword_name_and_value_to_type_to_side_effects: Dict[ Tuple[str, Hashable], Dict[type, List[InspectFunctionSideEffect]] ] = field(default_factory=lambda: defaultdict(lambda: defaultdict(list)))
[docs] def inspect( self, fn: Callable, kwargs: Dict[str, object] ) -> Optional[List[InspectFunctionSideEffect]]: """ Inspect a function call and return a list of side effects, if it matches any of the annotations """ # We assume a function is a method if it has a __self__ and the __self__ is not a Module # Note that for functions defines in C, like `setitem`, they have a __self__, but it's the # module they were defined in, in `setitems` case, `operator`, so that's why we need the isinstance # check obj = getattr(fn, "__self__", None) is_method = obj is not None and not isinstance(obj, ModuleType) # If it's a function, we just do a simple lookup to see if it's exactly equal to any functions we saved if not is_method: return self.function_to_side_effects.get(fn, None) # If it's a class instance however, we have to consider superclasses, so we first do a lookup # on the name, then check for isinstance method_name = fn.__name__ for tp, side_effects in self.method_name_to_type_to_side_effects[ method_name ].items(): if isinstance(obj, tp): return side_effects # Finally, if we haven't found something yet, try the keyword names mapping on the method for k, v in kwargs.items(): # Ignore any non hashable keyword args we pass in if not isinstance(v, Hashable): continue # type: ignore for ( tp, side_effects, ) in self.keyword_name_and_value_to_type_to_side_effects[ (k, v) ].items(): if isinstance(obj, tp): return side_effects return None
[docs] def add_annotations( self, module: ModuleType, annotations: List[Annotation] ) -> None: """ Parse a list of annotations and look them up to add them to our parsed criteria. """ for annotation in annotations: self._add_annotation( module, annotation.criteria, annotation.side_effects )
def _add_annotation( self, module: ModuleType, criteria: Criteria, side_effects: List[InspectFunctionSideEffect], ) -> None: if isinstance(criteria, KeywordArgumentCriteria): class_ = getattr(module, criteria.class_instance) self.keyword_name_and_value_to_type_to_side_effects[ (criteria.keyword_arg_name, criteria.keyword_arg_value) ][class_] = side_effects elif isinstance(criteria, FunctionNames): for name in criteria.function_names: fn = getattr(module, name) self.function_to_side_effects[fn] = side_effects elif isinstance(criteria, FunctionName): fn = getattr(module, criteria.function_name) self.function_to_side_effects[fn] = side_effects elif isinstance(criteria, ClassMethodName): tp = getattr(module, criteria.class_instance) self.method_name_to_type_to_side_effects[ criteria.class_method_name ][tp] = side_effects elif isinstance(criteria, ClassMethodNames): tp = getattr(module, criteria.class_instance) for name in criteria.class_method_names: self.method_name_to_type_to_side_effects[name][ tp ] = side_effects else: raise NotImplementedError(criteria)
[docs]@dataclass class FunctionInspector: """ The FunctionInspector does two different loading steps. 1. Load all the specs from disk with `get_specs`. This happens once on creation of the object. 2. On initialization, and before every spec call, go through all the specs and "parse" any for modules we have already imported, which means turning the criteria into in memory objects, we can compare against when inspecting. """ # Dictionary contains all the specs we haven't parsed yet, because they correspond to un-imported modules specs: Dict[str, List[Annotation]] = field(default_factory=get_specs) # Annotations we have already parsed, since we have already imported these modules. parsed: FunctionInspectorParsed = field( default_factory=FunctionInspectorParsed ) def _parse(self) -> None: """ Parses all specs which are for modules we have imported """ for module_name in list(self.specs.keys()): module = get_imported_module(module_name) if not module: continue self.parsed.add_annotations( module, # Pop the spec once we have processed it self.specs.pop(module_name), ) def __post_init__(self): self._parse()
[docs] def inspect( self, function: Callable, args: list[object], kwargs: dict[str, object], result: object, ) -> Iterable[InspectFunctionSideEffect]: """ Inspects a function and returns how calling it mutates the args/result and creates view relationships between them. """ # Try re-parsing during each function call, in case other modules were imported we can analyse self._parse() side_effects = self.parsed.inspect(function, kwargs) or [] for side_effect in side_effects: processed_side_effect = process_side_effect( side_effect, args, kwargs, result ) if processed_side_effect: yield processed_side_effect
[docs]def get_imported_module(name: str) -> Optional[ModuleType]: """ Return a module, if it has been imported. Also handles the corner case where a submodule has not been imported, but is accessible as an attribute on the parent module. This is needed for the example `tensorflow.keras.utils`, which is not imported when importing `tensorflow`, but is accessible as a property of `tensorflow`. """ if name in sys.modules: return sys.modules[name] *parent_names, submodule_name = name.split(".") if not parent_names: return None parent_module = get_imported_module(".".join(parent_names)) if not parent_module: return None return getattr(parent_module, submodule_name, None)