Module arti.backends.memory
None
None
View Source
from __future__ import annotations
from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager
from functools import partial
from pydantic import PrivateAttr
from arti import (
Artifact,
Backend,
Connection,
Fingerprint,
Graph,
GraphSnapshot,
InputFingerprints,
Storage,
StoragePartition,
StoragePartitions,
)
from arti.internal.utils import NoCopyMixin
def _ensure_fingerprinted(partitions: StoragePartitions) -> Iterator[StoragePartition]:
for partition in partitions:
yield partition.with_content_fingerprint(keep_existing=True)
_Graphs = dict[str, dict[Fingerprint, Graph]]
_GraphSnapshots = dict[str, dict[Fingerprint, GraphSnapshot]]
_GraphSnapshotPartitions = dict[
GraphSnapshot, dict[str, set[StoragePartition]]
] # ...[snapshot][artifact_key]
_SnapshotTags = dict[str, dict[str, GraphSnapshot]] # ...[name][tag]
_StoragePartitions = dict[Storage[StoragePartition], set[StoragePartition]]
class _NoCopyContainer(NoCopyMixin):
"""Container for MemoryBackend data that bypasses (deep)copying.
The MemoryBackend is *intended* to be stateful, like a connection to an external database in
other backends. However, we usually prefer immutable data structures and Pydantic models, which
(deep)copy often. If we were to (deep)copy these data structures, then we wouldn't be able to
track metadata between steps. Instead, this container holds the state and skips (deep)copying.
We may also add threading locks around access (with some slight usage changes).
"""
def __init__(self) -> None:
# NOTE: lambdas are not pickleable, so use partial for any nested defaultdicts.
self.graphs: _Graphs = defaultdict(dict)
self.snapshots: _GraphSnapshots = defaultdict(dict)
# `container.snapshot_partitions` tracks all the partitions for a *specific* GraphSnapshot.
# `container.storage_partitions` tracks all partitions, across all snapshots. This
# separation is important to allow for Literals to be used even after a snapshot change.
self.snapshot_partitions: _GraphSnapshotPartitions = defaultdict(
partial(defaultdict, set[StoragePartition]) # type: ignore[arg-type]
)
self.snapshot_tags: _SnapshotTags = defaultdict(dict)
self.storage_partitions: _StoragePartitions = defaultdict(set[StoragePartition])
class MemoryConnection(Connection):
def __init__(self, container: _NoCopyContainer) -> None:
self.container = container
def read_artifact_partitions(
self, artifact: Artifact, input_fingerprints: InputFingerprints = InputFingerprints()
) -> StoragePartitions:
# The MemoryBackend is (obviously) not persistent, so there may be external data we don't
# know about. If we haven't seen this storage before, we'll attempt to "warm" the cache.
if artifact.storage not in self.container.storage_partitions:
self.write_artifact_partitions(
artifact, artifact.storage.discover_partitions(input_fingerprints)
)
partitions = self.container.storage_partitions[artifact.storage]
if input_fingerprints:
partitions = {
partition
for partition in partitions
if input_fingerprints.get(partition.keys) == partition.input_fingerprint
}
return tuple(partitions)
def write_artifact_partitions(self, artifact: Artifact, partitions: StoragePartitions) -> None:
self.container.storage_partitions[artifact.storage].update(
_ensure_fingerprinted(partitions)
)
def read_graph(self, name: str, fingerprint: Fingerprint) -> Graph:
return self.container.graphs[name][fingerprint]
def write_graph(self, graph: Graph) -> None:
self.container.graphs[graph.name][graph.fingerprint] = graph
def read_snapshot(self, name: str, fingerprint: Fingerprint) -> GraphSnapshot:
return self.container.snapshots[name][fingerprint]
def write_snapshot(self, snapshot: GraphSnapshot) -> None:
self.container.snapshots[snapshot.name][snapshot.fingerprint] = snapshot
def read_snapshot_tag(self, name: str, tag: str) -> GraphSnapshot:
if tag not in self.container.snapshot_tags[name]:
raise ValueError(f"No known `{tag}` tag for GraphSnapshot `{name}`")
return self.container.snapshot_tags[name][tag]
def write_snapshot_tag(
self, snapshot: GraphSnapshot, tag: str, overwrite: bool = False
) -> None:
"""Read the known Partitions for the named Artifact in a specific GraphSnapshot."""
if (
existing := self.container.snapshot_tags[snapshot.name].get(tag)
) is not None and not overwrite:
raise ValueError(
f"Existing `{tag}` tag for Graph `{snapshot.name}` points to {existing}"
)
self.container.snapshot_tags[snapshot.name][tag] = snapshot
def read_snapshot_partitions(
self, snapshot: GraphSnapshot, artifact_key: str, artifact: Artifact
) -> StoragePartitions:
return tuple(self.container.snapshot_partitions[snapshot][artifact_key])
def write_snapshot_partitions(
self,
snapshot: GraphSnapshot,
artifact_key: str,
artifact: Artifact,
partitions: StoragePartitions,
) -> None:
self.container.snapshot_partitions[snapshot][artifact_key].update(
_ensure_fingerprinted(partitions)
)
class MemoryBackend(Backend[MemoryConnection]):
_container: _NoCopyContainer = PrivateAttr(default_factory=_NoCopyContainer)
@contextmanager
def connect(self) -> Iterator[MemoryConnection]:
yield MemoryConnection(self._container)
Classes
MemoryBackend
class MemoryBackend(
__pydantic_self__,
**data: Any
)
View Source
class MemoryBackend(Backend[MemoryConnection]):
_container: _NoCopyContainer = PrivateAttr(default_factory=_NoCopyContainer)
@contextmanager
def connect(self) -> Iterator[MemoryConnection]:
yield MemoryConnection(self._container)
Ancestors (in MRO)
- arti.backends.Backend
- arti.internal.models.Model
- pydantic.main.BaseModel
- pydantic.utils.Representation
- typing.Generic
Class variables
Config
Static methods
construct
def construct(
_fields_set: Optional[ForwardRef('SetStr')] = None,
**values: Any
) -> 'Model'
Creates a new model setting dict and fields_set from trusted or pre-validated data.
Default values are respected, but no other validation is performed.
Behaves as if Config.extra = 'allow'
was set since it adds all passed values
from_orm
def from_orm(
obj: Any
) -> 'Model'
parse_file
def parse_file(
path: Union[str, pathlib.Path],
*,
content_type: 'unicode' = None,
encoding: 'unicode' = 'utf8',
proto: pydantic.parse.Protocol = None,
allow_pickle: bool = False
) -> 'Model'
parse_obj
def parse_obj(
obj: Any
) -> 'Model'
parse_raw
def parse_raw(
b: Union[str, bytes],
*,
content_type: 'unicode' = None,
encoding: 'unicode' = 'utf8',
proto: pydantic.parse.Protocol = None,
allow_pickle: bool = False
) -> 'Model'
schema
def schema(
by_alias: bool = True,
ref_template: 'unicode' = '#/definitions/{model}'
) -> 'DictStrAny'
schema_json
def schema_json(
*,
by_alias: bool = True,
ref_template: 'unicode' = '#/definitions/{model}',
**dumps_kwargs: Any
) -> 'unicode'
update_forward_refs
def update_forward_refs(
**localns: Any
) -> None
Try to update ForwardRefs on fields based on this Model, globalns and localns.
validate
def validate(
value: Any
) -> 'Model'
Instance variables
fingerprint
Methods
connect
def connect(
self
) -> 'Iterator[MemoryConnection]'
View Source
@contextmanager
def connect(self) -> Iterator[MemoryConnection]:
yield MemoryConnection(self._container)
copy
def copy(
self,
*,
deep: 'bool' = False,
validate: 'bool' = True,
**kwargs: 'Any'
) -> 'Self'
Duplicate a model, optionally choose which fields to include, exclude and change.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
include | None | fields to include in new model | None |
exclude | None | fields to exclude from new model, as with values this takes precedence over include | None |
update | None | values to change/add in the new model. Note: the data is not validated before creating | |
the new model: you should trust this data | None | ||
deep | None | set to True to make a deep copy of the model |
None |
Returns:
Type | Description |
---|---|
None | new model instance |
View Source
def copy(self, *, deep: bool = False, validate: bool = True, **kwargs: Any) -> Self:
copy = super().copy(deep=deep, **kwargs)
if validate:
# NOTE: We set exclude_unset=False so that all existing defaulted fields are reused (as
# is normal `.copy` behavior).
#
# To reduce `repr` noise, we'll reset .__fields_set__ to those of the pre-validation copy
# (which includes those originally set + updated).
fields_set = copy.__fields_set__
copy = copy.validate(
dict(copy._iter(to_dict=False, by_alias=False, exclude_unset=False))
)
# Use object.__setattr__ to bypass frozen model assignment errors
object.__setattr__(copy, "__fields_set__", set(fields_set))
# Copy over the private attributes, which are missing after validation (since we're only
# passing the fields).
for name in self.__private_attributes__:
if (value := getattr(self, name, Undefined)) is not Undefined:
if deep:
value = deepcopy(value)
object.__setattr__(copy, name, value)
return copy
dict
def dict(
self,
*,
include: Union[ForwardRef('AbstractSetIntStr'), ForwardRef('MappingIntStrAny'), NoneType] = None,
exclude: Union[ForwardRef('AbstractSetIntStr'), ForwardRef('MappingIntStrAny'), NoneType] = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False
) -> 'DictStrAny'
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
json
def json(
self,
*,
include: Union[ForwardRef('AbstractSetIntStr'), ForwardRef('MappingIntStrAny'), NoneType] = None,
exclude: Union[ForwardRef('AbstractSetIntStr'), ForwardRef('MappingIntStrAny'), NoneType] = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
encoder: Optional[Callable[[Any], Any]] = None,
models_as_dict: bool = True,
**dumps_kwargs: Any
) -> 'unicode'
Generate a JSON representation of the model, include
and exclude
arguments as per dict()
.
encoder
is an optional function to supply as default
to json.dumps(), other arguments as per json.dumps()
.
MemoryConnection
class MemoryConnection(
container: '_NoCopyContainer'
)
View Source
class MemoryConnection(Connection):
def __init__(self, container: _NoCopyContainer) -> None:
self.container = container
def read_artifact_partitions(
self, artifact: Artifact, input_fingerprints: InputFingerprints = InputFingerprints()
) -> StoragePartitions:
# The MemoryBackend is (obviously) not persistent, so there may be external data we don't
# know about. If we haven't seen this storage before, we'll attempt to "warm" the cache.
if artifact.storage not in self.container.storage_partitions:
self.write_artifact_partitions(
artifact, artifact.storage.discover_partitions(input_fingerprints)
)
partitions = self.container.storage_partitions[artifact.storage]
if input_fingerprints:
partitions = {
partition
for partition in partitions
if input_fingerprints.get(partition.keys) == partition.input_fingerprint
}
return tuple(partitions)
def write_artifact_partitions(self, artifact: Artifact, partitions: StoragePartitions) -> None:
self.container.storage_partitions[artifact.storage].update(
_ensure_fingerprinted(partitions)
)
def read_graph(self, name: str, fingerprint: Fingerprint) -> Graph:
return self.container.graphs[name][fingerprint]
def write_graph(self, graph: Graph) -> None:
self.container.graphs[graph.name][graph.fingerprint] = graph
def read_snapshot(self, name: str, fingerprint: Fingerprint) -> GraphSnapshot:
return self.container.snapshots[name][fingerprint]
def write_snapshot(self, snapshot: GraphSnapshot) -> None:
self.container.snapshots[snapshot.name][snapshot.fingerprint] = snapshot
def read_snapshot_tag(self, name: str, tag: str) -> GraphSnapshot:
if tag not in self.container.snapshot_tags[name]:
raise ValueError(f"No known `{tag}` tag for GraphSnapshot `{name}`")
return self.container.snapshot_tags[name][tag]
def write_snapshot_tag(
self, snapshot: GraphSnapshot, tag: str, overwrite: bool = False
) -> None:
"""Read the known Partitions for the named Artifact in a specific GraphSnapshot."""
if (
existing := self.container.snapshot_tags[snapshot.name].get(tag)
) is not None and not overwrite:
raise ValueError(
f"Existing `{tag}` tag for Graph `{snapshot.name}` points to {existing}"
)
self.container.snapshot_tags[snapshot.name][tag] = snapshot
def read_snapshot_partitions(
self, snapshot: GraphSnapshot, artifact_key: str, artifact: Artifact
) -> StoragePartitions:
return tuple(self.container.snapshot_partitions[snapshot][artifact_key])
def write_snapshot_partitions(
self,
snapshot: GraphSnapshot,
artifact_key: str,
artifact: Artifact,
partitions: StoragePartitions,
) -> None:
self.container.snapshot_partitions[snapshot][artifact_key].update(
_ensure_fingerprinted(partitions)
)
Ancestors (in MRO)
- arti.backends.Connection
Methods
connect
def connect(
self
) -> 'Iterator[Self]'
Return self
This makes it easier to work with an Optional connection, eg: with (connection or backend).connect() as conn: ...
View Source
@contextmanager
def connect(self) -> Iterator[Self]:
"""Return self
This makes it easier to work with an Optional connection, eg:
with (connection or backend).connect() as conn:
...
"""
yield self
read_artifact_partitions
def read_artifact_partitions(
self,
artifact: 'Artifact',
input_fingerprints: 'InputFingerprints' = {}
) -> 'StoragePartitions'
Read all known Partitions for this Storage spec.
If input_fingerprints
is provided, the returned partitions will be filtered accordingly.
NOTE: The returned partitions may not be associated with any particular Graph, unless
input_fingerprints
is provided matching those for a GraphSnapshot.
View Source
def read_artifact_partitions(
self, artifact: Artifact, input_fingerprints: InputFingerprints = InputFingerprints()
) -> StoragePartitions:
# The MemoryBackend is (obviously) not persistent, so there may be external data we don't
# know about. If we haven't seen this storage before, we'll attempt to "warm" the cache.
if artifact.storage not in self.container.storage_partitions:
self.write_artifact_partitions(
artifact, artifact.storage.discover_partitions(input_fingerprints)
)
partitions = self.container.storage_partitions[artifact.storage]
if input_fingerprints:
partitions = {
partition
for partition in partitions
if input_fingerprints.get(partition.keys) == partition.input_fingerprint
}
return tuple(partitions)
read_graph
def read_graph(
self,
name: 'str',
fingerprint: 'Fingerprint'
) -> 'Graph'
Fetch an instance of the named Graph.
View Source
def read_graph(self, name: str, fingerprint: Fingerprint) -> Graph:
return self.container.graphs[name][fingerprint]
read_snapshot
def read_snapshot(
self,
name: 'str',
fingerprint: 'Fingerprint'
) -> 'GraphSnapshot'
Fetch an instance of the named GraphSnapshot.
View Source
def read_snapshot(self, name: str, fingerprint: Fingerprint) -> GraphSnapshot:
return self.container.snapshots[name][fingerprint]
read_snapshot_partitions
def read_snapshot_partitions(
self,
snapshot: 'GraphSnapshot',
artifact_key: 'str',
artifact: 'Artifact'
) -> 'StoragePartitions'
Read the known Partitions for the named Artifact in a specific GraphSnapshot.
View Source
def read_snapshot_partitions(
self, snapshot: GraphSnapshot, artifact_key: str, artifact: Artifact
) -> StoragePartitions:
return tuple(self.container.snapshot_partitions[snapshot][artifact_key])
read_snapshot_tag
def read_snapshot_tag(
self,
name: 'str',
tag: 'str'
) -> 'GraphSnapshot'
Fetch the GraphSnapshot for the named tag.
View Source
def read_snapshot_tag(self, name: str, tag: str) -> GraphSnapshot:
if tag not in self.container.snapshot_tags[name]:
raise ValueError(f"No known `{tag}` tag for GraphSnapshot `{name}`")
return self.container.snapshot_tags[name][tag]
write_artifact_and_graph_partitions
def write_artifact_and_graph_partitions(
self,
snapshot: 'GraphSnapshot',
artifact_key: 'str',
artifact: 'Artifact',
partitions: 'StoragePartitions'
) -> 'None'
View Source
def write_artifact_and_graph_partitions(
self,
snapshot: GraphSnapshot,
artifact_key: str,
artifact: Artifact,
partitions: StoragePartitions,
) -> None:
self.write_artifact_partitions(artifact, partitions)
self.write_snapshot_partitions(snapshot, artifact_key, artifact, partitions)
write_artifact_partitions
def write_artifact_partitions(
self,
artifact: 'Artifact',
partitions: 'StoragePartitions'
) -> 'None'
Add more partitions for a Storage spec.
View Source
def write_artifact_partitions(self, artifact: Artifact, partitions: StoragePartitions) -> None:
self.container.storage_partitions[artifact.storage].update(
_ensure_fingerprinted(partitions)
)
write_graph
def write_graph(
self,
graph: 'Graph'
) -> 'None'
Write the Graph and all linked Artifacts and Producers to the database.
View Source
def write_graph(self, graph: Graph) -> None:
self.container.graphs[graph.name][graph.fingerprint] = graph
write_snapshot
def write_snapshot(
self,
snapshot: 'GraphSnapshot'
) -> 'None'
Write the GraphSnapshot to the database.
View Source
def write_snapshot(self, snapshot: GraphSnapshot) -> None:
self.container.snapshots[snapshot.name][snapshot.fingerprint] = snapshot
write_snapshot_partitions
def write_snapshot_partitions(
self,
snapshot: 'GraphSnapshot',
artifact_key: 'str',
artifact: 'Artifact',
partitions: 'StoragePartitions'
) -> 'None'
Link the Partitions to the named Artifact in a specific GraphSnapshot.
View Source
def write_snapshot_partitions(
self,
snapshot: GraphSnapshot,
artifact_key: str,
artifact: Artifact,
partitions: StoragePartitions,
) -> None:
self.container.snapshot_partitions[snapshot][artifact_key].update(
_ensure_fingerprinted(partitions)
)
write_snapshot_tag
def write_snapshot_tag(
self,
snapshot: 'GraphSnapshot',
tag: 'str',
overwrite: 'bool' = False
) -> 'None'
Read the known Partitions for the named Artifact in a specific GraphSnapshot.
View Source
def write_snapshot_tag(
self, snapshot: GraphSnapshot, tag: str, overwrite: bool = False
) -> None:
"""Read the known Partitions for the named Artifact in a specific GraphSnapshot."""
if (
existing := self.container.snapshot_tags[snapshot.name].get(tag)
) is not None and not overwrite:
raise ValueError(
f"Existing `{tag}` tag for Graph `{snapshot.name}` points to {existing}"
)
self.container.snapshot_tags[snapshot.name][tag] = snapshot