"""
This file contains the ORM versions of the graph node in types.py.
Pydantic allows us to extract out a Dataclass like object from the ORM,
but not let us directly write to the ORM.
Relationships
-------------
Warning
-------
non exhaustive list
SessionContext
- ImportNode (One to Many)
- HardwareSpec (Many to One)
Node
- SessionContext (Many to One)
CallNode
- Node (Many to Many)
"""
from __future__ import annotations
from datetime import datetime
from typing import Union
from sqlalchemy import (
Boolean,
CheckConstraint,
Column,
DateTime,
Enum,
ForeignKey,
Integer,
String,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from lineapy.data.types import (
LineaID,
LiteralType,
NodeType,
SessionType,
ValueType,
)
from lineapy.utils.constants import ARTIFACT_NAME_PLACEHOLDER
Base = declarative_base()
[docs]class SessionContextORM(Base):
__tablename__ = "session_context"
id = Column(String, primary_key=True)
environment_type = Column(Enum(SessionType))
creation_time = Column(DateTime)
working_directory = Column(String)
session_name = Column(String, nullable=True)
user_name = Column(String, nullable=True)
hardware_spec = Column(String, nullable=True)
execution_id = Column(String, ForeignKey("execution.id"))
[docs]class ArtifactORM(Base):
"""
An artifact is a named pointer to a node.
"""
__tablename__ = "artifact"
node_id: LineaID = Column(String, ForeignKey("node.id"), primary_key=True)
execution_id: LineaID = Column(
String, ForeignKey("execution.id"), primary_key=True
)
name = Column(
String,
nullable=False,
default=ARTIFACT_NAME_PLACEHOLDER,
primary_key=True,
)
date_created = Column(DateTime, nullable=False)
version = Column(Integer, nullable=False, primary_key=True)
node: BaseNodeORM = relationship(
"BaseNodeORM", uselist=False, lazy="joined", innerjoin=True
)
execution: ExecutionORM = relationship(
"ExecutionORM", uselist=False, lazy="joined", innerjoin=True
)
[docs]class ExecutionORM(Base):
"""
An execution represents one Python interpreter invocation of some number of nodes
"""
__tablename__ = "execution"
id = Column(String, primary_key=True)
timestamp = Column(DateTime, nullable=True, default=datetime.utcnow)
[docs]class NodeValueORM(Base):
"""
A node value represents the value of a node during some execution.
It is uniquely identified by the `node_id` and `execution_id`.
The following invariant holds:
`value.node.session == value.execution.session`
"""
__tablename__ = "node_value"
node_id = Column(String, ForeignKey("node.id"), primary_key=True)
execution_id = Column(String, ForeignKey("execution.id"), primary_key=True)
value = Column(String, nullable=True)
value_type = Column(Enum(ValueType))
start_time = Column(DateTime, nullable=True)
end_time = Column(DateTime, nullable=True)
[docs]class BaseNodeORM(Base):
"""
node.source_code has a path value if node.session.environment_type == "script"
otherwise the environment type is "jupyter" and it has a jupyter execution
count and session id, which is equal to the node.session
NOTE
----
- Because other nodes are inheriting from BaseNodeORM, finding a node
based on its id is easy (something like the following)::
session.query(BaseNodeORM).filter(BaseNodeORM.id == linea_id)
- Each node inheriting from BaseNodeORM must have non null values for
all of lineno, col_offset, end_lineno, end_col_offset and source_code_id
or nulls for all of them.
"""
__tablename__ = "node"
id = Column(String, primary_key=True)
session_id: LineaID = Column(String)
node_type = Column(Enum(NodeType))
lineno = Column(Integer, nullable=True) # line numbers are 1-indexed
col_offset = Column(Integer, nullable=True) # col numbers are 0-indexed
end_lineno = Column(Integer, nullable=True)
end_col_offset = Column(Integer, nullable=True)
source_code_id = Column(
String, ForeignKey("source_code.id"), nullable=True
)
source_code: SourceCodeORM = relationship("SourceCodeORM", lazy="joined")
__table_args__ = (
# Either all source keys or none should be specified
CheckConstraint(
"(lineno IS NULL) = (col_offset is NULL) and "
"(col_offset is NULL) = (end_lineno is NULL) and "
"(end_lineno is NULL) = (end_col_offset is NULL) and "
"(end_col_offset is NULL) = (source_code_id is NULL)"
),
)
# https://docs.sqlalchemy.org/en/14/orm/inheritance.html#joined-table-inheritance
__mapper_args__ = {
"polymorphic_on": node_type,
"polymorphic_identity": NodeType.Node,
}
[docs]class SourceCodeORM(Base):
__tablename__ = "source_code"
id = Column(String, primary_key=True)
code = Column(String)
path = Column(String, nullable=True)
jupyter_execution_count = Column(Integer, nullable=True)
jupyter_session_id = Column(String, nullable=True)
__table_args__ = (
# Either path is set or jupyter_execution_count and jupyter_session_id are set
CheckConstraint(
"(path IS NOT NULL) != ((jupyter_execution_count IS NOT NULL) AND "
"(jupyter_execution_count IS NOT NULL))"
),
# If one jupyter arg is provided, both must be
CheckConstraint(
"(jupyter_execution_count IS NULL) = (jupyter_session_id is NULL)"
),
)
[docs]class LookupNodeORM(BaseNodeORM):
__tablename__ = "lookup"
__mapper_args__ = {"polymorphic_identity": NodeType.LookupNode}
id = Column(String, ForeignKey("node.id"), primary_key=True)
name = Column(String, nullable=False)
[docs]class ImportNodeORM(BaseNodeORM):
__tablename__ = "import_node"
__mapper_args__ = {"polymorphic_identity": NodeType.ImportNode}
id = Column(String, ForeignKey("node.id"), primary_key=True)
name = Column(String)
package_name = Column(String, nullable=True)
version = Column(String, nullable=True)
path = Column(String, nullable=True)
# Use associations for many to many relationship between calls and args
# https://docs.sqlalchemy.org/en/14/orm/basic_relationships.html#association-object
[docs]class PositionalArgORM(Base):
__tablename__ = "positional_arg"
call_node_id: str = Column(
ForeignKey("call_node.id"), primary_key=True, nullable=False
)
arg_node_id: LineaID = Column(
ForeignKey("node.id"), primary_key=True, nullable=False
)
starred: bool = Column(Boolean, nullable=False, default=False)
index = Column(Integer, primary_key=True, nullable=False)
argument = relationship(BaseNodeORM, uselist=False)
[docs]class KeywordArgORM(Base):
__tablename__ = "keyword_arg"
call_node_id: str = Column(
ForeignKey("call_node.id"), primary_key=True, nullable=False
)
arg_node_id: LineaID = Column(
ForeignKey("node.id"), primary_key=True, nullable=False
)
starred: bool = Column(Boolean, nullable=False, default=False)
name: str = Column(String, primary_key=True, nullable=False)
argument = relationship(BaseNodeORM, uselist=False)
[docs]class GlobalReferenceORM(Base):
__tablename__ = "global_reference"
call_node_id: str = Column(
ForeignKey("call_node.id"), primary_key=True, nullable=False
)
variable_node_id: str = Column(
ForeignKey("node.id"), primary_key=True, nullable=False
)
variable_name = Column(String, primary_key=True, nullable=False)
variable_node = relationship(BaseNodeORM, uselist=False)
[docs]class ImplicitDependencyORM(Base):
__tablename__ = "implicit_dependency"
call_node_id: str = Column(
ForeignKey("call_node.id"), primary_key=True, nullable=False
)
arg_node_id: str = Column(
ForeignKey("node.id"), primary_key=True, nullable=False
)
index = Column(Integer, primary_key=True, nullable=False)
argument = relationship(BaseNodeORM, uselist=False)
[docs]class CallNodeORM(BaseNodeORM):
__tablename__ = "call_node"
id = Column(String, ForeignKey("node.id"), primary_key=True)
function_id = Column(String, ForeignKey("node.id"))
positional_args = relationship(
PositionalArgORM, collection_class=set, lazy="joined"
)
keyword_args = relationship(
KeywordArgORM, collection_class=set, lazy="joined"
)
global_reads = relationship(
GlobalReferenceORM, collection_class=set, lazy="joined"
)
implicit_dependencies = relationship(
ImplicitDependencyORM,
collection_class=set,
lazy="joined",
)
__mapper_args__ = {
"polymorphic_identity": NodeType.CallNode,
# Need this so that sqlalchemy doesn't get confused about additional
# foreign key from function_id
# https://stackoverflow.com/a/39518177/907060
"inherit_condition": id == BaseNodeORM.id,
}
[docs]class LiteralNodeORM(BaseNodeORM):
__tablename__ = "literal_assign_node"
__mapper_args__ = {"polymorphic_identity": NodeType.LiteralNode}
id = Column(String, ForeignKey("node.id"), primary_key=True)
value_type: LiteralType = Column(Enum(LiteralType))
# The value of the literal serialized as a string
value: str = Column(String, nullable=False)
[docs]class MutateNodeORM(BaseNodeORM):
__tablename__ = "mutate_node"
id = Column(String, ForeignKey("node.id"), primary_key=True)
source_id = Column(String, ForeignKey("node.id"))
call_id = Column(String, ForeignKey("node.id"))
__mapper_args__ = {
"polymorphic_identity": NodeType.MutateNode,
"inherit_condition": id == BaseNodeORM.id,
}
[docs]class GlobalNodeORM(BaseNodeORM):
__tablename__ = "global_node"
id = Column(String, ForeignKey("node.id"), primary_key=True)
name = Column(String)
call_id = Column(String, ForeignKey("node.id"))
__mapper_args__ = {
"polymorphic_identity": NodeType.GlobalNode,
"inherit_condition": id == BaseNodeORM.id,
}
# Explicitly define all subclasses of NodeORM, so that if we use this as a type
# we can accurately know if we cover all cases
NodeORM = Union[
LookupNodeORM,
ImportNodeORM,
CallNodeORM,
LiteralNodeORM,
MutateNodeORM,
GlobalNodeORM,
]