Source code for lineapy.transformer.node_transformer

import ast
import logging
import sys
from pathlib import Path
from typing import Any, Iterable, Optional, cast

from lineapy.data.types import (
    CallNode,
    LiteralNode,
    Node,
    SourceCode,
    SourceCodeLocation,
    SourceLocation,
)
from lineapy.editors.ipython_cell_storage import get_location_path
from lineapy.exceptions.user_exception import RemoveFrames, UserException
from lineapy.instrumentation.tracer import Tracer
from lineapy.transformer.transformer_util import create_lib_attributes
from lineapy.utils.constants import (
    ADD,
    BITAND,
    BITOR,
    BITXOR,
    DEL_ATTR,
    DEL_ITEM,
    DIV,
    EQ,
    FLOORDIV,
    GET_ITEM,
    GETATTR,
    GT,
    GTE,
    IN,
    INVERT,
    IS,
    ISNOT,
    LSHIFT,
    LT,
    LTE,
    MATMUL,
    MOD,
    MULT,
    NEG,
    NOT,
    NOTEQ,
    POS,
    POW,
    RSHIFT,
    SET_ATTR,
    SET_ITEM,
    SUB,
)
from lineapy.utils.lineabuiltins import (
    l_alias,
    l_assert,
    l_dict,
    l_dict_kwargs_sentinel,
    l_exec_expr,
    l_exec_statement,
    l_list,
    l_unpack_ex,
    l_unpack_sequence,
)
from lineapy.utils.utils import get_new_id

logger = logging.getLogger(__name__)


