Skip to content

Module arti.backends.memory

None

None

View Source
from __future__ import annotations

from collections import defaultdict

from collections.abc import Iterator

from contextlib import contextmanager

from functools import partial

from pydantic import PrivateAttr

from arti import (

    Artifact,

    Backend,

    Connection,

    Fingerprint,

    Graph,

    GraphSnapshot,

    InputFingerprints,

    Storage,

    StoragePartition,

    StoragePartitions,

)

from arti.internal.utils import NoCopyMixin

def _ensure_fingerprinted(partitions: StoragePartitions) -> Iterator[StoragePartition]:

    for partition in partitions:

        yield partition.with_content_fingerprint(keep_existing=True)

_Graphs = dict[str, dict[Fingerprint, Graph]]

_GraphSnapshots = dict[str, dict[Fingerprint, GraphSnapshot]]

_GraphSnapshotPartitions = dict[

    GraphSnapshot, dict[str, set[StoragePartition]]

]  # ...[snapshot][artifact_key]

_SnapshotTags = dict[str, dict[str, GraphSnapshot]]  # ...[name][tag]

_StoragePartitions = dict[Storage[StoragePartition], set[StoragePartition]]

class _NoCopyContainer(NoCopyMixin):

    """Container for MemoryBackend data that bypasses (deep)copying.

    The MemoryBackend is *intended* to be stateful, like a connection to an external database in

    other backends. However, we usually prefer immutable data structures and Pydantic models, which

    (deep)copy often. If we were to (deep)copy these data structures, then we wouldn't be able to

    track metadata between steps. Instead, this container holds the state and skips (deep)copying.

    We may also add threading locks around access (with some slight usage changes).

    """

    def __init__(self) -> None:

        # NOTE: lambdas are not pickleable, so use partial for any nested defaultdicts.

        self.graphs: _Graphs = defaultdict(dict)

        self.snapshots: _GraphSnapshots = defaultdict(dict)

        # `container.snapshot_partitions` tracks all the partitions for a *specific* GraphSnapshot.

        # `container.storage_partitions` tracks all partitions, across all snapshots. This

        # separation is important to allow for Literals to be used even after a snapshot change.

        self.snapshot_partitions: _GraphSnapshotPartitions = defaultdict(

            partial(defaultdict, set[StoragePartition])  # type: ignore[arg-type]

        )

        self.snapshot_tags: _SnapshotTags = defaultdict(dict)

        self.storage_partitions: _StoragePartitions = defaultdict(set[StoragePartition])

class MemoryConnection(Connection):

    def __init__(self, container: _NoCopyContainer) -> None:

        self.container = container

    def read_artifact_partitions(

        self, artifact: Artifact, input_fingerprints: InputFingerprints = InputFingerprints()

    ) -> StoragePartitions:

        # The MemoryBackend is (obviously) not persistent, so there may be external data we don't

        # know about. If we haven't seen this storage before, we'll attempt to "warm" the cache.

        if artifact.storage not in self.container.storage_partitions:

            self.write_artifact_partitions(

                artifact, artifact.storage.discover_partitions(input_fingerprints)

            )

        partitions = self.container.storage_partitions[artifact.storage]

        if input_fingerprints:

            partitions = {

                partition

                for partition in partitions

                if input_fingerprints.get(partition.keys) == partition.input_fingerprint

            }

        return tuple(partitions)

    def write_artifact_partitions(self, artifact: Artifact, partitions: StoragePartitions) -> None:

        self.container.storage_partitions[artifact.storage].update(

            _ensure_fingerprinted(partitions)

        )

    def read_graph(self, name: str, fingerprint: Fingerprint) -> Graph:

        return self.container.graphs[name][fingerprint]

    def write_graph(self, graph: Graph) -> None:

        self.container.graphs[graph.name][graph.fingerprint] = graph

    def read_snapshot(self, name: str, fingerprint: Fingerprint) -> GraphSnapshot:

        return self.container.snapshots[name][fingerprint]

    def write_snapshot(self, snapshot: GraphSnapshot) -> None:

        self.container.snapshots[snapshot.name][snapshot.fingerprint] = snapshot

    def read_snapshot_tag(self, name: str, tag: str) -> GraphSnapshot:

        if tag not in self.container.snapshot_tags[name]:

            raise ValueError(f"No known `{tag}` tag for GraphSnapshot `{name}`")

        return self.container.snapshot_tags[name][tag]

    def write_snapshot_tag(

        self, snapshot: GraphSnapshot, tag: str, overwrite: bool = False

    ) -> None:

        """Read the known Partitions for the named Artifact in a specific GraphSnapshot."""

        if (

            existing := self.container.snapshot_tags[snapshot.name].get(tag)

        ) is not None and not overwrite:

            raise ValueError(

                f"Existing `{tag}` tag for Graph `{snapshot.name}` points to {existing}"

            )

        self.container.snapshot_tags[snapshot.name][tag] = snapshot

    def read_snapshot_partitions(

        self, snapshot: GraphSnapshot, artifact_key: str, artifact: Artifact

    ) -> StoragePartitions:

        return tuple(self.container.snapshot_partitions[snapshot][artifact_key])

    def write_snapshot_partitions(

        self,

        snapshot: GraphSnapshot,

        artifact_key: str,

        artifact: Artifact,

        partitions: StoragePartitions,

    ) -> None:

        self.container.snapshot_partitions[snapshot][artifact_key].update(

            _ensure_fingerprinted(partitions)

        )

