from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, cast
from sqlalchemy import create_engine
from sqlalchemy.orm import defaultload, scoped_session, sessionmaker
from sqlalchemy.pool import StaticPool
from sqlalchemy.sql.expression import and_
from lineapy.data.types import (
Artifact,
CallNode,
Execution,
GlobalNode,
ImportNode,
JupyterCell,
KeywordArgument,
LineaID,
LiteralNode,
LiteralType,
LookupNode,
MutateNode,
Node,
NodeValue,
PositionalArgument,
SessionContext,
SourceCode,
SourceLocation,
)
from lineapy.db.relational import (
ArtifactORM,
Base,
BaseNodeORM,
CallNodeORM,
ExecutionORM,
GlobalNodeORM,
GlobalReferenceORM,
ImplicitDependencyORM,
ImportNodeORM,
KeywordArgORM,
LiteralNodeORM,
LookupNodeORM,
MutateNodeORM,
NodeORM,
NodeValueORM,
PositionalArgORM,
SessionContextORM,
SourceCodeORM,
)
from lineapy.db.utils import OVERRIDE_HELP_TEXT, resolve_db_url
from lineapy.exceptions.db_exceptions import ArtifactSaveException
from lineapy.exceptions.user_exception import UserException
from lineapy.utils.analytics.event_schemas import ExceptionEvent
from lineapy.utils.analytics.usage_tracking import track # circular dep issues
from lineapy.utils.constants import DB_SQLITE_PREFIX, SQLALCHEMY_ECHO
from lineapy.utils.utils import get_literal_value_from_string
logger = logging.getLogger(__name__)
[docs]class RelationalLineaDB:
"""
- Note that LineaDB coordinates with asset manager and relational db.
- The asset manager deals with binaries (e.g., cached values)
The relational db deals with more structured data,
such as the Nodes and edges.
- Also, at some point we might have a "cache" such that the readers
don't have to go to the database if it's already
loaded, but that's low priority.
"""
def __init__(self, url: str):
"""
Create a linea DB, by connecting to a database url:
https://docs.sqlalchemy.org/en/14/core/engines.html#database-urls
"""
# create_engine params from
# https://stackoverflow.com/questions/21766960/operationalerror-no-such-table-in-flask-with-sqlalchemy
self.url: str = url
echo = os.getenv(SQLALCHEMY_ECHO, default="false").lower() == "true"
logger.debug(f"Connecting to Linea DB at {url}")
additional_args = {}
if url.startswith(DB_SQLITE_PREFIX):
additional_args = {"check_same_thread": False}
self.engine = create_engine(
url,
connect_args=additional_args,
poolclass=StaticPool,
echo=echo,
)
self.session = scoped_session(sessionmaker())
self.session.configure(bind=self.engine)
Base.metadata.create_all(self.engine)
def renew_session(self):
if self.url.startswith(DB_SQLITE_PREFIX):
self.commit()
self.session = scoped_session(sessionmaker())
self.session.configure(bind=self.engine)
@classmethod
def from_environment(cls, url: Optional[str] = None) -> RelationalLineaDB:
f"""
Creates a new database.
url: {OVERRIDE_HELP_TEXT}
"""
return cls(resolve_db_url(url))
@staticmethod
def get_type_of_literal_value(val: Any) -> LiteralType:
if isinstance(val, str):
return LiteralType.String
elif isinstance(val, bool):
return LiteralType.Boolean
elif isinstance(val, int):
return LiteralType.Integer
elif isinstance(val, float):
return LiteralType.Float
elif val is None:
return LiteralType.NoneType
elif val is ...:
return LiteralType.Ellipsis
raise NotImplementedError(f"Literal {val} is of type {type(val)}.")
def write_context(self, context: SessionContext) -> None:
args = context.dict()
context_orm = SessionContextORM(**args)
self.session.add(context_orm)
if not self.url.startswith(DB_SQLITE_PREFIX):
self.session.flush()
self.renew_session()
[docs] def commit(self) -> None:
"""
End the transaction and commit the changes.
"""
try:
self.session.commit()
except Exception as e:
self.session.rollback()
raise ArtifactSaveException() from e
[docs] def close(self):
"""
Close the database connection.
"""
# Always close, even if error is raised
# https://docs.sqlalchemy.org/en/14/orm/session_api.html#sqlalchemy.orm.sessionmaker
try:
self.commit()
finally:
self.session.close()
[docs] def write_source_code(self, source_code: SourceCode) -> None:
"""
Writes a source code object to the database.
It first has to convert it to a SourceCodeORM object, which has the fields
inlined instead of a union
"""
source_code_orm = SourceCodeORM(
id=source_code.id, code=source_code.code
)
location = source_code.location
if isinstance(location, Path):
source_code_orm.path = str(location)
source_code_orm.jupyter_execution_count = None
source_code_orm.jupyter_session_id = None
else:
source_code_orm.path = None
source_code_orm.jupyter_execution_count = location.execution_count
source_code_orm.jupyter_session_id = location.session_id
self.session.add(source_code_orm)
self.renew_session()
def write_node(self, node: Node) -> None:
args = node.dict(include={"id", "session_id", "node_type"})
s = node.source_location
if s:
args["lineno"] = s.lineno
args["col_offset"] = s.col_offset
args["end_lineno"] = s.end_lineno
args["end_col_offset"] = s.end_col_offset
args["source_code_id"] = s.source_code.id
node_orm: NodeORM
if isinstance(node, CallNode):
node_orm = CallNodeORM(
**args,
function_id=node.function_id,
positional_args={
PositionalArgORM(
index=i, starred=v.starred, arg_node_id=v.id
)
for i, v in enumerate(node.positional_args)
},
keyword_args={
KeywordArgORM(
name=v.key, arg_node_id=v.value, starred=v.starred
)
for v in node.keyword_args
},
global_reads={
GlobalReferenceORM(
call_node_id=node.id,
variable_name=k,
variable_node_id=id_,
)
for k, id_ in node.global_reads.items()
},
implicit_dependencies={
ImplicitDependencyORM(index=k, arg_node_id=id_)
for k, id_ in enumerate(node.implicit_dependencies)
},
)
elif isinstance(node, ImportNode):
node_orm = ImportNodeORM(
**args,
name=node.name,
version=node.version,
package_name=node.package_name,
path=node.path,
)
elif isinstance(node, LiteralNode):
# The value_type is not currently used anywhere
# Was used before for rendering to a web UI.
# Keeping it for now for anticipation of platform work.
node_orm = LiteralNodeORM(
**args,
value_type=RelationalLineaDB.get_type_of_literal_value(
node.value
),
value=str(node.value),
)
elif isinstance(node, MutateNode):
node_orm = MutateNodeORM(
**args,
call_id=node.call_id,
source_id=node.source_id,
)
elif isinstance(node, GlobalNode):
node_orm = GlobalNodeORM(
**args, call_id=node.call_id, name=node.name
)
else:
node_orm = LookupNodeORM(**args, name=node.name)
self.session.add(node_orm)
self.renew_session()
def write_node_value(
self,
node_value: NodeValue,
) -> None:
self.session.add(NodeValueORM(**node_value.dict()))
self.renew_session()
def write_artifact(self, artifact: Artifact) -> None:
artifact_orm = ArtifactORM(
node_id=artifact.node_id,
execution_id=artifact.execution_id,
name=artifact.name,
date_created=artifact.date_created,
version=artifact.version,
)
self.session.add(artifact_orm)
self.renew_session()
[docs] def artifact_in_db(
self, node_id: LineaID, execution_id: LineaID, name: str, version: int
) -> bool:
"""
Returns true if the artifact is already in the DB.
"""
return self.session.query(
self.session.query(ArtifactORM)
.filter(
and_(
ArtifactORM.node_id == node_id,
ArtifactORM.execution_id == execution_id,
ArtifactORM.name == name,
ArtifactORM.version == version,
)
)
.exists()
).scalar()
def write_execution(self, execution: Execution) -> None:
execution_orm = ExecutionORM(
id=execution.id,
timestamp=execution.timestamp,
)
self.session.add(execution_orm)
if not self.url.startswith(DB_SQLITE_PREFIX):
self.session.flush()
self.renew_session()
"""
Readers
"""
def map_orm_to_pydantic(self, node: NodeORM) -> Node:
args: Dict[str, Any] = {
"id": node.id,
"session_id": node.session_id,
"node_type": node.node_type,
}
if node.source_code:
source_code = SourceCode(
id=node.source_code_id,
code=node.source_code.code,
location=(
Path(node.source_code.path)
if node.source_code.path
else JupyterCell(
execution_count=node.source_code.jupyter_execution_count,
session_id=node.source_code.jupyter_session_id,
)
),
)
args["source_location"] = SourceLocation(
lineno=node.lineno,
col_offset=node.col_offset,
end_lineno=node.end_lineno,
end_col_offset=node.end_col_offset,
source_code=source_code,
)
# cast string serialized values to their appropriate types
if isinstance(node, LiteralNodeORM):
return LiteralNode(
value=get_literal_value_from_string(
node.value, node.value_type
),
**args,
)
if isinstance(node, ImportNodeORM):
return ImportNode(
name=node.name,
version=node.version,
package_name=node.package_name,
path=node.path,
**args,
)
if isinstance(node, CallNodeORM):
positional_args = [
v
for _, v in sorted(
(
# Not sure why we need cast here, index field isn't optional
# but mypy thinks it is
(
cast(int, p.index),
PositionalArgument(
id=p.arg_node_id, starred=p.starred
),
)
for p in node.positional_args
),
key=lambda p: p[0],
)
]
keyword_args = [
KeywordArgument(
key=n.name, value=n.arg_node_id, starred=n.starred
)
for n in node.keyword_args
]
global_reads = {
gr.variable_name: gr.variable_node_id
for gr in node.global_reads
}
implicit_dependencies = [
n.arg_node_id for n in node.implicit_dependencies
]
return CallNode(
function_id=node.function_id,
positional_args=positional_args,
keyword_args=keyword_args,
global_reads=global_reads,
implicit_dependencies=implicit_dependencies,
**args,
)
if isinstance(node, MutateNodeORM):
return MutateNode(
call_id=node.call_id,
source_id=node.source_id,
**args,
)
if isinstance(node, GlobalNodeORM):
return GlobalNode(
call_id=node.call_id,
name=node.name,
**args,
)
return LookupNode(name=node.name, **args)
[docs] def get_node_by_id(self, linea_id: LineaID) -> Node:
"""
Returns the node by looking up the database by ID
SQLAlchemy is able to translate between the two types on demand
"""
node = (
self.session.query(BaseNodeORM)
.filter(BaseNodeORM.id == linea_id)
.one()
)
return self.map_orm_to_pydantic(node)
def get_session_context(self, linea_id: LineaID) -> SessionContext:
query_obj = (
self.session.query(SessionContextORM)
.filter(SessionContextORM.id == linea_id)
.one()
)
obj = SessionContext.from_orm(query_obj)
return obj
def get_node_value_from_db(
self, node_id: LineaID, execution_id: LineaID
) -> Optional[NodeValueORM]:
value_orm = (
self.session.query(NodeValueORM)
.filter(
and_(
NodeValueORM.node_id == node_id,
NodeValueORM.execution_id == execution_id,
)
)
.first()
)
return value_orm
[docs] def get_node_value_path(
self, node_id: LineaID, execution_id: LineaID
) -> Optional[str]:
"""
Get the path to the value of the artifact.
:param other: Additional argument to let you query another artifact's value path.
This is set to be optional and if its not set, we will use the current artifact
"""
value = self.get_node_value_from_db(node_id, execution_id)
if not value:
raise ValueError("No value saved for this node")
return value.value
[docs] def node_value_in_db(
self, node_id: LineaID, execution_id: LineaID
) -> bool:
"""
Returns true if the node value is already in the DB.
"""
return self.session.query(
self.session.query(NodeValueORM)
.filter(
and_(
NodeValueORM.node_id == node_id,
NodeValueORM.execution_id == execution_id,
)
)
.exists()
).scalar()
[docs] def get_libraries_for_session(
self, session_id: LineaID
) -> List[ImportNodeORM]:
"""
Gets all dependencies for a session, assuming all the libs in a
particular session will be required to set up a new env.
TODO: I think this distinct is still broken, because we want to
make it distinct on a subset of columns: session_id, name, and version.
"""
return (
self.session.query(ImportNodeORM)
.filter(
and_(
ImportNodeORM.session_id == session_id,
ImportNodeORM.version is not None,
)
)
.distinct(ImportNodeORM.package_name)
.all()
)
[docs] def get_artifacts_for_session(
self, session_id: LineaID
) -> List[ArtifactORM]:
"""
Gets a code slice for an artifact by name, assuming there is only
one artifact with that name,
"""
return (
self.session.query(ArtifactORM)
.filter(BaseNodeORM.session_id == session_id)
.join(BaseNodeORM)
# Don't include source code in query, since it's not needed
.options(
defaultload(ArtifactORM.node).raiseload(
BaseNodeORM.source_code
)
)
.all()
)
[docs] def get_artifact_by_name(
self, artifact_name: str, version: Optional[int] = None
) -> ArtifactORM:
"""
Gets the most recent artifact with a certain name.
If a version is not specified, it will return the most recent
version sorted by date_created
"""
res_query = self.session.query(ArtifactORM).filter(
ArtifactORM.name == artifact_name
)
if version:
res_query = res_query.filter(ArtifactORM.version == version)
res = res_query.order_by(ArtifactORM.version.desc()).first()
if res is None:
msg = (
f"Artifact {artifact_name} (version {version})"
if version
else f"Artifact {artifact_name}"
)
track(ExceptionEvent("UserException", "Artifact not found"))
raise UserException(
NameError(
f"{msg} not found. Perhaps there was a typo. Please try lineapy.catalog() to inspect all your artifacts."
)
)
return res
[docs] def get_latest_artifact_version(self, artifact_name: str) -> int:
"""
Get the latest version number of an artifact.
If the artifact does not exist, it will return -1
"""
res = (
self.session.query(ArtifactORM)
.filter(ArtifactORM.name == artifact_name)
.order_by(ArtifactORM.version.desc())
.first()
)
return -1 if res is None else res.version
[docs] def get_all_artifacts(self) -> List[ArtifactORM]:
"""
Used by the catalog to get all the artifacts
"""
results = self.session.query(ArtifactORM).all()
return results
[docs] def get_nodes_for_session(self, session_id: LineaID) -> List[Node]:
"""
Get all the nodes associated with the session, which does
NOT include things like SessionContext
"""
node_orms = (
self.session.query(BaseNodeORM)
.filter(BaseNodeORM.session_id == session_id)
.all()
)
return [self.map_orm_to_pydantic(node) for node in node_orms]
def get_source_code_for_session(self, session_id: LineaID) -> str:
if (
self.get_session_context(session_id).environment_type.name
== "JUPYTER"
):
jupyter_source_code_orms = (
self.session.query(SourceCodeORM)
.filter(SourceCodeORM.jupyter_session_id == session_id)
.order_by(SourceCodeORM.jupyter_execution_count)
.all()
)
return "".join(
source_code.code for source_code in jupyter_source_code_orms
)
else:
script_source_code_orms = (
self.session.query(SourceCodeORM)
.join(
BaseNodeORM, SourceCodeORM.id == BaseNodeORM.source_code_id
)
.filter(BaseNodeORM.session_id == session_id)
.first()
)
return (
script_source_code_orms.code
if script_source_code_orms is not None
else ""
)