[docs]def transform( code: str, location: SourceCodeLocation, tracer: Tracer ) -> Optional[Node]: """ Traces the given code, executing it and writing the results to the DB. It returns the node corresponding to the last statement in the code, if it exists. """ node_transformer = NodeTransformer(code, location, tracer) try: tree = ast.parse( code, str(get_location_path(location).absolute()), ) except SyntaxError as e: raise UserException(e, RemoveFrames(2)) if sys.version_info < (3, 8): from asttokens import ASTTokens from lineapy.transformer.source_giver import SourceGiver # if python version is 3.7 or below, we need to run the source_giver # to add the end_lineno's to the nodes. We do this in two steps - first # the asttoken lib does its thing and adds tokens to the nodes # and then we swoop in and copy the end_lineno from the tokens # and claim credit for their hard work ASTTokens(code, parse=False, tree=tree) SourceGiver().transform(tree) node_transformer.visit(tree) tracer.db.commit() return node_transformer.last_statement_result
[docs]class NodeTransformer(ast.NodeTransformer): """ .. note:: - Need to be careful about the order by which these calls are invoked so that the transformation do not get called more than once. """ def __init__( self, code: str, location: SourceCodeLocation, tracer: Tracer, ): self.source_code = SourceCode( id=get_new_id(), code=code, location=location ) tracer.db.write_source_code(self.source_code) self.tracer = tracer # Set __file__ to the pathname of the file if isinstance(location, Path): tracer.executor.module_file = str(location) # The result of the last line, a node if it was an expression, # None if it was a statement. Used by ipython to grab the last value self.last_statement_result: Optional[Node] = None def _get_code_from_node(self, node: ast.AST) -> Optional[str]: if sys.version_info < (3, 8): from lineapy.utils.deprecation_utils import get_source_segment return get_source_segment(self.source_code.code, node, padded=True) else: return ast.get_source_segment( self.source_code.code, node, padded=True )
[docs] def generic_visit(self, node: ast.AST): """ This will capture any generic blackboxes. Now that we have a clean scope handling, we can separate them out into two types: expressions that return something and statements that return nothing """ if isinstance( node, ast.stmt, ): if ( isinstance(node, ast.FunctionDef) and len(node.decorator_list) > 0 ): min_decorator_line = min( [decorator.lineno for decorator in node.decorator_list] ) # this might not be needed but adding in case older python has weirdness if min_decorator_line is not None: node.lineno = min(min_decorator_line, node.lineno) return self._exec_statement(node) elif isinstance(node, ast.expr): return self._exec_expression(node) else: raise NotImplementedError( f"Don't know how to transform {type(node).__name__}" )
[docs] def visit_Ellipsis(self, node: ast.Ellipsis) -> LiteralNode: """ Note ---- Deprecated in Python 3.8 """ if sys.version_info >= (3, 8): raise NotImplementedError( "Ellipsis nodes are deprecated since Python 3.8" ) else: return self.tracer.literal(..., self.get_source(node))
[docs] def visit_Str(self, node: ast.Str) -> LiteralNode: """ Note ---- Deprecated in Python 3.8 """ if sys.version_info >= (3, 8): raise NotImplementedError( "Str nodes are deprecated since Python 3.8" ) else: return self.tracer.literal(node.s, self.get_source(node))
[docs] def visit_Num(self, node: ast.Num) -> LiteralNode: """ Note ---- Deprecated in Python 3.8 """ if sys.version_info >= (3, 8): raise NotImplementedError( "Num nodes are deprecated since Python 3.8" ) else: return self.tracer.literal(node.n, self.get_source(node))
[docs] def visit_NameConstant(self, node: ast.NameConstant) -> LiteralNode: """ Note ---- Deprecated in Python 3.8 """ if sys.version_info >= (3, 8): raise NotImplementedError( "Num nodes are deprecated since Python 3.8" ) else: return self.tracer.literal(node.value, self.get_source(node))
# FIXME - this is deprecated def visit_Starred(self, node: ast.Starred) -> Iterable[LiteralNode]: elemlist: Iterable = [] if isinstance(node.value, ast.Constant): elemlist = cast(Iterable, node.value.value) elif isinstance(node.value, ast.Name): elemlist = cast(Iterable, self.tracer.values[node.value.id]) elif isinstance(node.value, ast.Str): elemlist = cast(Iterable, node.value.s) elem_nodes = [self.visit(ast.Constant(ele)) for ele in iter(elemlist)] yield from elem_nodes def visit_Raise(self, node: ast.Raise) -> None: return super().visit_Raise(node) def visit_Module(self, node: ast.Module) -> Any: for stmt in node.body: self.last_statement_result = self.visit(stmt) def visit_Expr(self, node: ast.Expr) -> Node: return self.visit(node.value) def visit_Assert(self, node: ast.Assert) -> None: args = [self.visit(node.test)] if node.msg: args.append(self.visit(node.msg)) self.tracer.call( self.tracer.lookup_node(l_assert.__name__), self.get_source(node), *args, )
[docs] def visit_Import(self, node: ast.Import) -> None: """ Similar to `visit_ImportFrom`, slightly different class syntax """ for lib in node.names: self.tracer.trace_import( lib.name, self.get_source(node), alias=lib.asname, )
def visit_ImportFrom(self, node: ast.ImportFrom) -> None: assert node.module self.tracer.trace_import( node.module, self.get_source(node), attributes=create_lib_attributes(node.names), )
[docs] def visit_Index(self, node: ast.Index) -> Node: """ Note ---- Deprecated in Python 3.9 """ if sys.version_info >= (3, 9): raise NotImplementedError( "Index nodes are deprecated in Python 3.9" ) else: return self.visit(node.value)
[docs] def visit_ExtSlice(self, node: ast.ExtSlice) -> Node: """ Note ---- Deprecated in Python 3.9 """ if sys.version_info >= (3, 9): raise NotImplementedError( "ExtSlice nodes are deprecated in Python 3.9" ) else: elem_nodes = [self.visit(elem) for elem in node.dims] return self.tracer.tuple( *elem_nodes, source_location=self.get_source(node), )
def visit_Name(self, node: ast.Name) -> Node: return self.tracer.lookup_node(node.id, self.get_source(node))
[docs] def visit_Call(self, node: ast.Call) -> Optional[CallNode]: """ Returns None if visiting special publish linea publish, which cannot be chained """ # this is the normal case, non-publish argument_nodes = [] for arg in node.args: # special case for starred, we need to unpack shit if isinstance(arg, ast.Starred): # for n in self.visit(arg): # argument_nodes.append(n) argument_nodes.append((True, self.visit(arg.value))) else: argument_nodes.append(self.visit(arg)) keyword_argument_nodes = { ( cast(str, arg.arg) if arg.arg is not None else f"unpack_{i}" ): self.visit(arg.value) for i, arg in enumerate(node.keywords) } function_node = self.visit(node.func) return self.tracer.call( function_node, self.get_source(node), *argument_nodes, **keyword_argument_nodes, )
def visit_Delete(self, node: ast.Delete) -> None: target = node.targets[0] if isinstance(target, ast.Name): raise NotImplementedError( "We do not support un-assigning a variable" ) elif isinstance(target, ast.Subscript): self.tracer.call( self.tracer.lookup_node(DEL_ITEM), self.get_source(node), self.visit(target.value), self.visit(target.slice), ) elif isinstance(target, ast.Attribute): self.tracer.call( self.tracer.lookup_node(DEL_ATTR), self.get_source(node), self.visit(target.value), self.visit(ast.Constant(value=target.attr)), ) else: raise NotImplementedError( f"We do not support deleting {type(target)}" ) def visit_Constant(self, node: ast.Constant) -> Node: return self.tracer.literal( node.value, self.get_source(node), )
[docs] def visit_Assign(self, node: ast.Assign) -> None: """ TODO ---- - None variable assignment, should be turned into a setattr call not an assignment, so we might need to change the return signature from ast.Expr. """ # target assignments are handled from left to right in Python # x = y = z -> x = z, y = z for target in node.targets: # handle special case of assigning aliases e.g. x = y if isinstance(target, ast.Name) and isinstance( node.value, ast.Name ): new_node = self.tracer.call( self.tracer.lookup_node(l_alias.__name__), self.get_source(node), self.visit(node.value), ) self.tracer.assign( target.id, new_node, ) else: self.visit_assign_value( target, self.visit(node.value), self.get_source(node), )
[docs] def visit_assign_value( self, target: ast.AST, value_node: Node, source_location: Optional[SourceLocation] = None, ) -> None: """ Visits assigning a target node to a value. This is extracted out of visit_assign, so we can call it multiple times and pass in the value as a node, instead of as AST, when we are assigning to a tuple. Assign currently special cases for: - Subscript, e.g., `ls[0] = 1` - Constant, e.g., `a = 1` - Call, e.g., `a = foo()` """ if isinstance(target, ast.Subscript): index = target.slice # note: isinstance(index, ast.List) only works for pandas, # not Python lists # if isinstance(index, (ast.Constant, ast.Name, ast.List, ast.Slice)): self.tracer.call( self.tracer.lookup_node(SET_ITEM), source_location, self.visit(target.value), self.visit(index), value_node, ) # e.g. `x.y = 10` elif isinstance(target, ast.Attribute): self.tracer.call( self.tracer.lookup_node(SET_ATTR), source_location, self.visit(target.value), self.visit(ast.Constant(target.attr)), value_node, ) elif isinstance(target, ast.List) or isinstance(target, ast.Tuple): # Assigning to a tuple or list of values, is like indexing the value # and then assigning to each. if any( isinstance(target_el, ast.Starred) for target_el in target.elts ): # count number of elements before and after the Starred item before = 0 for target_el in target.elts: if isinstance(target_el, ast.Starred): break else: before += 1 after = len(target.elts) - before - 1 # get a proper unpacked list of CallNode unpacked_nodes = self.tracer.call( self.tracer.lookup_node(l_unpack_ex.__name__), source_location, value_node, self.tracer.literal(before), self.tracer.literal(after), ) else: # get a proper unpacked list of CallNode unpacked_nodes = self.tracer.call( self.tracer.lookup_node(l_unpack_sequence.__name__), source_location, value_node, self.tracer.literal(len(target.elts)), ) # visit all elements of the new list for i, target_el in enumerate(target.elts): if isinstance(target_el, ast.Starred): target_el = target_el.value self.visit_assign_value( target_el, self.tracer.call( self.tracer.lookup_node(GET_ITEM), source_location, unpacked_nodes, self.tracer.literal(i), ), source_location, ) elif isinstance(target, ast.Name): variable_name = target.id self.tracer.assign( variable_name, value_node, ) else: raise NotImplementedError( "Other assignment types are not supported" ) return None
def visit_UnaryOp(self, node: ast.UnaryOp) -> CallNode: ast_to_op_map = { ast.Invert: INVERT, ast.Not: NOT, ast.UAdd: POS, ast.USub: NEG, } op = node.op return self.tracer.call( self.tracer.lookup_node(ast_to_op_map[type(op)]), self.get_source(node), self.visit(node.operand), ) def visit_List(self, node: ast.List) -> CallNode: elem_nodes = [self.visit(elem) for elem in node.elts] return self.tracer.call( self.tracer.lookup_node(l_list.__name__), self.get_source(node), *elem_nodes, ) def visit_Tuple(self, node: ast.Tuple) -> CallNode: elem_nodes = [self.visit(elem) for elem in node.elts] return self.tracer.tuple( *elem_nodes, source_location=self.get_source(node), ) def visit_BinOp(self, node: ast.BinOp) -> CallNode: ast_to_op_map = { ast.Add: ADD, ast.Sub: SUB, ast.Mult: MULT, ast.Div: DIV, ast.FloorDiv: FLOORDIV, ast.Mod: MOD, ast.Pow: POW, ast.LShift: LSHIFT, ast.RShift: RSHIFT, ast.BitOr: BITOR, ast.BitXor: BITXOR, ast.BitAnd: BITAND, ast.MatMult: MATMUL, } op = ast_to_op_map[node.op.__class__] argument_nodes = [self.visit(node.left), self.visit(node.right)] return self.tracer.call( self.tracer.lookup_node(op), self.get_source(node), *argument_nodes, ) def visit_BoolOp(self, node: ast.BoolOp) -> CallNode: ast_to_op_map = { ast.Or: BITOR, ast.And: BITAND, } op = ast_to_op_map[node.op.__class__] value_nodes = [self.visit(value) for value in node.values] return self.tracer.call( self.tracer.lookup_node(op), self.get_source(node), *value_nodes, ) def visit_Compare(self, node: ast.Compare) -> CallNode: ast_to_op_map = { ast.Eq: EQ, ast.NotEq: NOTEQ, ast.Lt: LT, ast.LtE: LTE, ast.Gt: GT, ast.GtE: GTE, ast.Is: IS, ast.IsNot: ISNOT, ast.In: IN, } from copy import deepcopy # TODO: either add more comments or refactor, hard to understand # ast.Compare can have an arbitrary number of operators # e.g., a < b <= c left = self.visit(node.left) for i in range(len(node.ops)): op = node.ops[i] right = self.visit(node.comparators[i]) if isinstance(op, ast.In) or isinstance(op, ast.NotIn): # flip left and right since in(a, b) = b.contains(a) left, right = right, deepcopy(left) if op.__class__ in ast_to_op_map: left = self.tracer.call( self.tracer.lookup_node(ast_to_op_map[op.__class__]), self.get_source(node), left, right, ) elif isinstance(op, ast.NotIn): # need to call operator.not_ on __contains___ inside = self.tracer.call( self.tracer.lookup_node(ast_to_op_map[ast.In]), self.get_source(node), left, right, ) left = self.tracer.call( self.tracer.lookup_node(NOT), self.get_source(node), inside, ) return left def visit_Slice(self, node: ast.Slice) -> CallNode: stop_node = ( self.visit(node.upper) if node.upper else self.tracer.literal(None) ) # From https://docs.python.org/3/library/functions.html?highlight=slice#slice # slice can be called in two ways: # 1. slice(stop) when the start and step are None if node.lower is None and node.step is None: args = [stop_node] # 2. slice(start, stop, [step]) otherwise else: start_node = ( self.visit(node.lower) if node.lower else self.tracer.literal(None) ) args = [start_node, stop_node] if node.step: step_node = self.visit(node.step) args.append(step_node) return self.tracer.call( self.tracer.lookup_node(slice.__name__), self.get_source(node), *args, ) def visit_Subscript(self, node: ast.Subscript) -> CallNode: args = [self.visit(node.value)] index = node.slice args.append(self.visit(index)) return self.tracer.call( self.tracer.lookup_node(GET_ITEM), self.get_source(node), *args, ) def visit_Attribute(self, node: ast.Attribute) -> CallNode: return self.tracer.call( self.tracer.lookup_node(GETATTR), self.get_source(node), self.visit(node.value), self.visit(ast.Constant(value=node.attr)), ) def _exec_statement(self, node: ast.AST) -> None: code = self._get_code_from_node(node) assert code self.tracer.call( self.tracer.lookup_node(l_exec_statement.__name__), self.get_source(node), self.tracer.literal(code), ) def _exec_expression(self, node: ast.AST) -> Node: code = self._get_code_from_node(node) assert code return self.tracer.call( self.tracer.lookup_node(l_exec_expr.__name__), self.get_source(node), self.tracer.literal(code), ) def visit_Dict(self, node: ast.Dict) -> CallNode: keys = node.keys values = node.values # Build a dict call from a list of tuples of each key, mapping to each value # If the key is None, use a sentinel value return self.tracer.call( self.tracer.lookup_node(l_dict.__name__), self.get_source(node), *( self.tracer.tuple( self.visit(k) if k is not None else self.tracer.call( self.tracer.lookup_node( l_dict_kwargs_sentinel.__name__ ), None, ), self.visit(v), ) for k, v in zip(keys, values) ), ) def get_source(self, node: ast.AST) -> Optional[SourceLocation]: if not hasattr(node, "lineno"): return None return SourceLocation( source_code=self.source_code, lineno=node.lineno, col_offset=node.col_offset, end_lineno=node.end_lineno, # type: ignore end_col_offset=node.end_col_offset, # type: ignore )