class MemoryBackend(Backend[MemoryConnection]):

    _container: _NoCopyContainer = PrivateAttr(default_factory=_NoCopyContainer)

    @contextmanager

    def connect(self) -> Iterator[MemoryConnection]:

        yield MemoryConnection(self._container)

Classes

MemoryBackend

class MemoryBackend(
    __pydantic_self__,
    **data: Any
)
View Source
class MemoryBackend(Backend[MemoryConnection]):

    _container: _NoCopyContainer = PrivateAttr(default_factory=_NoCopyContainer)

    @contextmanager

    def connect(self) -> Iterator[MemoryConnection]:

        yield MemoryConnection(self._container)

Ancestors (in MRO)

  • arti.backends.Backend
  • arti.internal.models.Model
  • pydantic.main.BaseModel
  • pydantic.utils.Representation
  • typing.Generic

Class variables

Config

Static methods

construct

def construct(
    _fields_set: Optional[ForwardRef('SetStr')] = None,
    **values: Any
) -> 'Model'

Creates a new model setting dict and fields_set from trusted or pre-validated data.

Default values are respected, but no other validation is performed. Behaves as if Config.extra = 'allow' was set since it adds all passed values

from_orm

def from_orm(
    obj: Any
) -> 'Model'

parse_file

def parse_file(
    path: Union[str, pathlib.Path],
    *,
    content_type: 'unicode' = None,
    encoding: 'unicode' = 'utf8',
    proto: pydantic.parse.Protocol = None,
    allow_pickle: bool = False
) -> 'Model'

parse_obj

def parse_obj(
    obj: Any
) -> 'Model'

parse_raw

def parse_raw(
    b: Union[str, bytes],
    *,
    content_type: 'unicode' = None,
    encoding: 'unicode' = 'utf8',
    proto: pydantic.parse.Protocol = None,
    allow_pickle: bool = False
) -> 'Model'

schema

def schema(
    by_alias: bool = True,
    ref_template: 'unicode' = '#/definitions/{model}'
) -> 'DictStrAny'

schema_json

def schema_json(
    *,
    by_alias: bool = True,
    ref_template: 'unicode' = '#/definitions/{model}',
    **dumps_kwargs: Any
) -> 'unicode'

update_forward_refs

def update_forward_refs(
    **localns: Any
) -> None

Try to update ForwardRefs on fields based on this Model, globalns and localns.

validate

def validate(
    value: Any
) -> 'Model'

Instance variables

fingerprint

Methods

connect

def connect(
    self
) -> 'Iterator[MemoryConnection]'
View Source
    @contextmanager

    def connect(self) -> Iterator[MemoryConnection]:

        yield MemoryConnection(self._container)

copy

