Module arti.internal.dispatch
None
None
View Source
from __future__ import annotations
import inspect
from collections.abc import Callable
from typing import Any, Optional, TypeVar, cast, overload
import multimethod as _multimethod # Minimize name confusion
from arti.internal.type_hints import lenient_issubclass, tidy_signature
RETURN = TypeVar("RETURN")
REGISTERED = TypeVar("REGISTERED", bound=Callable[..., Any])
# This may be less useful once mypy supports ParamSpecs - after that, we might be able to define
# multidispatch with a ParamSpec and have mypy check the handlers' arguments are covariant.
class _multipledispatch(_multimethod.multidispatch[RETURN]):
"""Multiple dispatch for a set of functions based on parameter type.
Usage is similar to `@functools.singledispatch`. The original definition defines the "spec" that
subsequent handlers must follow, namely the name and (base)class of parameters.
"""
# NOTE: We can't add extra (kw)args without also overriding __new__. However, `__new__` is
# called for each *registered* func in the multimethod internals (a bit confusing). Instead, we
# can just set attrs in a helper func below.
def __init__(self, func: Callable[..., RETURN]) -> None:
super().__init__(func)
self.canonical_name: Optional[str] = None
self.discovery_func: Optional[Callable[[], None]] = None
assert self.signature is not None
self.clean_signature = tidy_signature(func, self.signature)
def __missing__(self, types: tuple[Any, ...]) -> Callable[..., RETURN]:
if self.discovery_func is not None:
self.discovery_func()
return super().__missing__(types)
def lookup(self, *args: Optional[type[Any]]) -> Callable[..., Any]:
# multimethod wraps Generics (eg: `list[int]`) with an internal helper. We must do the same
# before looking up. Non-Generics pass through as is.
args = tuple(_multimethod.subtype(arg) for arg in args) # type: ignore[no-untyped-call]
# NOTE: multimethod doesn't override __contains__ (likely so __missing__ will still run), so
# "args in self" will be False when using subclasses of any arg.
missing_error = ValueError(f"No `{self.canonical_name}` implementation found for: {args}")
try:
handler = cast(Callable[..., Any], self[args])
# multimethod raises a TypeError instead of KeyError, as __call__.
except TypeError as e: # pragma: no cover
raise missing_error from e
# Filter out the base "NotImplementedError" handler.
if getattr(handler, "_abstract_", False):
raise missing_error
return handler
@overload
def register(self, __func: REGISTERED) -> REGISTERED:
...
@overload
def register(self, *args: type) -> Callable[[REGISTERED], REGISTERED]:
...
def register(self, *args: Any) -> Callable[[REGISTERED], REGISTERED]:
if len(args) == 1 and hasattr(args[0], "__annotations__"):
func = args[0]
sig = tidy_signature(func, inspect.signature(func))
spec = self.clean_signature
if set(sig.parameters) != set(spec.parameters):
raise TypeError(
f"Expected `{func.__name__}` to have {sorted(set(spec.parameters))} parameters, got {sorted(set(sig.parameters))}"
)
for name in sig.parameters:
sig_param, spec_param = sig.parameters[name], spec.parameters[name]
if sig_param.kind != spec_param.kind:
raise TypeError(
f"Expected the `{func.__name__}.{name}` parameter to be {spec_param.kind}, got {sig_param.kind}"
)
if sig_param.annotation is not Any and not lenient_issubclass(
sig_param.annotation, spec_param.annotation
):
raise TypeError(
f"Expected the `{func.__name__}.{name}` parameter to be a subclass of {spec_param.annotation}, got {sig_param.annotation}"
)
if not lenient_issubclass(sig.return_annotation, spec.return_annotation):
raise TypeError(
f"Expected the `{func.__name__}` return to match {spec.return_annotation}, got {sig.return_annotation}"
)
return cast(Callable[..., Any], super().register(*args))
def multipledispatch(
canonical_name: str, *, discovery_func: Optional[Callable[[], None]] = None
) -> Callable[[Callable[..., RETURN]], _multipledispatch[RETURN]]:
def wrap(func: Callable[..., RETURN]) -> _multipledispatch[RETURN]:
# The base handler is expected to `raise NotImplementedError`
func._abstract_ = True # type: ignore[attr-defined]
dispatch = _multipledispatch(func)
dispatch.canonical_name = canonical_name
dispatch.discovery_func = discovery_func
return dispatch
return wrap
Variables
REGISTERED
RETURN
Functions
multipledispatch
def multipledispatch(
canonical_name: 'str',
*,
discovery_func: 'Optional[Callable[[], None]]' = None
) -> 'Callable[[Callable[..., RETURN]], _multipledispatch[RETURN]]'
View Source
def multipledispatch(
canonical_name: str, *, discovery_func: Optional[Callable[[], None]] = None
) -> Callable[[Callable[..., RETURN]], _multipledispatch[RETURN]]:
def wrap(func: Callable[..., RETURN]) -> _multipledispatch[RETURN]:
# The base handler is expected to `raise NotImplementedError`
func._abstract_ = True # type: ignore[attr-defined]
dispatch = _multipledispatch(func)
dispatch.canonical_name = canonical_name
dispatch.discovery_func = discovery_func
return dispatch
return wrap