Source code for lineapy.editors.ipython

"""
Transforms all executions in IPython to execute with lineapy, by adding to
`input_transformers_post`. You can find more documentations below:
https://ipython.readthedocs.io/en/stable/config/inputtransforms.html
"""
from __future__ import annotations

from dataclasses import dataclass, field
from typing import List, Optional, Union

from IPython.core.interactiveshell import InteractiveShell
from IPython.display import DisplayHandle, DisplayObject, display

from lineapy.data.types import JupyterCell, SessionType
from lineapy.db.db import RelationalLineaDB
from lineapy.editors.ipython_cell_storage import cleanup_cells, get_cell_path
from lineapy.exceptions.excepthook import transform_except_hook_args
from lineapy.exceptions.flag import REWRITE_EXCEPTIONS
from lineapy.exceptions.user_exception import AddFrame
from lineapy.instrumentation.tracer import Tracer
from lineapy.transformer.node_transformer import transform
from lineapy.utils.analytics import send_lib_info_from_db
from lineapy.utils.logging_config import configure_logging

__all__ = ["_end_cell", "start", "stop", "visualize"]

# The state of the ipython extension works like:
# 1. Originally starts at `None``, meaning that no ipython transformations are active
# 2. After calling `start`, it transitions to `Started`, which means that the
#    the transformer is registered and the exceptions will be transformed.
# 3. After the first cell is executed, the state changes to `CellsExecuted`,
#    which means we have connected to the database and have started saving traces.
#    Note: We wait to connect to the DB till the first cell is executed, so that
#    any logging printed during this connection, or errors raised during this connection
#    are displayed to users, instead of being lost, if the start is called during
#    an extension load during ipython startup.
# SS: do not explicitly set the state to `None` here
STATE: Union[None, StartedState, CellsExecutedState]


@dataclass
class StartedState:
    # Save the ipython in the started state, because we can't look it
    # up during our transformation, and we need it to get the globals
    ipython: InteractiveShell

    # Optionally overrides for the session name and DB URL
    session_name: Optional[str]
    db_url: Optional[str]


@dataclass
class CellsExecutedState:
    tracer: Tracer
    # The code for this cell's execution
    code: str
    # If set, we should update this display on every cell execution.
    visualize_display_handle: Optional[DisplayHandle] = field(default=None)

    def create_visualize_display_object(self) -> DisplayObject:
        """
        Returns a jupyter display object for the visualization.
        """
        from lineapy.visualizer import Visualizer

        return Visualizer.for_public(self.tracer).ipython_display_object()


[docs]def start( session_name: Optional[str] = None, db_url: Optional[str] = None, ipython: Optional[InteractiveShell] = None, ) -> None: """ Initializing the runtime so that the cells are traced with lineapy. """ global STATE ipython = ipython or get_ipython() # type: ignore # IPython does not use exceptionhook, so instead we monkeypatch # how it processes the exceptions, in order to add our handler # that removes the outer frames. if REWRITE_EXCEPTIONS: InteractiveShell._get_exc_info = custom_get_exc_info ipython.input_transformers_post.append(input_transformer_post) STATE = StartedState(ipython, session_name=session_name, db_url=db_url)
def input_transformer_post(lines: List[str]) -> List[str]: """ Translate the lines of code for the cell provided by ipython. """ global STATE if not STATE: raise RuntimeError( "input_transformer_post shouldn't be called when we don't have an active tracer" ) code = "".join(lines) # If we have just started, first start everything up if isinstance(STATE, StartedState): configure_logging() db = RelationalLineaDB.from_environment(STATE.db_url) # pass in globals from ipython so that `get_ipython()` works # and things like `!cat df.csv` work in the notebook ipython_globals = STATE.ipython.user_global_ns tracer = Tracer( db, SessionType.JUPYTER, STATE.session_name, ipython_globals ) STATE = CellsExecutedState(tracer, code=code) else: STATE.code = code return RETURNED_LINES # We always return the same two lines for IPython to use as input. # They will run our internal end cell function, which wil return the proper # return value for the cell, so ipython can display it. They will also clean # up the tracer if we stopped in that cell. RETURNED_LINES = [ "import lineapy.editors.ipython\n", "lineapy.editors.ipython._end_cell()\n", ] def _end_cell() -> object: """ Returns the last value that was executed, used when rendering the cell. We also write each ipython cell to its own temporary file, so that if an exception is raised it will have proper tracebacks (this is how ipython handles error reporting as well). There are more details in the README file. """ global STATE if not isinstance(STATE, CellsExecutedState): raise ValueError("We need to be executing cells to get the last value") execution_count: int = get_ipython().execution_count # type: ignore location = JupyterCell( execution_count=execution_count, session_id=STATE.tracer.get_session_id(), ) code = STATE.code # Write the code text to a file for error reporting get_cell_path(location).write_text(code) last_node = transform(code, location, STATE.tracer) if STATE.visualize_display_handle: STATE.visualize_display_handle.update( STATE.create_visualize_display_object() ) # Return the last value so it will be printed, if we don't end # in a semicolon ends_with_semicolon = code.strip().endswith(";") if not ends_with_semicolon and last_node: res = STATE.tracer.executor.get_value(last_node.id) else: res = None return res
[docs]def visualize(*, live=False) -> None: """ Display a visualization of the Linea graph from this session using Graphviz. If `live=True`, then this visualization will live update after cell execution. Note that this comes with a substantial performance penalty, so it is False by default. Note: If the visualization is not live, it will print out the visualization as of the previous cell execution, not the one where `visualize` is executed. """ if not isinstance(STATE, CellsExecutedState): raise RuntimeError( "Cannot visualize before we have started executing cells" ) display_object = STATE.create_visualize_display_object() if live: # If we have an existing display handle, display a new version of it. if STATE.visualize_display_handle: STATE.visualize_display_handle.display(display_object) # Otherwise, create a new one else: STATE.visualize_display_handle = display( display_object, display_id=True ) else: # Otherwise, just display the visualization display(display_object)
[docs]def stop() -> None: """ Stop tracing if the `stop()` was called in the cell and should_stop was set. Also track the lib information (would be the most complete). """ if isinstance(STATE, CellsExecutedState): send_lib_info_from_db(STATE.tracer.db, STATE.tracer.get_session_id()) STATE.tracer.db.close() cleanup_cells()
# Save the original get_exc_info so that we can call it in our custom one # after transforming original_get_exc_info = InteractiveShell._get_exc_info def custom_get_exc_info(*args, **kwargs): """ A custom get_exc_info which will transform exceptions raised from the users code to remove our frames that we have added. Add an extra frame on top (in the `AddFrame` call), since ipython will strip out the first one (might change in future versions), which is probably also for similar reasons as us. """ return transform_except_hook_args( original_get_exc_info(*args, **kwargs), AddFrame("", 0), )