def copy(
    self,
    *,
    deep: 'bool' = False,
    validate: 'bool' = True,
    **kwargs: 'Any'
) -> 'Self'

Duplicate a model, optionally choose which fields to include, exclude and change.

Parameters:

Name Type Description Default
include None fields to include in new model None
exclude None fields to exclude from new model, as with values this takes precedence over include None
update None values to change/add in the new model. Note: the data is not validated before creating
the new model: you should trust this data None
deep None set to True to make a deep copy of the model None

Returns:

Type Description
None new model instance
View Source
    def copy(self, *, deep: bool = False, validate: bool = True, **kwargs: Any) -> Self:

        copy = super().copy(deep=deep, **kwargs)

        if validate:

            # NOTE: We set exclude_unset=False so that all existing defaulted fields are reused (as

            # is normal `.copy` behavior).

            #

            # To reduce `repr` noise, we'll reset .__fields_set__ to those of the pre-validation copy

            # (which includes those originally set + updated).

            fields_set = copy.__fields_set__

            copy = copy.validate(

                dict(copy._iter(to_dict=False, by_alias=False, exclude_unset=False))

            )

            # Use object.__setattr__ to bypass frozen model assignment errors

            object.__setattr__(copy, "__fields_set__", set(fields_set))

            # Copy over the private attributes, which are missing after validation (since we're only

            # passing the fields).

            for name in self.__private_attributes__:

                if (value := getattr(self, name, Undefined)) is not Undefined:

                    if deep:

                        value = deepcopy(value)

                    object.__setattr__(copy, name, value)

        return copy

dict

def dict(
    self,
    *,
    include: Union[ForwardRef('AbstractSetIntStr'), ForwardRef('MappingIntStrAny'), NoneType] = None,
    exclude: Union[ForwardRef('AbstractSetIntStr'), ForwardRef('MappingIntStrAny'), NoneType] = None,
    by_alias: bool = False,
    skip_defaults: Optional[bool] = None,
    exclude_unset: bool = False,
    exclude_defaults: bool = False,
    exclude_none: bool = False
) -> 'DictStrAny'

Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.

json

def json(
    self,
    *,
    include: Union[ForwardRef('AbstractSetIntStr'), ForwardRef('MappingIntStrAny'), NoneType] = None,
    exclude: Union[ForwardRef('AbstractSetIntStr'), ForwardRef('MappingIntStrAny'), NoneType] = None,
    by_alias: bool = False,
    skip_defaults: Optional[bool] = None,
    exclude_unset: bool = False,
    exclude_defaults: bool = False,
    exclude_none: bool = False,
    encoder: Optional[Callable[[Any], Any]] = None,
    models_as_dict: bool = True,
    **dumps_kwargs: Any
) -> 'unicode'

Generate a JSON representation of the model, include and exclude arguments as per dict().

encoder is an optional function to supply as default to json.dumps(), other arguments as per json.dumps().

MemoryConnection

class MemoryConnection(
    container: '_NoCopyContainer'
)
View Source
class MemoryConnection(Connection):

    def __init__(self, container: _NoCopyContainer) -> None:

        self.container = container

    def read_artifact_partitions(

        self, artifact: Artifact, input_fingerprints: InputFingerprints = InputFingerprints()

    ) -> StoragePartitions:

        # The MemoryBackend is (obviously) not persistent, so there may be external data we don't

        # know about. If we haven't seen this storage before, we'll attempt to "warm" the cache.

        if artifact.storage not in self.container.storage_partitions:

            self.write_artifact_partitions(

                artifact, artifact.storage.discover_partitions(input_fingerprints)

            )

        partitions = self.container.storage_partitions[artifact.storage]

        if input_fingerprints:

            partitions = {

                partition

                for partition in partitions

                if input_fingerprints.get(partition.keys) == partition.input_fingerprint

            }

        return tuple(partitions)

    def write_artifact_partitions(self, artifact: Artifact, partitions: StoragePartitions) -> None:

        self.container.storage_partitions[artifact.storage].update(

            _ensure_fingerprinted(partitions)

        )

    def read_graph(self, name: str, fingerprint: Fingerprint) -> Graph:

        return self.container.graphs[name][fingerprint]

    def write_graph(self, graph: Graph) -> None:

        self.container.graphs[graph.name][graph.fingerprint] = graph

    def read_snapshot(self, name: str, fingerprint: Fingerprint) -> GraphSnapshot:

        return self.container.snapshots[name][fingerprint]

    def write_snapshot(self, snapshot: GraphSnapshot) -> None:

        self.container.snapshots[snapshot.name][snapshot.fingerprint] = snapshot

    def read_snapshot_tag(self, name: str, tag: str) -> GraphSnapshot:

        if tag not in self.container.snapshot_tags[name]:

            raise ValueError(f"No known `{tag}` tag for GraphSnapshot `{name}`")

        return self.container.snapshot_tags[name][tag]

    def write_snapshot_tag(

        self, snapshot: GraphSnapshot, tag: str, overwrite: bool = False

    ) -> None:

        """Read the known Partitions for the named Artifact in a specific GraphSnapshot."""

        if (

            existing := self.container.snapshot_tags[snapshot.name].get(tag)

        ) is not None and not overwrite:

            raise ValueError(

                f"Existing `{tag}` tag for Graph `{snapshot.name}` points to {existing}"

            )

        self.container.snapshot_tags[snapshot.name][tag] = snapshot

    def read_snapshot_partitions(

        self, snapshot: GraphSnapshot, artifact_key: str, artifact: Artifact

    ) -> StoragePartitions:

        return tuple(self.container.snapshot_partitions[snapshot][artifact_key])

    def write_snapshot_partitions(

        self,

        snapshot: GraphSnapshot,

        artifact_key: str,

        artifact: Artifact,

        partitions: StoragePartitions,

    ) -> None:

        self.container.snapshot_partitions[snapshot][artifact_key].update(

            _ensure_fingerprinted(partitions)

        )

Ancestors (in MRO)

  • arti.backends.Connection

Methods

connect

def connect(
    self
) -> 'Iterator[Self]'

Return self

This makes it easier to work with an Optional connection, eg: with (connection or backend).connect() as conn: ...

View Source
    @contextmanager

    def connect(self) -> Iterator[Self]:

        """Return self

        This makes it easier to work with an Optional connection, eg:

            with (connection or backend).connect() as conn:

                ...

        """

        yield self

read_artifact_partitions

def read_artifact_partitions(
    self,
    artifact: 'Artifact',
    input_fingerprints: 'InputFingerprints' = {}
) -> 'StoragePartitions'

Read all known Partitions for this Storage spec.

If input_fingerprints is provided, the returned partitions will be filtered accordingly.

NOTE: The returned partitions may not be associated with any particular Graph, unless input_fingerprints is provided matching those for a GraphSnapshot.

View Source
    def read_artifact_partitions(

        self, artifact: Artifact, input_fingerprints: InputFingerprints = InputFingerprints()

    ) -> StoragePartitions:

        # The MemoryBackend is (obviously) not persistent, so there may be external data we don't

        # know about. If we haven't seen this storage before, we'll attempt to "warm" the cache.

        if artifact.storage not in self.container.storage_partitions:

            self.write_artifact_partitions(

                artifact, artifact.storage.discover_partitions(input_fingerprints)

            )

        partitions = self.container.storage_partitions[artifact.storage]

        if input_fingerprints:

            partitions = {

                partition

                for partition in partitions

                if input_fingerprints.get(partition.keys) == partition.input_fingerprint

            }

        return tuple(partitions)

read_graph

def read_graph(
    self,
    name: 'str',
    fingerprint: 'Fingerprint'
) -> 'Graph'

Fetch an instance of the named Graph.

View Source
    def read_graph(self, name: str, fingerprint: Fingerprint) -> Graph:

        return self.container.graphs[name][fingerprint]

read_snapshot

def read_snapshot(
    self,
    name: 'str',
    fingerprint: 'Fingerprint'
) -> 'GraphSnapshot'

Fetch an instance of the named GraphSnapshot.

View Source
    def read_snapshot(self, name: str, fingerprint: Fingerprint) -> GraphSnapshot:

        return self.container.snapshots[name][fingerprint]

read_snapshot_partitions

def read_snapshot_partitions(
    self,
    snapshot: 'GraphSnapshot',
    artifact_key: 'str',
    artifact: 'Artifact'
) -> 'StoragePartitions'

Read the known Partitions for the named Artifact in a specific GraphSnapshot.

View Source
    def read_snapshot_partitions(

        self, snapshot: GraphSnapshot, artifact_key: str, artifact: Artifact

    ) -> StoragePartitions:

        return tuple(self.container.snapshot_partitions[snapshot][artifact_key])

read_snapshot_tag

def read_snapshot_tag(
    self,
    name: 'str',
    tag: 'str'
) -> 'GraphSnapshot'

Fetch the GraphSnapshot for the named tag.

View Source
    def read_snapshot_tag(self, name: str, tag: str) -> GraphSnapshot:

        if tag not in self.container.snapshot_tags[name]:

            raise ValueError(f"No known `{tag}` tag for GraphSnapshot `{name}`")

        return self.container.snapshot_tags[name][tag]

write_artifact_and_graph_partitions

def write_artifact_and_graph_partitions(
    self,
    snapshot: 'GraphSnapshot',
    artifact_key: 'str',
    artifact: 'Artifact',
    partitions: 'StoragePartitions'
) -> 'None'
View Source
    def write_artifact_and_graph_partitions(

        self,

        snapshot: GraphSnapshot,

        artifact_key: str,

        artifact: Artifact,

        partitions: StoragePartitions,

    ) -> None:

        self.write_artifact_partitions(artifact, partitions)

        self.write_snapshot_partitions(snapshot, artifact_key, artifact, partitions)

write_artifact_partitions

def write_artifact_partitions(
    self,
    artifact: 'Artifact',
    partitions: 'StoragePartitions'
) -> 'None'

Add more partitions for a Storage spec.

View Source
    def write_artifact_partitions(self, artifact: Artifact, partitions: StoragePartitions) -> None:

        self.container.storage_partitions[artifact.storage].update(

            _ensure_fingerprinted(partitions)

        )

write_graph

def write_graph(
    self,
    graph: 'Graph'
) -> 'None'

Write the Graph and all linked Artifacts and Producers to the database.

View Source
    def write_graph(self, graph: Graph) -> None:

        self.container.graphs[graph.name][graph.fingerprint] = graph

write_snapshot

def write_snapshot(
    self,
    snapshot: 'GraphSnapshot'
) -> 'None'

Write the GraphSnapshot to the database.

View Source
    def write_snapshot(self, snapshot: GraphSnapshot) -> None:

        self.container.snapshots[snapshot.name][snapshot.fingerprint] = snapshot

write_snapshot_partitions

def write_snapshot_partitions(
    self,
    snapshot: 'GraphSnapshot',
    artifact_key: 'str',
    artifact: 'Artifact',
    partitions: 'StoragePartitions'
) -> 'None'

Link the Partitions to the named Artifact in a specific GraphSnapshot.

View Source
    def write_snapshot_partitions(

        self,

        snapshot: GraphSnapshot,

        artifact_key: str,

        artifact: Artifact,

        partitions: StoragePartitions,

    ) -> None:

        self.container.snapshot_partitions[snapshot][artifact_key].update(

            _ensure_fingerprinted(partitions)

        )

write_snapshot_tag

def write_snapshot_tag(
    self,
    snapshot: 'GraphSnapshot',
    tag: 'str',
    overwrite: 'bool' = False
) -> 'None'

Read the known Partitions for the named Artifact in a specific GraphSnapshot.

View Source
    def write_snapshot_tag(

        self, snapshot: GraphSnapshot, tag: str, overwrite: bool = False

    ) -> None:

        """Read the known Partitions for the named Artifact in a specific GraphSnapshot."""

        if (

            existing := self.container.snapshot_tags[snapshot.name].get(tag)

        ) is not None and not overwrite:

            raise ValueError(

                f"Existing `{tag}` tag for Graph `{snapshot.name}` points to {existing}"

            )

        self.container.snapshot_tags[snapshot.name][tag] = snapshot