automated terminal push

This commit is contained in:
lenape
2025-06-27 16:06:02 +00:00
parent 511dd3b36b
commit 6a954eb013
4221 changed files with 2916190 additions and 1 deletions

View File

@@ -0,0 +1,55 @@
from typing import Final
from .converters import BaseConverter, Converter, GenConverter, UnstructureStrategy
from .errors import (
AttributeValidationNote,
BaseValidationError,
ClassValidationError,
ForbiddenExtraKeysError,
IterableValidationError,
IterableValidationNote,
StructureHandlerNotFoundError,
)
from .gen import override
from .v import transform_error
__all__ = [
"structure",
"unstructure",
"get_structure_hook",
"get_unstructure_hook",
"register_structure_hook_func",
"register_structure_hook",
"register_unstructure_hook_func",
"register_unstructure_hook",
"structure_attrs_fromdict",
"structure_attrs_fromtuple",
"global_converter",
"BaseConverter",
"Converter",
"AttributeValidationNote",
"BaseValidationError",
"ClassValidationError",
"ForbiddenExtraKeysError",
"GenConverter",
"IterableValidationError",
"IterableValidationNote",
"override",
"StructureHandlerNotFoundError",
"transform_error",
"UnstructureStrategy",
]
#: The global converter. Prefer creating your own if customizations are required.
global_converter: Final = Converter()
unstructure = global_converter.unstructure
structure = global_converter.structure
structure_attrs_fromtuple = global_converter.structure_attrs_fromtuple
structure_attrs_fromdict = global_converter.structure_attrs_fromdict
register_structure_hook = global_converter.register_structure_hook
register_structure_hook_func = global_converter.register_structure_hook_func
register_unstructure_hook = global_converter.register_unstructure_hook
register_unstructure_hook_func = global_converter.register_unstructure_hook_func
get_structure_hook: Final = global_converter.get_structure_hook
get_unstructure_hook: Final = global_converter.get_unstructure_hook

View File

@@ -0,0 +1,579 @@
import sys
from collections import deque
from collections.abc import Mapping as AbcMapping
from collections.abc import MutableMapping as AbcMutableMapping
from collections.abc import MutableSet as AbcMutableSet
from collections.abc import Set as AbcSet
from dataclasses import MISSING, Field, is_dataclass
from dataclasses import fields as dataclass_fields
from functools import partial
from inspect import signature as _signature
from typing import AbstractSet as TypingAbstractSet
from typing import (
Any,
Deque,
Dict,
Final,
FrozenSet,
List,
Literal,
NewType,
Optional,
Protocol,
Tuple,
Type,
Union,
get_args,
get_origin,
get_type_hints,
)
from typing import Mapping as TypingMapping
from typing import MutableMapping as TypingMutableMapping
from typing import MutableSequence as TypingMutableSequence
from typing import MutableSet as TypingMutableSet
from typing import Sequence as TypingSequence
from typing import Set as TypingSet
from attrs import NOTHING, Attribute, Factory, resolve_types
from attrs import fields as attrs_fields
from attrs import fields_dict as attrs_fields_dict
__all__ = [
"ANIES",
"adapted_fields",
"fields_dict",
"ExceptionGroup",
"ExtensionsTypedDict",
"get_type_alias_base",
"has",
"is_type_alias",
"is_typeddict",
"TypeAlias",
"TypedDict",
]
try:
from typing_extensions import TypedDict as ExtensionsTypedDict
except ImportError: # pragma: no cover
ExtensionsTypedDict = None
if sys.version_info >= (3, 11):
from builtins import ExceptionGroup
else:
from exceptiongroup import ExceptionGroup
try:
from typing_extensions import is_typeddict as _is_typeddict
except ImportError: # pragma: no cover
assert sys.version_info >= (3, 10)
from typing import is_typeddict as _is_typeddict
try:
from typing_extensions import TypeAlias
except ImportError: # pragma: no cover
assert sys.version_info >= (3, 11)
from typing import TypeAlias
LITERALS = {Literal}
try:
from typing_extensions import Literal as teLiteral
LITERALS.add(teLiteral)
except ImportError: # pragma: no cover
pass
# On some Python versions, `typing_extensions.Any` is different than
# `typing.Any`.
try:
from typing_extensions import Any as teAny
ANIES = frozenset([Any, teAny])
except ImportError: # pragma: no cover
ANIES = frozenset([Any])
NoneType = type(None)
def is_optional(typ: Type) -> bool:
return is_union_type(typ) and NoneType in typ.__args__ and len(typ.__args__) == 2
def is_typeddict(cls):
"""Thin wrapper around typing(_extensions).is_typeddict"""
return _is_typeddict(getattr(cls, "__origin__", cls))
def is_type_alias(type: Any) -> bool:
"""Is this a PEP 695 type alias?"""
return False
def get_type_alias_base(type: Any) -> Any:
"""
What is this a type alias of?
Works only on 3.12+.
"""
return type.__value__
def has(cls):
return hasattr(cls, "__attrs_attrs__") or hasattr(cls, "__dataclass_fields__")
def has_with_generic(cls):
"""Test whether the class if a normal or generic attrs or dataclass."""
return has(cls) or has(get_origin(cls))
def fields(type):
try:
return type.__attrs_attrs__
except AttributeError:
return dataclass_fields(type)
def fields_dict(type) -> Dict[str, Union[Attribute, Field]]:
"""Return the fields_dict for attrs and dataclasses."""
if is_dataclass(type):
return {f.name: f for f in dataclass_fields(type)}
return attrs_fields_dict(type)
def adapted_fields(cl) -> List[Attribute]:
"""Return the attrs format of `fields()` for attrs and dataclasses."""
if is_dataclass(cl):
attrs = dataclass_fields(cl)
if any(isinstance(a.type, str) for a in attrs):
# Do this conditionally in case `get_type_hints` fails, so
# users can resolve on their own first.
type_hints = get_type_hints(cl)
else:
type_hints = {}
return [
Attribute(
attr.name,
(
attr.default
if attr.default is not MISSING
else (
Factory(attr.default_factory)
if attr.default_factory is not MISSING
else NOTHING
)
),
None,
True,
None,
True,
attr.init,
True,
type=type_hints.get(attr.name, attr.type),
alias=attr.name,
kw_only=getattr(attr, "kw_only", False),
)
for attr in attrs
]
attribs = attrs_fields(cl)
if any(isinstance(a.type, str) for a in attribs):
# PEP 563 annotations - need to be resolved.
resolve_types(cl)
attribs = attrs_fields(cl)
return attribs
def is_subclass(obj: type, bases) -> bool:
"""A safe version of issubclass (won't raise)."""
try:
return issubclass(obj, bases)
except TypeError:
return False
def is_hetero_tuple(type: Any) -> bool:
origin = getattr(type, "__origin__", None)
return origin is tuple and ... not in type.__args__
def is_protocol(type: Any) -> bool:
return is_subclass(type, Protocol) and getattr(type, "_is_protocol", False)
def is_bare_final(type) -> bool:
return type is Final
def get_final_base(type) -> Optional[type]:
"""Return the base of the Final annotation, if it is Final."""
if type is Final:
return Any
if type.__class__ is _GenericAlias and type.__origin__ is Final:
return type.__args__[0]
return None
OriginAbstractSet = AbcSet
OriginMutableSet = AbcMutableSet
signature = _signature
if sys.version_info >= (3, 10):
signature = partial(_signature, eval_str=True)
if sys.version_info >= (3, 9):
from collections import Counter
from collections.abc import MutableSequence as AbcMutableSequence
from collections.abc import MutableSet as AbcMutableSet
from collections.abc import Sequence as AbcSequence
from collections.abc import Set as AbcSet
from types import GenericAlias
from typing import (
Annotated,
Generic,
TypedDict,
Union,
_AnnotatedAlias,
_GenericAlias,
_SpecialGenericAlias,
_UnionGenericAlias,
)
from typing import Counter as TypingCounter
try:
# Not present on 3.9.0, so we try carefully.
from typing import _LiteralGenericAlias
def is_literal(type) -> bool:
return type in LITERALS or (
isinstance(
type, (_GenericAlias, _LiteralGenericAlias, _SpecialGenericAlias)
)
and type.__origin__ in LITERALS
)
except ImportError: # pragma: no cover
def is_literal(_) -> bool:
return False
Set = AbcSet
AbstractSet = AbcSet
MutableSet = AbcMutableSet
Sequence = AbcSequence
MutableSequence = AbcMutableSequence
MutableMapping = AbcMutableMapping
Mapping = AbcMapping
FrozenSetSubscriptable = frozenset
TupleSubscriptable = tuple
def is_annotated(type) -> bool:
return getattr(type, "__class__", None) is _AnnotatedAlias
def is_tuple(type):
return (
type in (Tuple, tuple)
or (type.__class__ is _GenericAlias and is_subclass(type.__origin__, Tuple))
or (getattr(type, "__origin__", None) is tuple)
)
if sys.version_info >= (3, 12):
from typing import TypeAliasType
def is_type_alias(type: Any) -> bool:
"""Is this a PEP 695 type alias?"""
return isinstance(type, TypeAliasType)
if sys.version_info >= (3, 10):
def is_union_type(obj):
from types import UnionType
return (
obj is Union
or (isinstance(obj, _UnionGenericAlias) and obj.__origin__ is Union)
or isinstance(obj, UnionType)
)
def get_newtype_base(typ: Any) -> Optional[type]:
if typ is NewType or isinstance(typ, NewType):
return typ.__supertype__
return None
if sys.version_info >= (3, 11):
from typing import NotRequired, Required
else:
from typing_extensions import NotRequired, Required
else:
from typing_extensions import NotRequired, Required
def is_union_type(obj):
return (
obj is Union
or isinstance(obj, _UnionGenericAlias)
and obj.__origin__ is Union
)
def get_newtype_base(typ: Any) -> Optional[type]:
supertype = getattr(typ, "__supertype__", None)
if (
supertype is not None
and getattr(typ, "__qualname__", "") == "NewType.<locals>.new_type"
and typ.__module__ in ("typing", "typing_extensions")
):
return supertype
return None
def get_notrequired_base(type) -> "Union[Any, Literal[NOTHING]]":
if is_annotated(type):
# Handle `Annotated[NotRequired[int]]`
type = get_args(type)[0]
if get_origin(type) in (NotRequired, Required):
return get_args(type)[0]
return NOTHING
def is_sequence(type: Any) -> bool:
"""A predicate function for sequences.
Matches lists, sequences, mutable sequences, deques and homogenous
tuples.
"""
origin = getattr(type, "__origin__", None)
return (
type
in (
List,
list,
TypingSequence,
TypingMutableSequence,
AbcMutableSequence,
tuple,
Tuple,
deque,
Deque,
)
or (
type.__class__ is _GenericAlias
and (
(origin is not tuple)
and is_subclass(origin, TypingSequence)
or origin is tuple
and type.__args__[1] is ...
)
)
or (origin in (list, deque, AbcMutableSequence, AbcSequence))
or (origin is tuple and type.__args__[1] is ...)
)
def is_deque(type):
return (
type in (deque, Deque)
or (type.__class__ is _GenericAlias and is_subclass(type.__origin__, deque))
or (getattr(type, "__origin__", None) is deque)
)
def is_mutable_set(type: Any) -> bool:
"""A predicate function for (mutable) sets.
Matches built-in sets and sets from the typing module.
"""
return (
type in (TypingSet, TypingMutableSet, set)
or (
type.__class__ is _GenericAlias
and is_subclass(type.__origin__, TypingMutableSet)
)
or (getattr(type, "__origin__", None) in (set, AbcMutableSet, AbcSet))
)
def is_frozenset(type: Any) -> bool:
"""A predicate function for frozensets.
Matches built-in frozensets and frozensets from the typing module.
"""
return (
type in (FrozenSet, frozenset)
or (
type.__class__ is _GenericAlias
and is_subclass(type.__origin__, FrozenSet)
)
or (getattr(type, "__origin__", None) is frozenset)
)
def is_bare(type):
return isinstance(type, _SpecialGenericAlias) or (
not hasattr(type, "__origin__") and not hasattr(type, "__args__")
)
def is_mapping(type: Any) -> bool:
"""A predicate function for mappings."""
return (
type in (dict, Dict, TypingMapping, TypingMutableMapping, AbcMutableMapping)
or (
type.__class__ is _GenericAlias
and is_subclass(type.__origin__, TypingMapping)
)
or is_subclass(
getattr(type, "__origin__", type), (dict, AbcMutableMapping, AbcMapping)
)
)
def is_counter(type):
return (
type in (Counter, TypingCounter)
or getattr(type, "__origin__", None) is Counter
)
def is_generic(type) -> bool:
"""Whether `type` is a generic type."""
# Inheriting from protocol will inject `Generic` into the MRO
# without `__orig_bases__`.
return isinstance(type, (_GenericAlias, GenericAlias)) or (
is_subclass(type, Generic) and hasattr(type, "__orig_bases__")
)
def copy_with(type, args):
"""Replace a generic type's arguments."""
if is_annotated(type):
# typing.Annotated requires a special case.
return Annotated[args]
if isinstance(args, tuple) and len(args) == 1:
# Some annotations can't handle 1-tuples.
args = args[0]
return type.__origin__[args]
def get_full_type_hints(obj, globalns=None, localns=None):
return get_type_hints(obj, globalns, localns, include_extras=True)
else:
# 3.8
Set = TypingSet
AbstractSet = TypingAbstractSet
MutableSet = TypingMutableSet
Sequence = TypingSequence
MutableSequence = TypingMutableSequence
MutableMapping = TypingMutableMapping
Mapping = TypingMapping
FrozenSetSubscriptable = FrozenSet
TupleSubscriptable = Tuple
from collections import Counter as ColCounter
from typing import Counter, Generic, TypedDict, Union, _GenericAlias
from typing_extensions import Annotated, NotRequired, Required
from typing_extensions import get_origin as te_get_origin
def is_annotated(type) -> bool:
return te_get_origin(type) is Annotated
def is_tuple(type):
return type in (Tuple, tuple) or (
type.__class__ is _GenericAlias and is_subclass(type.__origin__, Tuple)
)
def is_union_type(obj):
return (
obj is Union or isinstance(obj, _GenericAlias) and obj.__origin__ is Union
)
def get_newtype_base(typ: Any) -> Optional[type]:
supertype = getattr(typ, "__supertype__", None)
if (
supertype is not None
and getattr(typ, "__qualname__", "") == "NewType.<locals>.new_type"
and typ.__module__ in ("typing", "typing_extensions")
):
return supertype
return None
def is_sequence(type: Any) -> bool:
return type in (List, list, Tuple, tuple) or (
type.__class__ is _GenericAlias
and (
type.__origin__ not in (Union, Tuple, tuple)
and is_subclass(type.__origin__, TypingSequence)
)
or (type.__origin__ in (Tuple, tuple) and type.__args__[1] is ...)
)
def is_deque(type: Any) -> bool:
return (
type in (deque, Deque)
or (type.__class__ is _GenericAlias and is_subclass(type.__origin__, deque))
or type.__origin__ is deque
)
def is_mutable_set(type) -> bool:
return type in (set, TypingAbstractSet) or (
type.__class__ is _GenericAlias
and is_subclass(type.__origin__, (MutableSet, TypingAbstractSet))
)
def is_frozenset(type):
return type is frozenset or (
type.__class__ is _GenericAlias and is_subclass(type.__origin__, FrozenSet)
)
def is_mapping(type: Any) -> bool:
"""A predicate function for mappings."""
return (
type in (TypingMapping, dict)
or (
type.__class__ is _GenericAlias
and is_subclass(type.__origin__, TypingMapping)
)
or is_subclass(
getattr(type, "__origin__", type), (dict, AbcMutableMapping, AbcMapping)
)
)
bare_generic_args = {
List.__args__,
TypingSequence.__args__,
TypingMapping.__args__,
Dict.__args__,
TypingMutableSequence.__args__,
Tuple.__args__,
None, # non-parametrized containers do not have `__args__ attribute in py3.7-8
}
def is_bare(type):
return getattr(type, "__args__", None) in bare_generic_args
def is_counter(type):
return (
type in (Counter, ColCounter)
or getattr(type, "__origin__", None) is ColCounter
)
def is_literal(type) -> bool:
return type in LITERALS or (
isinstance(type, _GenericAlias) and type.__origin__ in LITERALS
)
def is_generic(obj):
return isinstance(obj, _GenericAlias) or (
is_subclass(obj, Generic) and hasattr(obj, "__orig_bases__")
)
def copy_with(type, args):
"""Replace a generic type's arguments."""
return type.copy_with(args)
def get_notrequired_base(type) -> "Union[Any, Literal[NOTHING]]":
if is_annotated(type):
# Handle `Annotated[NotRequired[int]]`
type = get_origin(type)
if get_origin(type) in (NotRequired, Required):
return get_args(type)[0]
return NOTHING
def get_full_type_hints(obj, globalns=None, localns=None):
return get_type_hints(obj, globalns, localns)
def is_generic_attrs(type) -> bool:
"""Return True for both specialized (A[int]) and unspecialized (A) generics."""
return is_generic(type) and has(type.__origin__)

View File

@@ -0,0 +1,24 @@
from typing import Any, Mapping
from ._compat import copy_with, get_args, is_annotated, is_generic
def deep_copy_with(t, mapping: Mapping[str, Any]):
args = get_args(t)
rest = ()
if is_annotated(t) and args:
# If we're dealing with `Annotated`, we only map the first type parameter
rest = tuple(args[1:])
args = (args[0],)
new_args = (
tuple(
(
mapping[a.__name__]
if hasattr(a, "__name__") and a.__name__ in mapping
else (deep_copy_with(a, mapping) if is_generic(a) else a)
)
for a in args
)
+ rest
)
return copy_with(t, new_args) if new_args != args else t

View File

@@ -0,0 +1,289 @@
"""Utility functions for collections."""
from __future__ import annotations
from sys import version_info
from typing import (
TYPE_CHECKING,
Any,
Iterable,
Literal,
NamedTuple,
Tuple,
TypeVar,
get_type_hints,
)
from attrs import NOTHING, Attribute
from ._compat import ANIES, is_bare, is_frozenset, is_mapping, is_sequence, is_subclass
from ._compat import is_mutable_set as is_set
from .dispatch import StructureHook, UnstructureHook
from .errors import IterableValidationError, IterableValidationNote
from .fns import identity
from .gen import (
AttributeOverride,
already_generating,
make_dict_structure_fn_from_attrs,
make_dict_unstructure_fn_from_attrs,
make_hetero_tuple_unstructure_fn,
mapping_structure_factory,
)
from .gen import make_iterable_unstructure_fn as iterable_unstructure_factory
if TYPE_CHECKING:
from .converters import BaseConverter
__all__ = [
"is_any_set",
"is_frozenset",
"is_namedtuple",
"is_mapping",
"is_set",
"is_sequence",
"iterable_unstructure_factory",
"list_structure_factory",
"namedtuple_structure_factory",
"namedtuple_unstructure_factory",
"namedtuple_dict_structure_factory",
"namedtuple_dict_unstructure_factory",
"mapping_structure_factory",
]
def is_any_set(type) -> bool:
"""A predicate function for both mutable and frozensets."""
return is_set(type) or is_frozenset(type)
if version_info[:2] >= (3, 9):
def is_namedtuple(type: Any) -> bool:
"""A predicate function for named tuples."""
if is_subclass(type, tuple):
for cl in type.mro():
orig_bases = cl.__dict__.get("__orig_bases__", ())
if NamedTuple in orig_bases:
return True
return False
else:
def is_namedtuple(type: Any) -> bool:
"""A predicate function for named tuples."""
# This is tricky. It may not be possible for this function to be 100%
# accurate, since it doesn't seem like we can distinguish between tuple
# subclasses and named tuples reliably.
if is_subclass(type, tuple):
for cl in type.mro():
if cl is tuple:
# No point going further.
break
if "_fields" in cl.__dict__:
return True
return False
def _is_passthrough(type: type[tuple], converter: BaseConverter) -> bool:
"""If all fields would be passed through, this class should not be processed
either.
"""
return all(
converter.get_unstructure_hook(t) == identity
for t in type.__annotations__.values()
)
T = TypeVar("T")
def list_structure_factory(type: type, converter: BaseConverter) -> StructureHook:
"""A hook factory for structuring lists.
Converts any given iterable into a list.
"""
if is_bare(type) or type.__args__[0] in ANIES:
def structure_list(obj: Iterable[T], _: type = type) -> list[T]:
return list(obj)
return structure_list
elem_type = type.__args__[0]
try:
handler = converter.get_structure_hook(elem_type)
except RecursionError:
# Break the cycle by using late binding.
handler = converter.structure
if converter.detailed_validation:
def structure_list(
obj: Iterable[T], _: type = type, _handler=handler, _elem_type=elem_type
) -> list[T]:
errors = []
res = []
ix = 0 # Avoid `enumerate` for performance.
for e in obj:
try:
res.append(handler(e, _elem_type))
except Exception as e:
msg = IterableValidationNote(
f"Structuring {type} @ index {ix}", ix, elem_type
)
e.__notes__ = [*getattr(e, "__notes__", []), msg]
errors.append(e)
finally:
ix += 1
if errors:
raise IterableValidationError(
f"While structuring {type!r}", errors, type
)
return res
else:
def structure_list(
obj: Iterable[T], _: type = type, _handler=handler, _elem_type=elem_type
) -> list[T]:
return [_handler(e, _elem_type) for e in obj]
return structure_list
def namedtuple_unstructure_factory(
cl: type[tuple], converter: BaseConverter, unstructure_to: Any = None
) -> UnstructureHook:
"""A hook factory for unstructuring namedtuples.
:param unstructure_to: Force unstructuring to this type, if provided.
"""
if unstructure_to is None and _is_passthrough(cl, converter):
return identity
return make_hetero_tuple_unstructure_fn(
cl,
converter,
unstructure_to=tuple if unstructure_to is None else unstructure_to,
type_args=tuple(cl.__annotations__.values()),
)
def namedtuple_structure_factory(
cl: type[tuple], converter: BaseConverter
) -> StructureHook:
"""A hook factory for structuring namedtuples from iterables."""
# We delegate to the existing infrastructure for heterogenous tuples.
hetero_tuple_type = Tuple[tuple(cl.__annotations__.values())]
base_hook = converter.get_structure_hook(hetero_tuple_type)
return lambda v, _: cl(*base_hook(v, hetero_tuple_type))
def _namedtuple_to_attrs(cl: type[tuple]) -> list[Attribute]:
"""Generate pseudo attributes for a namedtuple."""
return [
Attribute(
name,
cl._field_defaults.get(name, NOTHING),
None,
False,
False,
False,
True,
False,
type=a,
alias=name,
)
for name, a in get_type_hints(cl).items()
]
def namedtuple_dict_structure_factory(
cl: type[tuple],
converter: BaseConverter,
detailed_validation: bool | Literal["from_converter"] = "from_converter",
forbid_extra_keys: bool = False,
use_linecache: bool = True,
/,
**kwargs: AttributeOverride,
) -> StructureHook:
"""A hook factory for hooks structuring namedtuples from dictionaries.
:param forbid_extra_keys: Whether the hook should raise a `ForbiddenExtraKeysError`
if unknown keys are encountered.
:param use_linecache: Whether to store the source code in the Python linecache.
.. versionadded:: 24.1.0
"""
try:
working_set = already_generating.working_set
except AttributeError:
working_set = set()
already_generating.working_set = working_set
else:
if cl in working_set:
raise RecursionError()
working_set.add(cl)
try:
return make_dict_structure_fn_from_attrs(
_namedtuple_to_attrs(cl),
cl,
converter,
_cattrs_forbid_extra_keys=forbid_extra_keys,
_cattrs_use_detailed_validation=detailed_validation,
_cattrs_use_linecache=use_linecache,
**kwargs,
)
finally:
working_set.remove(cl)
if not working_set:
del already_generating.working_set
def namedtuple_dict_unstructure_factory(
cl: type[tuple],
converter: BaseConverter,
omit_if_default: bool = False,
use_linecache: bool = True,
/,
**kwargs: AttributeOverride,
) -> UnstructureHook:
"""A hook factory for hooks unstructuring namedtuples to dictionaries.
:param omit_if_default: When true, attributes equal to their default values
will be omitted in the result dictionary.
:param use_linecache: Whether to store the source code in the Python linecache.
.. versionadded:: 24.1.0
"""
try:
working_set = already_generating.working_set
except AttributeError:
working_set = set()
already_generating.working_set = working_set
if cl in working_set:
raise RecursionError()
working_set.add(cl)
try:
return make_dict_unstructure_fn_from_attrs(
_namedtuple_to_attrs(cl),
cl,
converter,
_cattrs_omit_if_default=omit_if_default,
_cattrs_use_linecache=use_linecache,
**kwargs,
)
finally:
working_set.remove(cl)
if not working_set:
del already_generating.working_set

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,205 @@
"""Utilities for union (sum type) disambiguation."""
from __future__ import annotations
from collections import defaultdict
from dataclasses import MISSING
from functools import reduce
from operator import or_
from typing import TYPE_CHECKING, Any, Callable, Literal, Mapping, Union
from attrs import NOTHING, Attribute, AttrsInstance
from ._compat import (
NoneType,
adapted_fields,
fields_dict,
get_args,
get_origin,
has,
is_literal,
is_union_type,
)
from .gen import AttributeOverride
if TYPE_CHECKING:
from .converters import BaseConverter
__all__ = ["is_supported_union", "create_default_dis_func"]
def is_supported_union(typ: Any) -> bool:
"""Whether the type is a union of attrs classes."""
return is_union_type(typ) and all(
e is NoneType or has(get_origin(e) or e) for e in typ.__args__
)
def create_default_dis_func(
converter: BaseConverter,
*classes: type[AttrsInstance],
use_literals: bool = True,
overrides: (
dict[str, AttributeOverride] | Literal["from_converter"]
) = "from_converter",
) -> Callable[[Mapping[Any, Any]], type[Any] | None]:
"""Given attrs classes or dataclasses, generate a disambiguation function.
The function is based on unique fields without defaults or unique values.
:param use_literals: Whether to try using fields annotated as literals for
disambiguation.
:param overrides: Attribute overrides to apply.
.. versionchanged:: 24.1.0
Dataclasses are now supported.
"""
if len(classes) < 2:
raise ValueError("At least two classes required.")
if overrides == "from_converter":
overrides = [
getattr(converter.get_structure_hook(c), "overrides", {}) for c in classes
]
else:
overrides = [overrides for _ in classes]
# first, attempt for unique values
if use_literals:
# requirements for a discriminator field:
# (... TODO: a single fallback is OK)
# - it must always be enumerated
cls_candidates = [
{
at.name
for at in adapted_fields(get_origin(cl) or cl)
if is_literal(at.type)
}
for cl in classes
]
# literal field names common to all members
discriminators: set[str] = cls_candidates[0]
for possible_discriminators in cls_candidates:
discriminators &= possible_discriminators
best_result = None
best_discriminator = None
for discriminator in discriminators:
# maps Literal values (strings, ints...) to classes
mapping = defaultdict(list)
for cl in classes:
for key in get_args(
fields_dict(get_origin(cl) or cl)[discriminator].type
):
mapping[key].append(cl)
if best_result is None or max(len(v) for v in mapping.values()) <= max(
len(v) for v in best_result.values()
):
best_result = mapping
best_discriminator = discriminator
if (
best_result
and best_discriminator
and max(len(v) for v in best_result.values()) != len(classes)
):
final_mapping = {
k: v[0] if len(v) == 1 else Union[tuple(v)]
for k, v in best_result.items()
}
def dis_func(data: Mapping[Any, Any]) -> type | None:
if not isinstance(data, Mapping):
raise ValueError("Only input mappings are supported.")
return final_mapping[data[best_discriminator]]
return dis_func
# next, attempt for unique keys
# NOTE: This could just as well work with just field availability and not
# uniqueness, returning Unions ... it doesn't do that right now.
cls_and_attrs = [
(cl, *_usable_attribute_names(cl, override))
for cl, override in zip(classes, overrides)
]
# For each class, attempt to generate a single unique required field.
uniq_attrs_dict: dict[str, type] = {}
# We start from classes with the largest number of unique fields
# so we can do easy picks first, making later picks easier.
cls_and_attrs.sort(key=lambda c_a: len(c_a[1]), reverse=True)
fallback = None # If none match, try this.
for cl, cl_reqs, back_map in cls_and_attrs:
# We do not have to consider classes we've already processed, since
# they will have been eliminated by the match dictionary already.
other_classes = [
c_and_a
for c_and_a in cls_and_attrs
if c_and_a[0] is not cl and c_and_a[0] not in uniq_attrs_dict.values()
]
other_reqs = reduce(or_, (c_a[1] for c_a in other_classes), set())
uniq = cl_reqs - other_reqs
# We want a unique attribute with no default.
cl_fields = fields_dict(get_origin(cl) or cl)
for maybe_renamed_attr_name in uniq:
orig_name = back_map[maybe_renamed_attr_name]
if cl_fields[orig_name].default in (NOTHING, MISSING):
break
else:
if fallback is None:
fallback = cl
continue
raise TypeError(f"{cl} has no usable non-default attributes")
uniq_attrs_dict[maybe_renamed_attr_name] = cl
if fallback is None:
def dis_func(data: Mapping[Any, Any]) -> type[AttrsInstance] | None:
if not isinstance(data, Mapping):
raise ValueError("Only input mappings are supported")
for k, v in uniq_attrs_dict.items():
if k in data:
return v
raise ValueError("Couldn't disambiguate")
else:
def dis_func(data: Mapping[Any, Any]) -> type[AttrsInstance] | None:
if not isinstance(data, Mapping):
raise ValueError("Only input mappings are supported")
for k, v in uniq_attrs_dict.items():
if k in data:
return v
return fallback
return dis_func
create_uniq_field_dis_func = create_default_dis_func
def _overriden_name(at: Attribute, override: AttributeOverride | None) -> str:
if override is None or override.rename is None:
return at.name
return override.rename
def _usable_attribute_names(
cl: type[Any], overrides: dict[str, AttributeOverride]
) -> tuple[set[str], dict[str, str]]:
"""Return renamed fields and a mapping to original field names."""
res = set()
mapping = {}
for at in adapted_fields(get_origin(cl) or cl):
res.add(n := _overriden_name(at, overrides.get(at.name)))
mapping[n] = at.name
return res, mapping

View File

@@ -0,0 +1,194 @@
from __future__ import annotations
from functools import lru_cache, singledispatch
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar
from attrs import Factory, define
from ._compat import TypeAlias
from .fns import Predicate
if TYPE_CHECKING:
from .converters import BaseConverter
TargetType: TypeAlias = Any
UnstructuredValue: TypeAlias = Any
StructuredValue: TypeAlias = Any
StructureHook: TypeAlias = Callable[[UnstructuredValue, TargetType], StructuredValue]
UnstructureHook: TypeAlias = Callable[[StructuredValue], UnstructuredValue]
Hook = TypeVar("Hook", StructureHook, UnstructureHook)
HookFactory: TypeAlias = Callable[[TargetType], Hook]
@define
class _DispatchNotFound:
"""A dummy object to help signify a dispatch not found."""
@define
class FunctionDispatch:
"""
FunctionDispatch is similar to functools.singledispatch, but
instead dispatches based on functions that take the type of the
first argument in the method, and return True or False.
objects that help determine dispatch should be instantiated objects.
:param converter: A converter to be used for factories that require converters.
.. versionchanged:: 24.1.0
Support for factories that require converters, hence this requires a
converter when creating.
"""
_converter: BaseConverter
_handler_pairs: list[tuple[Predicate, Callable[[Any, Any], Any], bool, bool]] = (
Factory(list)
)
def register(
self,
predicate: Predicate,
func: Callable[..., Any],
is_generator=False,
takes_converter=False,
) -> None:
self._handler_pairs.insert(0, (predicate, func, is_generator, takes_converter))
def dispatch(self, typ: Any) -> Callable[..., Any] | None:
"""
Return the appropriate handler for the object passed.
"""
for can_handle, handler, is_generator, takes_converter in self._handler_pairs:
# can handle could raise an exception here
# such as issubclass being called on an instance.
# it's easier to just ignore that case.
try:
ch = can_handle(typ)
except Exception: # noqa: S112
continue
if ch:
if is_generator:
if takes_converter:
return handler(typ, self._converter)
return handler(typ)
return handler
return None
def get_num_fns(self) -> int:
return len(self._handler_pairs)
def copy_to(self, other: FunctionDispatch, skip: int = 0) -> None:
other._handler_pairs = self._handler_pairs[:-skip] + other._handler_pairs
@define(init=False)
class MultiStrategyDispatch(Generic[Hook]):
"""
MultiStrategyDispatch uses a combination of exact-match dispatch,
singledispatch, and FunctionDispatch.
:param converter: A converter to be used for factories that require converters.
:param fallback_factory: A hook factory to be called when a hook cannot be
produced.
.. versionchanged:: 23.2.0
Fallbacks are now factories.
.. versionchanged:: 24.1.0
Support for factories that require converters, hence this requires a
converter when creating.
"""
_fallback_factory: HookFactory[Hook]
_converter: BaseConverter
_direct_dispatch: dict[TargetType, Hook]
_function_dispatch: FunctionDispatch
_single_dispatch: Any
dispatch: Callable[[TargetType, BaseConverter], Hook]
def __init__(
self, fallback_factory: HookFactory[Hook], converter: BaseConverter
) -> None:
self._fallback_factory = fallback_factory
self._direct_dispatch = {}
self._function_dispatch = FunctionDispatch(converter)
self._single_dispatch = singledispatch(_DispatchNotFound)
self.dispatch = lru_cache(maxsize=None)(self.dispatch_without_caching)
def dispatch_without_caching(self, typ: TargetType) -> Hook:
"""Dispatch on the type but without caching the result."""
try:
dispatch = self._single_dispatch.dispatch(typ)
if dispatch is not _DispatchNotFound:
return dispatch
except Exception: # noqa: S110
pass
direct_dispatch = self._direct_dispatch.get(typ)
if direct_dispatch is not None:
return direct_dispatch
res = self._function_dispatch.dispatch(typ)
return res if res is not None else self._fallback_factory(typ)
def register_cls_list(self, cls_and_handler, direct: bool = False) -> None:
"""Register a class to direct or singledispatch."""
for cls, handler in cls_and_handler:
if direct:
self._direct_dispatch[cls] = handler
else:
self._single_dispatch.register(cls, handler)
self.clear_direct()
self.dispatch.cache_clear()
def register_func_list(
self,
pred_and_handler: list[
tuple[Predicate, Any]
| tuple[Predicate, Any, bool]
| tuple[Predicate, Callable[[Any, BaseConverter], Any], Literal["extended"]]
],
):
"""
Register a predicate function to determine if the handler
should be used for the type.
:param pred_and_handler: The list of predicates and their associated
handlers. If a handler is registered in `extended` mode, it's a
factory that requires a converter.
"""
for tup in pred_and_handler:
if len(tup) == 2:
func, handler = tup
self._function_dispatch.register(func, handler)
else:
func, handler, is_gen = tup
if is_gen == "extended":
self._function_dispatch.register(
func, handler, is_generator=is_gen, takes_converter=True
)
else:
self._function_dispatch.register(func, handler, is_generator=is_gen)
self.clear_direct()
self.dispatch.cache_clear()
def clear_direct(self) -> None:
"""Clear the direct dispatch."""
self._direct_dispatch.clear()
def clear_cache(self) -> None:
"""Clear all caches."""
self._direct_dispatch.clear()
self.dispatch.cache_clear()
def get_num_fns(self) -> int:
return self._function_dispatch.get_num_fns()
def copy_to(self, other: MultiStrategyDispatch, skip: int = 0) -> None:
self._function_dispatch.copy_to(other._function_dispatch, skip=skip)
for cls, fn in self._single_dispatch.registry.items():
other._single_dispatch.register(cls, fn)
other.clear_cache()

View File

@@ -0,0 +1,129 @@
from typing import Any, List, Optional, Set, Tuple, Type, Union
from cattrs._compat import ExceptionGroup
class StructureHandlerNotFoundError(Exception):
"""
Error raised when structuring cannot find a handler for converting inputs into
:attr:`type_`.
"""
def __init__(self, message: str, type_: Type) -> None:
super().__init__(message)
self.type_ = type_
class BaseValidationError(ExceptionGroup):
cl: Type
def __new__(cls, message, excs, cl: Type):
obj = super().__new__(cls, message, excs)
obj.cl = cl
return obj
def derive(self, excs):
return ClassValidationError(self.message, excs, self.cl)
class IterableValidationNote(str):
"""Attached as a note to an exception when an iterable element fails structuring."""
index: Union[int, str] # Ints for list indices, strs for dict keys
type: Any
def __new__(
cls, string: str, index: Union[int, str], type: Any
) -> "IterableValidationNote":
instance = str.__new__(cls, string)
instance.index = index
instance.type = type
return instance
def __getnewargs__(self) -> Tuple[str, Union[int, str], Any]:
return (str(self), self.index, self.type)
class IterableValidationError(BaseValidationError):
"""Raised when structuring an iterable."""
def group_exceptions(
self,
) -> Tuple[List[Tuple[Exception, IterableValidationNote]], List[Exception]]:
"""Split the exceptions into two groups: with and without validation notes."""
excs_with_notes = []
other_excs = []
for subexc in self.exceptions:
if hasattr(subexc, "__notes__"):
for note in subexc.__notes__:
if note.__class__ is IterableValidationNote:
excs_with_notes.append((subexc, note))
break
else:
other_excs.append(subexc)
else:
other_excs.append(subexc)
return excs_with_notes, other_excs
class AttributeValidationNote(str):
"""Attached as a note to an exception when an attribute fails structuring."""
name: str
type: Any
def __new__(cls, string: str, name: str, type: Any) -> "AttributeValidationNote":
instance = str.__new__(cls, string)
instance.name = name
instance.type = type
return instance
def __getnewargs__(self) -> Tuple[str, str, Any]:
return (str(self), self.name, self.type)
class ClassValidationError(BaseValidationError):
"""Raised when validating a class if any attributes are invalid."""
def group_exceptions(
self,
) -> Tuple[List[Tuple[Exception, AttributeValidationNote]], List[Exception]]:
"""Split the exceptions into two groups: with and without validation notes."""
excs_with_notes = []
other_excs = []
for subexc in self.exceptions:
if hasattr(subexc, "__notes__"):
for note in subexc.__notes__:
if note.__class__ is AttributeValidationNote:
excs_with_notes.append((subexc, note))
break
else:
other_excs.append(subexc)
else:
other_excs.append(subexc)
return excs_with_notes, other_excs
class ForbiddenExtraKeysError(Exception):
"""
Raised when `forbid_extra_keys` is activated and such extra keys are detected
during structuring.
The attribute `extra_fields` is a sequence of those extra keys, which were the
cause of this error, and `cl` is the class which was structured with those extra
keys.
"""
def __init__(
self, message: Optional[str], cl: Type, extra_fields: Set[str]
) -> None:
self.cl = cl
self.extra_fields = extra_fields
cln = cl.__name__
super().__init__(
message
or f"Extra fields in constructor for {cln}: {', '.join(extra_fields)}"
)

View File

@@ -0,0 +1,22 @@
"""Useful internal functions."""
from typing import Any, Callable, NoReturn, Type, TypeVar
from ._compat import TypeAlias
from .errors import StructureHandlerNotFoundError
T = TypeVar("T")
Predicate: TypeAlias = Callable[[Any], bool]
"""A predicate function determines if a type can be handled."""
def identity(obj: T) -> T:
"""The identity function."""
return obj
def raise_error(_, cl: Type) -> NoReturn:
"""At the bottom of the condition stack, we explode if we can't handle it."""
msg = f"Unsupported type: {cl!r}. Register a structure hook for it."
raise StructureHandlerNotFoundError(msg, type_=cl)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from threading import local
from typing import Any, Callable
from attrs import frozen
@frozen
class AttributeOverride:
omit_if_default: bool | None = None
rename: str | None = None
omit: bool | None = None # Omit the field completely.
struct_hook: Callable[[Any, Any], Any] | None = None # Structure hook to use.
unstruct_hook: Callable[[Any], Any] | None = None # Structure hook to use.
neutral = AttributeOverride()
already_generating = local()

View File

@@ -0,0 +1,79 @@
from __future__ import annotations
from typing import TypeVar
from .._compat import get_args, get_origin, is_generic
def _tvar_has_default(tvar) -> bool:
"""Does `tvar` have a default?
In CPython 3.13+ and typing_extensions>=4.12.0:
- TypeVars have a `no_default()` method for detecting
if a TypeVar has a default
- TypeVars with `default=None` have `__default__` set to `None`
- TypeVars with no `default` parameter passed
have `__default__` set to `typing(_extensions).NoDefault
On typing_exensions<4.12.0:
- TypeVars do not have a `no_default()` method for detecting
if a TypeVar has a default
- TypeVars with `default=None` have `__default__` set to `NoneType`
- TypeVars with no `default` parameter passed
have `__default__` set to `typing(_extensions).NoDefault
"""
try:
return tvar.has_default()
except AttributeError:
# compatibility for typing_extensions<4.12.0
return getattr(tvar, "__default__", None) is not None
def generate_mapping(cl: type, old_mapping: dict[str, type] = {}) -> dict[str, type]:
"""Generate a mapping of typevars to actual types for a generic class."""
mapping = dict(old_mapping)
origin = get_origin(cl)
if origin is not None:
# To handle the cases where classes in the typing module are using
# the GenericAlias structure but aren't a Generic and hence
# end up in this function but do not have an `__parameters__`
# attribute. These classes are interface types, for example
# `typing.Hashable`.
parameters = getattr(get_origin(cl), "__parameters__", None)
if parameters is None:
return dict(old_mapping)
for p, t in zip(parameters, get_args(cl)):
if isinstance(t, TypeVar):
continue
mapping[p.__name__] = t
elif is_generic(cl):
# Origin is None, so this may be a subclass of a generic class.
orig_bases = cl.__orig_bases__
for base in orig_bases:
if not hasattr(base, "__args__"):
continue
base_args = base.__args__
if hasattr(base.__origin__, "__parameters__"):
base_params = base.__origin__.__parameters__
elif any(_tvar_has_default(base_arg) for base_arg in base_args):
# TypeVar with a default e.g. PEP 696
# https://www.python.org/dev/peps/pep-0696/
# Extract the defaults for the TypeVars and insert
# them into the mapping
mapping_params = [
(base_arg, base_arg.__default__)
for base_arg in base_args
if _tvar_has_default(base_arg)
]
base_params, base_args = zip(*mapping_params)
else:
continue
for param, arg in zip(base_params, base_args):
mapping[param.__name__] = arg
return mapping

View File

@@ -0,0 +1,29 @@
"""Line-cache functionality."""
import linecache
from typing import List
def generate_unique_filename(cls: type, func_name: str, lines: List[str] = []) -> str:
"""
Create a "filename" suitable for a function being generated.
If *lines* are provided, insert them in the first free spot or stop
if a duplicate is found.
"""
extra = ""
count = 1
while True:
unique_filename = "<cattrs generated {} {}.{}{}>".format(
func_name, cls.__module__, getattr(cls, "__qualname__", cls.__name__), extra
)
if not lines:
return unique_filename
cache_line = (len("\n".join(lines)), None, lines, unique_filename)
if linecache.cache.setdefault(unique_filename, cache_line) == cache_line:
return unique_filename
# Looks like this spot is taken. Try again.
count += 1
extra = f"-{count}"

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from attrs import NOTHING, Attribute, Factory
from .._compat import is_bare_final
from ..dispatch import StructureHook
from ..fns import raise_error
if TYPE_CHECKING:
from ..converters import BaseConverter
def find_structure_handler(
a: Attribute, type: Any, c: BaseConverter, prefer_attrs_converters: bool = False
) -> StructureHook | None:
"""Find the appropriate structure handler to use.
Return `None` if no handler should be used.
"""
try:
if a.converter is not None and prefer_attrs_converters:
# If the user as requested to use attrib converters, use nothing
# so it falls back to that.
handler = None
elif (
a.converter is not None and not prefer_attrs_converters and type is not None
):
handler = c.get_structure_hook(type, cache_result=False)
if handler == raise_error:
handler = None
elif type is not None:
if (
is_bare_final(type)
and a.default is not NOTHING
and not isinstance(a.default, Factory)
):
# This is a special case where we can use the
# type of the default to dispatch on.
type = a.default.__class__
handler = c.get_structure_hook(type, cache_result=False)
if handler == c._structure_call:
# Finals can't really be used with _structure_call, so
# we wrap it so the rest of the toolchain doesn't get
# confused.
def handler(v, _, _h=handler):
return _h(v, type)
else:
handler = c.get_structure_hook(type, cache_result=False)
else:
handler = c.structure
return handler
except RecursionError:
# This means we're dealing with a reference cycle, so use late binding.
return c.structure

View File

@@ -0,0 +1,611 @@
from __future__ import annotations
import re
import sys
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar
from attrs import NOTHING, Attribute
try:
from inspect import get_annotations
def get_annots(cl) -> dict[str, Any]:
return get_annotations(cl, eval_str=True)
except ImportError:
# https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
def get_annots(cl) -> dict[str, Any]:
if isinstance(cl, type):
ann = cl.__dict__.get("__annotations__", {})
else:
ann = getattr(cl, "__annotations__", {})
return ann
try:
from typing_extensions import _TypedDictMeta
except ImportError:
_TypedDictMeta = None
from .._compat import (
TypedDict,
get_full_type_hints,
get_notrequired_base,
get_origin,
is_annotated,
is_bare,
is_generic,
)
from .._generics import deep_copy_with
from ..errors import (
AttributeValidationNote,
ClassValidationError,
ForbiddenExtraKeysError,
StructureHandlerNotFoundError,
)
from ..fns import identity
from . import AttributeOverride
from ._consts import already_generating, neutral
from ._generics import generate_mapping
from ._lc import generate_unique_filename
from ._shared import find_structure_handler
if TYPE_CHECKING:
from ..converters import BaseConverter
__all__ = ["make_dict_unstructure_fn", "make_dict_structure_fn"]
T = TypeVar("T", bound=TypedDict)
def make_dict_unstructure_fn(
cl: type[T],
converter: BaseConverter,
_cattrs_use_linecache: bool = True,
**kwargs: AttributeOverride,
) -> Callable[[T], dict[str, Any]]:
"""
Generate a specialized dict unstructuring function for a TypedDict.
:param cl: A `TypedDict` class.
:param converter: A Converter instance to use for unstructuring nested fields.
:param kwargs: A mapping of field names to an `AttributeOverride`, for
customization.
:param _cattrs_detailed_validation: Whether to store the generated code in the
_linecache_, for easier debugging and better stack traces.
"""
origin = get_origin(cl)
attrs = _adapted_fields(origin or cl) # type: ignore
req_keys = _required_keys(origin or cl)
mapping = {}
if is_generic(cl):
mapping = generate_mapping(cl, mapping)
for base in getattr(origin, "__orig_bases__", ()):
if is_generic(base) and not str(base).startswith("typing.Generic"):
mapping = generate_mapping(base, mapping)
break
# It's possible for origin to be None if this is a subclass
# of a generic class.
if origin is not None:
cl = origin
cl_name = cl.__name__
fn_name = "unstructure_typeddict_" + cl_name
globs = {}
lines = []
internal_arg_parts = {}
# We keep track of what we're generating to help with recursive
# class graphs.
try:
working_set = already_generating.working_set
except AttributeError:
working_set = set()
already_generating.working_set = working_set
if cl in working_set:
raise RecursionError()
working_set.add(cl)
try:
# We want to short-circuit in certain cases and return the identity
# function.
# We short-circuit if all of these are true:
# * no attributes have been overridden
# * all attributes resolve to `converter._unstructure_identity`
for a in attrs:
attr_name = a.name
override = kwargs.get(attr_name, neutral)
if override != neutral:
break
handler = None
t = a.type
if isinstance(t, TypeVar):
if t.__name__ in mapping:
t = mapping[t.__name__]
else:
# Unbound typevars use late binding.
handler = converter.unstructure
elif is_generic(t) and not is_bare(t) and not is_annotated(t):
t = deep_copy_with(t, mapping)
if handler is None:
nrb = get_notrequired_base(t)
if nrb is not NOTHING:
t = nrb
try:
handler = converter.get_unstructure_hook(t)
except RecursionError:
# There's a circular reference somewhere down the line
handler = converter.unstructure
is_identity = handler == identity
if not is_identity:
break
else:
# We've not broken the loop.
return identity
for ix, a in enumerate(attrs):
attr_name = a.name
override = kwargs.get(attr_name, neutral)
if override.omit:
lines.append(f" res.pop('{attr_name}', None)")
continue
if override.rename is not None:
# We also need to pop when renaming, since we're copying
# the original.
lines.append(f" res.pop('{attr_name}', None)")
kn = attr_name if override.rename is None else override.rename
attr_required = attr_name in req_keys
# For each attribute, we try resolving the type here and now.
# If a type is manually overwritten, this function should be
# regenerated.
handler = None
if override.unstruct_hook is not None:
handler = override.unstruct_hook
else:
t = a.type
if isinstance(t, TypeVar):
if t.__name__ in mapping:
t = mapping[t.__name__]
else:
handler = converter.unstructure
elif is_generic(t) and not is_bare(t) and not is_annotated(t):
t = deep_copy_with(t, mapping)
if handler is None:
nrb = get_notrequired_base(t)
if nrb is not NOTHING:
t = nrb
try:
handler = converter.get_unstructure_hook(t)
except RecursionError:
# There's a circular reference somewhere down the line
handler = converter.unstructure
is_identity = handler == identity
if not is_identity:
unstruct_handler_name = f"__c_unstr_{ix}"
globs[unstruct_handler_name] = handler
internal_arg_parts[unstruct_handler_name] = handler
invoke = f"{unstruct_handler_name}(instance['{attr_name}'])"
elif override.rename is None:
# We're not doing anything to this attribute, so
# it'll already be present in the input dict.
continue
else:
# Probably renamed, we just fetch it.
invoke = f"instance['{attr_name}']"
if attr_required:
# No default or no override.
lines.append(f" res['{kn}'] = {invoke}")
else:
lines.append(f" if '{attr_name}' in instance: res['{kn}'] = {invoke}")
internal_arg_line = ", ".join([f"{i}={i}" for i in internal_arg_parts])
if internal_arg_line:
internal_arg_line = f", {internal_arg_line}"
for k, v in internal_arg_parts.items():
globs[k] = v
total_lines = [
f"def {fn_name}(instance{internal_arg_line}):",
" res = instance.copy()",
*lines,
" return res",
]
script = "\n".join(total_lines)
fname = generate_unique_filename(
cl, "unstructure", lines=total_lines if _cattrs_use_linecache else []
)
eval(compile(script, fname, "exec"), globs)
finally:
working_set.remove(cl)
if not working_set:
del already_generating.working_set
return globs[fn_name]
def make_dict_structure_fn(
cl: Any,
converter: BaseConverter,
_cattrs_forbid_extra_keys: bool | Literal["from_converter"] = "from_converter",
_cattrs_use_linecache: bool = True,
_cattrs_detailed_validation: bool | Literal["from_converter"] = "from_converter",
**kwargs: AttributeOverride,
) -> Callable[[dict, Any], Any]:
"""Generate a specialized dict structuring function for typed dicts.
:param cl: A `TypedDict` class.
:param converter: A Converter instance to use for structuring nested fields.
:param kwargs: A mapping of field names to an `AttributeOverride`, for
customization.
:param _cattrs_detailed_validation: Whether to use a slower mode that produces
more detailed errors.
:param _cattrs_forbid_extra_keys: Whether the structuring function should raise a
`ForbiddenExtraKeysError` if unknown keys are encountered.
:param _cattrs_detailed_validation: Whether to store the generated code in the
_linecache_, for easier debugging and better stack traces.
.. versionchanged:: 23.2.0
The `_cattrs_forbid_extra_keys` and `_cattrs_detailed_validation` parameters
take their values from the given converter by default.
"""
mapping = {}
if is_generic(cl):
base = get_origin(cl)
mapping = generate_mapping(cl, mapping)
if base is not None:
# It's possible for this to be a subclass of a generic,
# so no origin.
cl = base
for base in getattr(cl, "__orig_bases__", ()):
if is_generic(base) and not str(base).startswith("typing.Generic"):
mapping = generate_mapping(base, mapping)
break
cl_name = cl.__name__
fn_name = "structure_" + cl_name
# We have generic parameters and need to generate a unique name for the function
for p in getattr(cl, "__parameters__", ()):
try:
name_base = mapping[p.__name__]
except KeyError:
pn = p.__name__
raise StructureHandlerNotFoundError(
f"Missing type for generic argument {pn}, specify it when structuring.",
p,
) from None
name = getattr(name_base, "__name__", None) or str(name_base)
# `<>` can be present in lambdas
# `|` can be present in unions
name = re.sub(r"[\[\.\] ,<>]", "_", name)
name = re.sub(r"\|", "u", name)
fn_name += f"_{name}"
internal_arg_parts = {"__cl": cl}
globs = {}
lines = []
post_lines = []
attrs = _adapted_fields(cl)
req_keys = _required_keys(cl)
allowed_fields = set()
if _cattrs_forbid_extra_keys == "from_converter":
# BaseConverter doesn't have it so we're careful.
_cattrs_forbid_extra_keys = getattr(converter, "forbid_extra_keys", False)
if _cattrs_detailed_validation == "from_converter":
_cattrs_detailed_validation = converter.detailed_validation
if _cattrs_forbid_extra_keys:
globs["__c_a"] = allowed_fields
globs["__c_feke"] = ForbiddenExtraKeysError
lines.append(" res = o.copy()")
if _cattrs_detailed_validation:
lines.append(" errors = []")
internal_arg_parts["__c_cve"] = ClassValidationError
internal_arg_parts["__c_avn"] = AttributeValidationNote
for ix, a in enumerate(attrs):
an = a.name
attr_required = an in req_keys
override = kwargs.get(an, neutral)
if override.omit:
continue
t = a.type
if isinstance(t, TypeVar):
t = mapping.get(t.__name__, t)
elif is_generic(t) and not is_bare(t) and not is_annotated(t):
t = deep_copy_with(t, mapping)
nrb = get_notrequired_base(t)
if nrb is not NOTHING:
t = nrb
if is_generic(t) and not is_bare(t) and not is_annotated(t):
t = deep_copy_with(t, mapping)
# For each attribute, we try resolving the type here and now.
# If a type is manually overwritten, this function should be
# regenerated.
if override.struct_hook is not None:
# If the user has requested an override, just use that.
handler = override.struct_hook
else:
handler = find_structure_handler(a, t, converter)
struct_handler_name = f"__c_structure_{ix}"
internal_arg_parts[struct_handler_name] = handler
kn = an if override.rename is None else override.rename
allowed_fields.add(kn)
i = " "
if not attr_required:
lines.append(f"{i}if '{kn}' in o:")
i = f"{i} "
lines.append(f"{i}try:")
i = f"{i} "
tn = f"__c_type_{ix}"
internal_arg_parts[tn] = t
if handler == converter._structure_call:
internal_arg_parts[struct_handler_name] = t
lines.append(f"{i}res['{an}'] = {struct_handler_name}(o['{kn}'])")
else:
lines.append(f"{i}res['{an}'] = {struct_handler_name}(o['{kn}'], {tn})")
if override.rename is not None:
lines.append(f"{i}del res['{kn}']")
i = i[:-2]
lines.append(f"{i}except Exception as e:")
i = f"{i} "
lines.append(
f'{i}e.__notes__ = [*getattr(e, \'__notes__\', []), __c_avn("Structuring typeddict {cl.__qualname__} @ attribute {an}", "{an}", {tn})]'
)
lines.append(f"{i}errors.append(e)")
if _cattrs_forbid_extra_keys:
post_lines += [
" unknown_fields = o.keys() - __c_a",
" if unknown_fields:",
" errors.append(__c_feke('', __cl, unknown_fields))",
]
post_lines.append(
f" if errors: raise __c_cve('While structuring ' + {cl.__name__!r}, errors, __cl)"
)
else:
non_required = []
# The first loop deals with required args.
for ix, a in enumerate(attrs):
an = a.name
attr_required = an in req_keys
override = kwargs.get(an, neutral)
if override.omit:
continue
if not attr_required:
non_required.append((ix, a))
continue
t = a.type
if isinstance(t, TypeVar):
t = mapping.get(t.__name__, t)
elif is_generic(t) and not is_bare(t) and not is_annotated(t):
t = deep_copy_with(t, mapping)
nrb = get_notrequired_base(t)
if nrb is not NOTHING:
t = nrb
if override.struct_hook is not None:
handler = override.struct_hook
else:
# For each attribute, we try resolving the type here and now.
# If a type is manually overwritten, this function should be
# regenerated.
handler = converter.get_structure_hook(t)
kn = an if override.rename is None else override.rename
allowed_fields.add(kn)
struct_handler_name = f"__c_structure_{ix}"
internal_arg_parts[struct_handler_name] = handler
if handler == converter._structure_call:
internal_arg_parts[struct_handler_name] = t
invocation_line = f" res['{an}'] = {struct_handler_name}(o['{kn}'])"
else:
tn = f"__c_type_{ix}"
internal_arg_parts[tn] = t
invocation_line = (
f" res['{an}'] = {struct_handler_name}(o['{kn}'], {tn})"
)
lines.append(invocation_line)
if override.rename is not None:
lines.append(f" del res['{override.rename}']")
# The second loop is for optional args.
if non_required:
for ix, a in non_required:
an = a.name
override = kwargs.get(an, neutral)
t = a.type
nrb = get_notrequired_base(t)
if nrb is not NOTHING:
t = nrb
if isinstance(t, TypeVar):
t = mapping.get(t.__name__, t)
elif is_generic(t) and not is_bare(t) and not is_annotated(t):
t = deep_copy_with(t, mapping)
if override.struct_hook is not None:
handler = override.struct_hook
else:
# For each attribute, we try resolving the type here and now.
# If a type is manually overwritten, this function should be
# regenerated.
handler = converter.get_structure_hook(t)
struct_handler_name = f"__c_structure_{ix}"
internal_arg_parts[struct_handler_name] = handler
ian = an
kn = an if override.rename is None else override.rename
allowed_fields.add(kn)
post_lines.append(f" if '{kn}' in o:")
if handler == converter._structure_call:
internal_arg_parts[struct_handler_name] = t
post_lines.append(
f" res['{ian}'] = {struct_handler_name}(o['{kn}'])"
)
else:
tn = f"__c_type_{ix}"
internal_arg_parts[tn] = t
post_lines.append(
f" res['{ian}'] = {struct_handler_name}(o['{kn}'], {tn})"
)
if override.rename is not None:
lines.append(f" res.pop('{override.rename}', None)")
if _cattrs_forbid_extra_keys:
post_lines += [
" unknown_fields = o.keys() - __c_a",
" if unknown_fields:",
" raise __c_feke('', __cl, unknown_fields)",
]
# At the end, we create the function header.
internal_arg_line = ", ".join([f"{i}={i}" for i in internal_arg_parts])
for k, v in internal_arg_parts.items():
globs[k] = v
total_lines = [
f"def {fn_name}(o, _, {internal_arg_line}):",
*lines,
*post_lines,
" return res",
]
script = "\n".join(total_lines)
fname = generate_unique_filename(
cl, "structure", lines=total_lines if _cattrs_use_linecache else []
)
eval(compile(script, fname, "exec"), globs)
return globs[fn_name]
def _adapted_fields(cls: Any) -> list[Attribute]:
annotations = get_annots(cls)
hints = get_full_type_hints(cls)
return [
Attribute(
n,
NOTHING,
None,
False,
False,
False,
False,
False,
type=hints[n] if n in hints else annotations[n],
)
for n, a in annotations.items()
]
def _is_extensions_typeddict(cls) -> bool:
if _TypedDictMeta is None:
return False
return cls.__class__ is _TypedDictMeta or (
is_generic(cls) and (cls.__origin__.__class__ is _TypedDictMeta)
)
if sys.version_info >= (3, 11):
def _required_keys(cls: type) -> set[str]:
return cls.__required_keys__
elif sys.version_info >= (3, 9):
from typing_extensions import Annotated, NotRequired, Required, get_args
# Note that there is no `typing.Required` on 3.9 and 3.10, only in
# `typing_extensions`. Therefore, `typing.TypedDict` will not honor this
# annotation, only `typing_extensions.TypedDict`.
def _required_keys(cls: type) -> set[str]:
"""Our own processor for required keys."""
if _is_extensions_typeddict(cls):
return cls.__required_keys__
# We vendor a part of the typing_extensions logic for
# gathering required keys. *sigh*
own_annotations = cls.__dict__.get("__annotations__", {})
required_keys = set()
# On 3.8 - 3.10, typing.TypedDict doesn't put typeddict superclasses
# in the MRO, therefore we cannot handle non-required keys properly
# in some situations. Oh well.
for key in getattr(cls, "__required_keys__", []):
annotation_type = own_annotations[key]
annotation_origin = get_origin(annotation_type)
if annotation_origin is Annotated:
annotation_args = get_args(annotation_type)
if annotation_args:
annotation_type = annotation_args[0]
annotation_origin = get_origin(annotation_type)
if annotation_origin is NotRequired:
pass
elif cls.__total__:
required_keys.add(key)
return required_keys
else:
from typing_extensions import Annotated, NotRequired, Required, get_args
# On 3.8, typing.TypedDicts do not have __required_keys__.
def _required_keys(cls: type) -> set[str]:
"""Our own processor for required keys."""
if _is_extensions_typeddict(cls):
return cls.__required_keys__
own_annotations = cls.__dict__.get("__annotations__", {})
required_keys = set()
for key in own_annotations:
annotation_type = own_annotations[key]
if is_annotated(annotation_type):
# If this is `Annotated`, we need to get the origin twice.
annotation_type = get_origin(annotation_type)
annotation_origin = get_origin(annotation_type)
if annotation_origin is Required:
required_keys.add(key)
elif annotation_origin is NotRequired:
pass
elif cls.__total__:
required_keys.add(key)
return required_keys

View File

@@ -0,0 +1,27 @@
import sys
from datetime import datetime
from typing import Any, Callable, TypeVar
if sys.version_info[:2] < (3, 10):
from typing_extensions import ParamSpec
else:
from typing import ParamSpec
def validate_datetime(v, _):
if not isinstance(v, datetime):
raise Exception(f"Expected datetime, got {v}")
return v
T = TypeVar("T")
P = ParamSpec("P")
def wrap(_: Callable[P, Any]) -> Callable[[Callable[..., T]], Callable[P, T]]:
"""Wrap a `Converter` `__init__` in a type-safe way."""
def impl(x: Callable[..., T]) -> Callable[P, T]:
return x
return impl

View File

@@ -0,0 +1,106 @@
"""Preconfigured converters for bson."""
from base64 import b85decode, b85encode
from datetime import date, datetime
from typing import Any, Type, TypeVar, Union
from bson import DEFAULT_CODEC_OPTIONS, CodecOptions, Int64, ObjectId, decode, encode
from cattrs._compat import AbstractSet, is_mapping
from cattrs.gen import make_mapping_structure_fn
from ..converters import BaseConverter, Converter
from ..dispatch import StructureHook
from ..strategies import configure_union_passthrough
from . import validate_datetime, wrap
T = TypeVar("T")
class Base85Bytes(bytes):
"""A subclass to help with binary key encoding/decoding."""
class BsonConverter(Converter):
def dumps(
self,
obj: Any,
unstructure_as: Any = None,
check_keys: bool = False,
codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS,
) -> bytes:
return encode(
self.unstructure(obj, unstructure_as=unstructure_as),
check_keys=check_keys,
codec_options=codec_options,
)
def loads(
self,
data: bytes,
cl: Type[T],
codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS,
) -> T:
return self.structure(decode(data, codec_options=codec_options), cl)
def configure_converter(converter: BaseConverter):
"""
Configure the converter for use with the bson library.
* sets are serialized as lists
* byte mapping keys are base85-encoded into strings when unstructuring, and reverse
* non-string, non-byte mapping keys are coerced into strings when unstructuring
* a deserialization hook is registered for bson.ObjectId by default
"""
def gen_unstructure_mapping(cl: Any, unstructure_to=None):
key_handler = str
args = getattr(cl, "__args__", None)
if args:
if issubclass(args[0], str):
key_handler = None
elif issubclass(args[0], bytes):
def key_handler(k):
return b85encode(k).decode("utf8")
return converter.gen_unstructure_mapping(
cl, unstructure_to=unstructure_to, key_handler=key_handler
)
def gen_structure_mapping(cl: Any) -> StructureHook:
args = getattr(cl, "__args__", None)
if args and issubclass(args[0], bytes):
h = make_mapping_structure_fn(cl, converter, key_type=Base85Bytes)
else:
h = make_mapping_structure_fn(cl, converter)
return h
converter.register_structure_hook(Base85Bytes, lambda v, _: b85decode(v))
converter.register_unstructure_hook_factory(is_mapping, gen_unstructure_mapping)
converter.register_structure_hook_factory(is_mapping, gen_structure_mapping)
converter.register_structure_hook(ObjectId, lambda v, _: ObjectId(v))
configure_union_passthrough(
Union[str, bool, int, float, None, bytes, datetime, ObjectId, Int64], converter
)
# datetime inherits from date, so identity unstructure hook used
# here to prevent the date unstructure hook running.
converter.register_unstructure_hook(datetime, lambda v: v)
converter.register_structure_hook(datetime, validate_datetime)
converter.register_unstructure_hook(date, lambda v: v.isoformat())
converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v))
@wrap(BsonConverter)
def make_converter(*args: Any, **kwargs: Any) -> BsonConverter:
kwargs["unstruct_collection_overrides"] = {
AbstractSet: list,
**kwargs.get("unstruct_collection_overrides", {}),
}
res = BsonConverter(*args, **kwargs)
configure_converter(res)
return res

View File

@@ -0,0 +1,50 @@
"""Preconfigured converters for cbor2."""
from datetime import date, datetime, timezone
from typing import Any, Type, TypeVar, Union
from cbor2 import dumps, loads
from cattrs._compat import AbstractSet
from ..converters import BaseConverter, Converter
from ..strategies import configure_union_passthrough
from . import wrap
T = TypeVar("T")
class Cbor2Converter(Converter):
def dumps(self, obj: Any, unstructure_as: Any = None, **kwargs: Any) -> bytes:
return dumps(self.unstructure(obj, unstructure_as=unstructure_as), **kwargs)
def loads(self, data: bytes, cl: Type[T], **kwargs: Any) -> T:
return self.structure(loads(data, **kwargs), cl)
def configure_converter(converter: BaseConverter):
"""
Configure the converter for use with the cbor2 library.
* datetimes are serialized as timestamp floats
* sets are serialized as lists
"""
converter.register_unstructure_hook(datetime, lambda v: v.timestamp())
converter.register_structure_hook(
datetime, lambda v, _: datetime.fromtimestamp(v, timezone.utc)
)
converter.register_unstructure_hook(date, lambda v: v.isoformat())
converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v))
configure_union_passthrough(Union[str, bool, int, float, None, bytes], converter)
@wrap(Cbor2Converter)
def make_converter(*args: Any, **kwargs: Any) -> Cbor2Converter:
kwargs["unstruct_collection_overrides"] = {
AbstractSet: list,
**kwargs.get("unstruct_collection_overrides", {}),
}
res = Cbor2Converter(*args, **kwargs)
configure_converter(res)
return res

View File

@@ -0,0 +1,56 @@
"""Preconfigured converters for the stdlib json."""
from base64 import b85decode, b85encode
from datetime import date, datetime
from json import dumps, loads
from typing import Any, Type, TypeVar, Union
from .._compat import AbstractSet, Counter
from ..converters import BaseConverter, Converter
from ..strategies import configure_union_passthrough
from . import wrap
T = TypeVar("T")
class JsonConverter(Converter):
def dumps(self, obj: Any, unstructure_as: Any = None, **kwargs: Any) -> str:
return dumps(self.unstructure(obj, unstructure_as=unstructure_as), **kwargs)
def loads(self, data: Union[bytes, str], cl: Type[T], **kwargs: Any) -> T:
return self.structure(loads(data, **kwargs), cl)
def configure_converter(converter: BaseConverter):
"""
Configure the converter for use with the stdlib json module.
* bytes are serialized as base85 strings
* datetimes are serialized as ISO 8601
* counters are serialized as dicts
* sets are serialized as lists
* union passthrough is configured for unions of strings, bools, ints,
floats and None
"""
converter.register_unstructure_hook(
bytes, lambda v: (b85encode(v) if v else b"").decode("utf8")
)
converter.register_structure_hook(bytes, lambda v, _: b85decode(v))
converter.register_unstructure_hook(datetime, lambda v: v.isoformat())
converter.register_structure_hook(datetime, lambda v, _: datetime.fromisoformat(v))
converter.register_unstructure_hook(date, lambda v: v.isoformat())
converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v))
configure_union_passthrough(Union[str, bool, int, float, None], converter)
@wrap(JsonConverter)
def make_converter(*args: Any, **kwargs: Any) -> JsonConverter:
kwargs["unstruct_collection_overrides"] = {
AbstractSet: list,
Counter: dict,
**kwargs.get("unstruct_collection_overrides", {}),
}
res = JsonConverter(*args, **kwargs)
configure_converter(res)
return res

View File

@@ -0,0 +1,54 @@
"""Preconfigured converters for msgpack."""
from datetime import date, datetime, time, timezone
from typing import Any, Type, TypeVar, Union
from msgpack import dumps, loads
from cattrs._compat import AbstractSet
from ..converters import BaseConverter, Converter
from ..strategies import configure_union_passthrough
from . import wrap
T = TypeVar("T")
class MsgpackConverter(Converter):
def dumps(self, obj: Any, unstructure_as: Any = None, **kwargs: Any) -> bytes:
return dumps(self.unstructure(obj, unstructure_as=unstructure_as), **kwargs)
def loads(self, data: bytes, cl: Type[T], **kwargs: Any) -> T:
return self.structure(loads(data, **kwargs), cl)
def configure_converter(converter: BaseConverter):
"""
Configure the converter for use with the msgpack library.
* datetimes are serialized as timestamp floats
* sets are serialized as lists
"""
converter.register_unstructure_hook(datetime, lambda v: v.timestamp())
converter.register_structure_hook(
datetime, lambda v, _: datetime.fromtimestamp(v, timezone.utc)
)
converter.register_unstructure_hook(
date, lambda v: datetime.combine(v, time(tzinfo=timezone.utc)).timestamp()
)
converter.register_structure_hook(
date, lambda v, _: datetime.fromtimestamp(v, timezone.utc).date()
)
configure_union_passthrough(Union[str, bool, int, float, None, bytes], converter)
@wrap(MsgpackConverter)
def make_converter(*args: Any, **kwargs: Any) -> MsgpackConverter:
kwargs["unstruct_collection_overrides"] = {
AbstractSet: list,
**kwargs.get("unstruct_collection_overrides", {}),
}
res = MsgpackConverter(*args, **kwargs)
configure_converter(res)
return res

View File

@@ -0,0 +1,185 @@
"""Preconfigured converters for msgspec."""
from __future__ import annotations
from base64 import b64decode
from datetime import date, datetime
from enum import Enum
from functools import partial
from typing import Any, Callable, TypeVar, Union, get_type_hints
from attrs import has as attrs_has
from attrs import resolve_types
from msgspec import Struct, convert, to_builtins
from msgspec.json import Encoder, decode
from .._compat import (
fields,
get_args,
get_origin,
has,
is_bare,
is_mapping,
is_sequence,
)
from ..cols import is_namedtuple
from ..converters import BaseConverter, Converter
from ..dispatch import UnstructureHook
from ..fns import identity
from ..gen import make_hetero_tuple_unstructure_fn
from ..strategies import configure_union_passthrough
from . import wrap
T = TypeVar("T")
__all__ = ["MsgspecJsonConverter", "configure_converter", "make_converter"]
class MsgspecJsonConverter(Converter):
"""A converter specialized for the _msgspec_ library."""
#: The msgspec encoder for dumping.
encoder: Encoder = Encoder()
def dumps(self, obj: Any, unstructure_as: Any = None, **kwargs: Any) -> bytes:
"""Unstructure and encode `obj` into JSON bytes."""
return self.encoder.encode(
self.unstructure(obj, unstructure_as=unstructure_as), **kwargs
)
def get_dumps_hook(
self, unstructure_as: Any, **kwargs: Any
) -> Callable[[Any], bytes]:
"""Produce a `dumps` hook for the given type."""
unstruct_hook = self.get_unstructure_hook(unstructure_as)
if unstruct_hook in (identity, to_builtins):
return self.encoder.encode
return self.dumps
def loads(self, data: bytes, cl: type[T], **kwargs: Any) -> T:
"""Decode and structure `cl` from the provided JSON bytes."""
return self.structure(decode(data, **kwargs), cl)
def get_loads_hook(self, cl: type[T]) -> Callable[[bytes], T]:
"""Produce a `loads` hook for the given type."""
return partial(self.loads, cl=cl)
def configure_converter(converter: Converter) -> None:
"""Configure the converter for the msgspec library.
* bytes are serialized as base64 strings, directly by msgspec
* datetimes and dates are passed through to be serialized as RFC 3339 directly
* enums are passed through to msgspec directly
* union passthrough configured for str, bool, int, float and None
"""
configure_passthroughs(converter)
converter.register_unstructure_hook(Struct, to_builtins)
converter.register_unstructure_hook(Enum, to_builtins)
converter.register_structure_hook(Struct, convert)
converter.register_structure_hook(bytes, lambda v, _: b64decode(v))
converter.register_structure_hook(datetime, lambda v, _: convert(v, datetime))
converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v))
configure_union_passthrough(Union[str, bool, int, float, None], converter)
@wrap(MsgspecJsonConverter)
def make_converter(*args: Any, **kwargs: Any) -> MsgspecJsonConverter:
res = MsgspecJsonConverter(*args, **kwargs)
configure_converter(res)
return res
def configure_passthroughs(converter: Converter) -> None:
"""Configure optimizing passthroughs.
A passthrough is when we let msgspec handle something automatically.
"""
converter.register_unstructure_hook(bytes, to_builtins)
converter.register_unstructure_hook_factory(is_mapping, mapping_unstructure_factory)
converter.register_unstructure_hook_factory(is_sequence, seq_unstructure_factory)
converter.register_unstructure_hook_factory(has, attrs_unstructure_factory)
converter.register_unstructure_hook_factory(
is_namedtuple, namedtuple_unstructure_factory
)
def seq_unstructure_factory(type, converter: Converter) -> UnstructureHook:
"""The msgspec unstructure hook factory for sequences."""
if is_bare(type):
type_arg = Any
else:
args = get_args(type)
type_arg = args[0]
handler = converter.get_unstructure_hook(type_arg, cache_result=False)
if handler in (identity, to_builtins):
return handler
return converter.gen_unstructure_iterable(type)
def mapping_unstructure_factory(type, converter: BaseConverter) -> UnstructureHook:
"""The msgspec unstructure hook factory for mappings."""
if is_bare(type):
key_arg = Any
val_arg = Any
key_handler = converter.get_unstructure_hook(key_arg, cache_result=False)
value_handler = converter.get_unstructure_hook(val_arg, cache_result=False)
else:
args = get_args(type)
if len(args) == 2:
key_arg, val_arg = args
else:
# Probably a Counter
key_arg, val_arg = args, Any
key_handler = converter.get_unstructure_hook(key_arg, cache_result=False)
value_handler = converter.get_unstructure_hook(val_arg, cache_result=False)
if key_handler in (identity, to_builtins) and value_handler in (
identity,
to_builtins,
):
return to_builtins
return converter.gen_unstructure_mapping(type)
def attrs_unstructure_factory(type: Any, converter: Converter) -> UnstructureHook:
"""Choose whether to use msgspec handling or our own."""
origin = get_origin(type)
attribs = fields(origin or type)
if attrs_has(type) and any(isinstance(a.type, str) for a in attribs):
resolve_types(type)
attribs = fields(origin or type)
if any(
attr.name.startswith("_")
or (
converter.get_unstructure_hook(attr.type, cache_result=False)
not in (identity, to_builtins)
)
for attr in attribs
):
return converter.gen_unstructure_attrs_fromdict(type)
return to_builtins
def namedtuple_unstructure_factory(
type: type[tuple], converter: BaseConverter
) -> UnstructureHook:
"""A hook factory for unstructuring namedtuples, modified for msgspec."""
if all(
converter.get_unstructure_hook(t) in (identity, to_builtins)
for t in get_type_hints(type).values()
):
return identity
return make_hetero_tuple_unstructure_fn(
type,
converter,
unstructure_to=tuple,
type_args=tuple(get_type_hints(type).values()),
)

View File

@@ -0,0 +1,95 @@
"""Preconfigured converters for orjson."""
from base64 import b85decode, b85encode
from datetime import date, datetime
from enum import Enum
from functools import partial
from typing import Any, Type, TypeVar, Union
from orjson import dumps, loads
from .._compat import AbstractSet, is_mapping
from ..cols import is_namedtuple, namedtuple_unstructure_factory
from ..converters import BaseConverter, Converter
from ..fns import identity
from ..strategies import configure_union_passthrough
from . import wrap
T = TypeVar("T")
class OrjsonConverter(Converter):
def dumps(self, obj: Any, unstructure_as: Any = None, **kwargs: Any) -> bytes:
return dumps(self.unstructure(obj, unstructure_as=unstructure_as), **kwargs)
def loads(self, data: Union[bytes, bytearray, memoryview, str], cl: Type[T]) -> T:
return self.structure(loads(data), cl)
def configure_converter(converter: BaseConverter):
"""
Configure the converter for use with the orjson library.
* bytes are serialized as base85 strings
* datetimes and dates are passed through to be serialized as RFC 3339 by orjson
* typed namedtuples are serialized as lists
* sets are serialized as lists
* string enum mapping keys have special handling
* mapping keys are coerced into strings when unstructuring
.. versionchanged: 24.1.0
Add support for typed namedtuples.
"""
converter.register_unstructure_hook(
bytes, lambda v: (b85encode(v) if v else b"").decode("utf8")
)
converter.register_structure_hook(bytes, lambda v, _: b85decode(v))
converter.register_structure_hook(datetime, lambda v, _: datetime.fromisoformat(v))
converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v))
def gen_unstructure_mapping(cl: Any, unstructure_to=None):
key_handler = str
args = getattr(cl, "__args__", None)
if args:
if issubclass(args[0], str) and issubclass(args[0], Enum):
def key_handler(v):
return v.value
else:
# It's possible the handler for the key type has been overridden.
# (For example base85 encoding for bytes.)
# In that case, we want to use the override.
kh = converter.get_unstructure_hook(args[0])
if kh != identity:
key_handler = kh
return converter.gen_unstructure_mapping(
cl, unstructure_to=unstructure_to, key_handler=key_handler
)
converter._unstructure_func.register_func_list(
[
(is_mapping, gen_unstructure_mapping, True),
(
is_namedtuple,
partial(namedtuple_unstructure_factory, unstructure_to=tuple),
"extended",
),
]
)
configure_union_passthrough(Union[str, bool, int, float, None], converter)
@wrap(OrjsonConverter)
def make_converter(*args: Any, **kwargs: Any) -> OrjsonConverter:
kwargs["unstruct_collection_overrides"] = {
AbstractSet: list,
**kwargs.get("unstruct_collection_overrides", {}),
}
res = OrjsonConverter(*args, **kwargs)
configure_converter(res)
return res

View File

@@ -0,0 +1,72 @@
"""Preconfigured converters for pyyaml."""
from datetime import date, datetime
from functools import partial
from typing import Any, Type, TypeVar, Union
from yaml import safe_dump, safe_load
from .._compat import FrozenSetSubscriptable
from ..cols import is_namedtuple, namedtuple_unstructure_factory
from ..converters import BaseConverter, Converter
from ..strategies import configure_union_passthrough
from . import validate_datetime, wrap
T = TypeVar("T")
def validate_date(v, _):
if not isinstance(v, date):
raise ValueError(f"Expected date, got {v}")
return v
class PyyamlConverter(Converter):
def dumps(self, obj: Any, unstructure_as: Any = None, **kwargs: Any) -> str:
return safe_dump(self.unstructure(obj, unstructure_as=unstructure_as), **kwargs)
def loads(self, data: str, cl: Type[T]) -> T:
return self.structure(safe_load(data), cl)
def configure_converter(converter: BaseConverter):
"""
Configure the converter for use with the pyyaml library.
* frozensets are serialized as lists
* string enums are converted into strings explicitly
* datetimes and dates are validated
* typed namedtuples are serialized as lists
.. versionchanged: 24.1.0
Add support for typed namedtuples.
"""
converter.register_unstructure_hook(
str, lambda v: v if v.__class__ is str else v.value
)
# datetime inherits from date, so identity unstructure hook used
# here to prevent the date unstructure hook running.
converter.register_unstructure_hook(datetime, lambda v: v)
converter.register_structure_hook(datetime, validate_datetime)
converter.register_structure_hook(date, validate_date)
converter.register_unstructure_hook_factory(is_namedtuple)(
partial(namedtuple_unstructure_factory, unstructure_to=tuple)
)
configure_union_passthrough(
Union[str, bool, int, float, None, bytes, datetime, date], converter
)
@wrap(PyyamlConverter)
def make_converter(*args: Any, **kwargs: Any) -> PyyamlConverter:
kwargs["unstruct_collection_overrides"] = {
FrozenSetSubscriptable: list,
**kwargs.get("unstruct_collection_overrides", {}),
}
res = PyyamlConverter(*args, **kwargs)
configure_converter(res)
return res

View File

@@ -0,0 +1,87 @@
"""Preconfigured converters for tomlkit."""
from base64 import b85decode, b85encode
from datetime import date, datetime
from enum import Enum
from operator import attrgetter
from typing import Any, Type, TypeVar, Union
from tomlkit import dumps, loads
from tomlkit.items import Float, Integer, String
from cattrs._compat import AbstractSet, is_mapping
from ..converters import BaseConverter, Converter
from ..strategies import configure_union_passthrough
from . import validate_datetime, wrap
T = TypeVar("T")
_enum_value_getter = attrgetter("_value_")
class TomlkitConverter(Converter):
def dumps(self, obj: Any, unstructure_as: Any = None, **kwargs: Any) -> str:
return dumps(self.unstructure(obj, unstructure_as=unstructure_as), **kwargs)
def loads(self, data: str, cl: Type[T]) -> T:
return self.structure(loads(data), cl)
def configure_converter(converter: BaseConverter):
"""
Configure the converter for use with the tomlkit library.
* bytes are serialized as base85 strings
* sets are serialized as lists
* tuples are serializas as lists
* mapping keys are coerced into strings when unstructuring
"""
converter.register_structure_hook(bytes, lambda v, _: b85decode(v))
converter.register_unstructure_hook(
bytes, lambda v: (b85encode(v) if v else b"").decode("utf8")
)
def gen_unstructure_mapping(cl: Any, unstructure_to=None):
key_handler = str
args = getattr(cl, "__args__", None)
if args:
# Currently, tomlkit has inconsistent behavior on 3.11
# so we paper over it here.
# https://github.com/sdispater/tomlkit/issues/237
if issubclass(args[0], str):
key_handler = _enum_value_getter if issubclass(args[0], Enum) else None
elif issubclass(args[0], bytes):
def key_handler(k: bytes):
return b85encode(k).decode("utf8")
return converter.gen_unstructure_mapping(
cl, unstructure_to=unstructure_to, key_handler=key_handler
)
converter._unstructure_func.register_func_list(
[(is_mapping, gen_unstructure_mapping, True)]
)
# datetime inherits from date, so identity unstructure hook used
# here to prevent the date unstructure hook running.
converter.register_unstructure_hook(datetime, lambda v: v)
converter.register_structure_hook(datetime, validate_datetime)
converter.register_unstructure_hook(date, lambda v: v.isoformat())
converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v))
configure_union_passthrough(
Union[str, String, bool, int, Integer, float, Float], converter
)
@wrap(TomlkitConverter)
def make_converter(*args: Any, **kwargs: Any) -> TomlkitConverter:
kwargs["unstruct_collection_overrides"] = {
AbstractSet: list,
tuple: list,
**kwargs.get("unstruct_collection_overrides", {}),
}
res = TomlkitConverter(*args, **kwargs)
configure_converter(res)
return res

View File

@@ -0,0 +1,55 @@
"""Preconfigured converters for ujson."""
from base64 import b85decode, b85encode
from datetime import date, datetime
from typing import Any, AnyStr, Type, TypeVar, Union
from ujson import dumps, loads
from cattrs._compat import AbstractSet
from ..converters import BaseConverter, Converter
from ..strategies import configure_union_passthrough
from . import wrap
T = TypeVar("T")
class UjsonConverter(Converter):
def dumps(self, obj: Any, unstructure_as: Any = None, **kwargs: Any) -> str:
return dumps(self.unstructure(obj, unstructure_as=unstructure_as), **kwargs)
def loads(self, data: AnyStr, cl: Type[T], **kwargs: Any) -> T:
return self.structure(loads(data, **kwargs), cl)
def configure_converter(converter: BaseConverter):
"""
Configure the converter for use with the ujson library.
* bytes are serialized as base64 strings
* datetimes are serialized as ISO 8601
* sets are serialized as lists
"""
converter.register_unstructure_hook(
bytes, lambda v: (b85encode(v) if v else b"").decode("utf8")
)
converter.register_structure_hook(bytes, lambda v, _: b85decode(v))
converter.register_unstructure_hook(datetime, lambda v: v.isoformat())
converter.register_structure_hook(datetime, lambda v, _: datetime.fromisoformat(v))
converter.register_unstructure_hook(date, lambda v: v.isoformat())
converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v))
configure_union_passthrough(Union[str, bool, int, float, None], converter)
@wrap(UjsonConverter)
def make_converter(*args: Any, **kwargs: Any) -> UjsonConverter:
kwargs["unstruct_collection_overrides"] = {
AbstractSet: list,
**kwargs.get("unstruct_collection_overrides", {}),
}
res = UjsonConverter(*args, **kwargs)
configure_converter(res)
return res

View File

@@ -0,0 +1,12 @@
"""High level strategies for converters."""
from ._class_methods import use_class_methods
from ._subclasses import include_subclasses
from ._unions import configure_tagged_union, configure_union_passthrough
__all__ = [
"configure_tagged_union",
"configure_union_passthrough",
"include_subclasses",
"use_class_methods",
]

View File

@@ -0,0 +1,64 @@
"""Strategy for using class-specific (un)structuring methods."""
from inspect import signature
from typing import Any, Callable, Optional, Type, TypeVar
from .. import BaseConverter
T = TypeVar("T")
def use_class_methods(
converter: BaseConverter,
structure_method_name: Optional[str] = None,
unstructure_method_name: Optional[str] = None,
) -> None:
"""
Configure the converter such that dedicated methods are used for (un)structuring
the instance of a class if such methods are available. The default (un)structuring
will be applied if such an (un)structuring methods cannot be found.
:param converter: The `Converter` on which this strategy is applied. You can use
:class:`cattrs.BaseConverter` or any other derived class.
:param structure_method_name: Optional string with the name of the class method
which should be used for structuring. If not provided, no class method will be
used for structuring.
:param unstructure_method_name: Optional string with the name of the class method
which should be used for unstructuring. If not provided, no class method will
be used for unstructuring.
If you want to (un)structured nested objects, just append a converter parameter
to your (un)structuring methods and you will receive the converter there.
.. versionadded:: 23.2.0
"""
if structure_method_name:
def make_class_method_structure(cl: Type[T]) -> Callable[[Any, Type[T]], T]:
fn = getattr(cl, structure_method_name)
n_parameters = len(signature(fn).parameters)
if n_parameters == 1:
return lambda v, _: fn(v)
if n_parameters == 2:
return lambda v, _: fn(v, converter)
raise TypeError("Provide a class method with one or two arguments.")
converter.register_structure_hook_factory(
lambda t: hasattr(t, structure_method_name), make_class_method_structure
)
if unstructure_method_name:
def make_class_method_unstructure(cl: Type[T]) -> Callable[[T], T]:
fn = getattr(cl, unstructure_method_name)
n_parameters = len(signature(fn).parameters)
if n_parameters == 1:
return fn
if n_parameters == 2:
return lambda self_: fn(self_, converter)
raise TypeError("Provide a method with no or one argument.")
converter.register_unstructure_hook_factory(
lambda t: hasattr(t, unstructure_method_name), make_class_method_unstructure
)

View File

@@ -0,0 +1,238 @@
"""Strategies for customizing subclass behaviors."""
from __future__ import annotations
from gc import collect
from typing import Any, Callable, TypeVar, Union
from ..converters import BaseConverter
from ..gen import AttributeOverride, make_dict_structure_fn, make_dict_unstructure_fn
from ..gen._consts import already_generating
def _make_subclasses_tree(cl: type) -> list[type]:
return [cl] + [
sscl for scl in cl.__subclasses__() for sscl in _make_subclasses_tree(scl)
]
def _has_subclasses(cl: type, given_subclasses: tuple[type, ...]) -> bool:
"""Whether the given class has subclasses from `given_subclasses`."""
actual = set(cl.__subclasses__())
given = set(given_subclasses)
return bool(actual & given)
def _get_union_type(cl: type, given_subclasses_tree: tuple[type]) -> type | None:
actual_subclass_tree = tuple(_make_subclasses_tree(cl))
class_tree = tuple(set(actual_subclass_tree) & set(given_subclasses_tree))
return Union[class_tree] if len(class_tree) >= 2 else None
C = TypeVar("C", bound=BaseConverter)
def include_subclasses(
cl: type,
converter: C,
subclasses: tuple[type, ...] | None = None,
union_strategy: Callable[[Any, C], Any] | None = None,
overrides: dict[str, AttributeOverride] | None = None,
) -> None:
"""
Configure the converter so that the attrs/dataclass `cl` is un/structured as if it
was a union of itself and all its subclasses that are defined at the time when this
strategy is applied.
:param cl: A base `attrs` or `dataclass` class.
:param converter: The `Converter` on which this strategy is applied. Do note that
the strategy does not work for a :class:`cattrs.BaseConverter`.
:param subclasses: A tuple of sublcasses whose ancestor is `cl`. If left as `None`,
subclasses are detected using recursively the `__subclasses__` method of `cl`
and its descendents.
:param union_strategy: A callable of two arguments passed by position
(`subclass_union`, `converter`) that defines the union strategy to use to
disambiguate the subclasses union. If `None` (the default), the automatic unique
field disambiguation is used which means that every single subclass
participating in the union must have an attribute name that does not exist in
any other sibling class.
:param overrides: a mapping of `cl` attribute names to overrides (instantiated with
:func:`cattrs.gen.override`) to customize un/structuring.
.. versionadded:: 23.1.0
.. versionchanged:: 24.1.0
When overrides are not provided, hooks for individual classes are retrieved from
the converter instead of generated with no overrides, using converter defaults.
"""
# Due to https://github.com/python-attrs/attrs/issues/1047
collect()
if subclasses is not None:
parent_subclass_tree = (cl, *subclasses)
else:
parent_subclass_tree = tuple(_make_subclasses_tree(cl))
if union_strategy is None:
_include_subclasses_without_union_strategy(
cl, converter, parent_subclass_tree, overrides
)
else:
_include_subclasses_with_union_strategy(
converter, parent_subclass_tree, union_strategy, overrides
)
def _include_subclasses_without_union_strategy(
cl,
converter: BaseConverter,
parent_subclass_tree: tuple[type],
overrides: dict[str, AttributeOverride] | None,
):
# The iteration approach is required if subclasses are more than one level deep:
for cl in parent_subclass_tree:
# We re-create a reduced union type to handle the following case:
#
# converter.structure(d, as=Child)
#
# In the above, the `as=Child` argument will be transformed to a union type of
# itself and its subtypes, that way we guarantee that the returned object will
# not be the parent.
subclass_union = _get_union_type(cl, parent_subclass_tree)
def cls_is_cl(cls, _cl=cl):
return cls is _cl
if overrides is not None:
base_struct_hook = make_dict_structure_fn(cl, converter, **overrides)
base_unstruct_hook = make_dict_unstructure_fn(cl, converter, **overrides)
else:
base_struct_hook = converter.get_structure_hook(cl)
base_unstruct_hook = converter.get_unstructure_hook(cl)
if subclass_union is None:
def struct_hook(val: dict, _, _cl=cl, _base_hook=base_struct_hook) -> cl:
return _base_hook(val, _cl)
else:
dis_fn = converter._get_dis_func(subclass_union, overrides=overrides)
def struct_hook(
val: dict,
_,
_c=converter,
_cl=cl,
_base_hook=base_struct_hook,
_dis_fn=dis_fn,
) -> cl:
"""
If val is disambiguated to the class `cl`, use its base hook.
If val is disambiguated to a subclass, dispatch on its exact runtime
type.
"""
dis_cl = _dis_fn(val)
if dis_cl is _cl:
return _base_hook(val, _cl)
return _c.structure(val, dis_cl)
def unstruct_hook(
val: parent_subclass_tree[0],
_c=converter,
_cl=cl,
_base_hook=base_unstruct_hook,
) -> dict:
"""
If val is an instance of the class `cl`, use the hook.
If val is an instance of a subclass, dispatch on its exact runtime type.
"""
if val.__class__ is _cl:
return _base_hook(val)
return _c.unstructure(val, unstructure_as=val.__class__)
# This needs to use function dispatch, using singledispatch will again
# match A and all subclasses, which is not what we want.
converter.register_structure_hook_func(cls_is_cl, struct_hook)
converter.register_unstructure_hook_func(cls_is_cl, unstruct_hook)
def _include_subclasses_with_union_strategy(
converter: C,
union_classes: tuple[type, ...],
union_strategy: Callable[[Any, C], Any],
overrides: dict[str, AttributeOverride] | None,
):
"""
This function is tricky because we're dealing with what is essentially a circular
reference.
We need to generate a structure hook for a class that is both:
* specific for that particular class and its own fields
* but should handle specific functions for all its descendants too
Hence the dance with registering below.
"""
parent_classes = [cl for cl in union_classes if _has_subclasses(cl, union_classes)]
if not parent_classes:
return
original_unstruct_hooks = {}
original_struct_hooks = {}
for cl in union_classes:
# In the first pass, every class gets its own unstructure function according to
# the overrides.
# We just generate the hooks, and do not register them. This allows us to
# manipulate the _already_generating set to force runtime dispatch.
already_generating.working_set = set(union_classes) - {cl}
try:
if overrides is not None:
unstruct_hook = make_dict_unstructure_fn(cl, converter, **overrides)
struct_hook = make_dict_structure_fn(cl, converter, **overrides)
else:
unstruct_hook = converter.get_unstructure_hook(cl, cache_result=False)
struct_hook = converter.get_structure_hook(cl, cache_result=False)
finally:
already_generating.working_set = set()
original_unstruct_hooks[cl] = unstruct_hook
original_struct_hooks[cl] = struct_hook
# Now that's done, we can register all the hooks and generate the
# union handler. The union handler needs them.
final_union = Union[union_classes] # type: ignore
for cl, hook in original_unstruct_hooks.items():
def cls_is_cl(cls, _cl=cl):
return cls is _cl
converter.register_unstructure_hook_func(cls_is_cl, hook)
for cl, hook in original_struct_hooks.items():
def cls_is_cl(cls, _cl=cl):
return cls is _cl
converter.register_structure_hook_func(cls_is_cl, hook)
union_strategy(final_union, converter)
unstruct_hook = converter.get_unstructure_hook(final_union)
struct_hook = converter.get_structure_hook(final_union)
for cl in union_classes:
# In the second pass, we overwrite the hooks with the union hook.
def cls_is_cl(cls, _cl=cl):
return cls is _cl
converter.register_unstructure_hook_func(cls_is_cl, unstruct_hook)
subclasses = tuple([c for c in union_classes if issubclass(c, cl)])
if len(subclasses) > 1:
u = Union[subclasses] # type: ignore
union_strategy(u, converter)
struct_hook = converter.get_structure_hook(u)
def sh(payload: dict, _, _u=u, _s=struct_hook) -> cl:
return _s(payload, _u)
converter.register_structure_hook_func(cls_is_cl, sh)

View File

@@ -0,0 +1,258 @@
from collections import defaultdict
from typing import Any, Callable, Dict, Literal, Type, Union
from attrs import NOTHING
from cattrs import BaseConverter
from cattrs._compat import get_newtype_base, is_literal, is_subclass, is_union_type
__all__ = [
"default_tag_generator",
"configure_tagged_union",
"configure_union_passthrough",
]
def default_tag_generator(typ: Type) -> str:
"""Return the class name."""
return typ.__name__
def configure_tagged_union(
union: Any,
converter: BaseConverter,
tag_generator: Callable[[Type], str] = default_tag_generator,
tag_name: str = "_type",
default: Union[Type, Literal[NOTHING]] = NOTHING,
) -> None:
"""
Configure the converter so that `union` (which should be a union) is
un/structured with the help of an additional piece of data in the
unstructured payload, the tag.
:param converter: The converter to apply the strategy to.
:param tag_generator: A `tag_generator` function is used to map each
member of the union to a tag, which is then included in the
unstructured payload. The default tag generator returns the name of
the class.
:param tag_name: The key under which the tag will be set in the
unstructured payload. By default, `'_type'`.
:param default: An optional class to be used if the tag information
is not present when structuring.
The tagged union strategy currently only works with the dict
un/structuring base strategy.
.. versionadded:: 23.1.0
"""
args = union.__args__
tag_to_hook = {}
exact_cl_unstruct_hooks = {}
for cl in args:
tag = tag_generator(cl)
struct_handler = converter.get_structure_hook(cl)
unstruct_handler = converter.get_unstructure_hook(cl)
def structure_union_member(val: dict, _cl=cl, _h=struct_handler) -> cl:
return _h(val, _cl)
def unstructure_union_member(val: union, _h=unstruct_handler) -> dict:
return _h(val)
tag_to_hook[tag] = structure_union_member
exact_cl_unstruct_hooks[cl] = unstructure_union_member
cl_to_tag = {cl: tag_generator(cl) for cl in args}
if default is not NOTHING:
default_handler = converter.get_structure_hook(default)
def structure_default(val: dict, _cl=default, _h=default_handler):
return _h(val, _cl)
tag_to_hook = defaultdict(lambda: structure_default, tag_to_hook)
cl_to_tag = defaultdict(lambda: default, cl_to_tag)
def unstructure_tagged_union(
val: union,
_exact_cl_unstruct_hooks=exact_cl_unstruct_hooks,
_cl_to_tag=cl_to_tag,
_tag_name=tag_name,
) -> Dict:
res = _exact_cl_unstruct_hooks[val.__class__](val)
res[_tag_name] = _cl_to_tag[val.__class__]
return res
if default is NOTHING:
if getattr(converter, "forbid_extra_keys", False):
def structure_tagged_union(
val: dict, _, _tag_to_cl=tag_to_hook, _tag_name=tag_name
) -> union:
val = val.copy()
return _tag_to_cl[val.pop(_tag_name)](val)
else:
def structure_tagged_union(
val: dict, _, _tag_to_cl=tag_to_hook, _tag_name=tag_name
) -> union:
return _tag_to_cl[val[_tag_name]](val)
else:
if getattr(converter, "forbid_extra_keys", False):
def structure_tagged_union(
val: dict,
_,
_tag_to_hook=tag_to_hook,
_tag_name=tag_name,
_dh=default_handler,
_default=default,
) -> union:
if _tag_name in val:
val = val.copy()
return _tag_to_hook[val.pop(_tag_name)](val)
return _dh(val, _default)
else:
def structure_tagged_union(
val: dict,
_,
_tag_to_hook=tag_to_hook,
_tag_name=tag_name,
_dh=default_handler,
_default=default,
) -> union:
if _tag_name in val:
return _tag_to_hook[val[_tag_name]](val)
return _dh(val, _default)
converter.register_unstructure_hook(union, unstructure_tagged_union)
converter.register_structure_hook(union, structure_tagged_union)
def configure_union_passthrough(union: Any, converter: BaseConverter) -> None:
"""
Configure the converter to support validating and passing through unions of the
provided types and their subsets.
For example, all mature JSON libraries natively support producing unions of ints,
floats, Nones, and strings. Using this strategy, a converter can be configured
to efficiently validate and pass through unions containing these types.
The most important point is that another library (in this example the JSON
library) handles producing the union, and the converter is configured to just
validate it.
Literals of provided types are also supported, and are checked by value.
NewTypes of provided types are also supported.
The strategy is designed to be O(1) in execution time, and independent of the
ordering of types in the union.
If the union contains a class and one or more of its subclasses, the subclasses
will also be included when validating the superclass.
.. versionadded:: 23.2.0
"""
args = set(union.__args__)
def make_structure_native_union(exact_type: Any) -> Callable:
# `exact_type` is likely to be a subset of the entire configured union (`args`).
literal_values = {
v for t in exact_type.__args__ if is_literal(t) for v in t.__args__
}
# We have no idea what the actual type of `val` will be, so we can't
# use it blindly with an `in` check since it might not be hashable.
# So we do an additional check when handling literals.
# Note: do no use `literal_values` here, since {0, False} gets reduced to {0}
literal_classes = {
v.__class__
for t in exact_type.__args__
if is_literal(t)
for v in t.__args__
}
non_literal_classes = {
get_newtype_base(t) or t
for t in exact_type.__args__
if not is_literal(t) and ((get_newtype_base(t) or t) in args)
}
# We augment the set of allowed classes with any configured subclasses of
# the exact subclasses.
non_literal_classes |= {
a for a in args if any(is_subclass(a, c) for c in non_literal_classes)
}
# We check for spillover - union types not handled by the strategy.
# If spillover exists and we fail to validate our types, we call
# further into the converter with the rest.
spillover = {
a
for a in exact_type.__args__
if (get_newtype_base(a) or a) not in non_literal_classes
and not is_literal(a)
}
if spillover:
spillover_type = (
Union[tuple(spillover)] if len(spillover) > 1 else next(iter(spillover))
)
def structure_native_union(
val: Any,
_: Any,
classes=non_literal_classes,
vals=literal_values,
converter=converter,
spillover=spillover_type,
) -> exact_type:
if val.__class__ in literal_classes and val in vals:
return val
if val.__class__ in classes:
return val
return converter.structure(val, spillover)
else:
def structure_native_union(
val: Any, _: Any, classes=non_literal_classes, vals=literal_values
) -> exact_type:
if val.__class__ in literal_classes and val in vals:
return val
if val.__class__ in classes:
return val
raise TypeError(f"{val} ({val.__class__}) not part of {_}")
return structure_native_union
def contains_native_union(exact_type: Any) -> bool:
"""Can we handle this type?"""
if is_union_type(exact_type):
type_args = set(exact_type.__args__)
# We special case optionals, since they are very common
# and are handled a little more efficiently by default.
if len(type_args) == 2 and type(None) in type_args:
return False
literal_classes = {
lit_arg.__class__
for t in type_args
if is_literal(t)
for lit_arg in t.__args__
}
non_literal_types = {
get_newtype_base(t) or t for t in type_args if not is_literal(t)
}
return (literal_classes | non_literal_types) & args
return False
converter.register_structure_hook_factory(
contains_native_union, make_structure_native_union
)

View File

@@ -0,0 +1,112 @@
"""Cattrs validation."""
from typing import Callable, List, Union
from .errors import (
ClassValidationError,
ForbiddenExtraKeysError,
IterableValidationError,
)
__all__ = ["format_exception", "transform_error"]
def format_exception(exc: BaseException, type: Union[type, None]) -> str:
"""The default exception formatter, handling the most common exceptions.
The following exceptions are handled specially:
* `KeyErrors` (`required field missing`)
* `ValueErrors` (`invalid value for type, expected <type>` or just `invalid value`)
* `TypeErrors` (`invalid value for type, expected <type>` and a couple special
cases for iterables)
* `cattrs.ForbiddenExtraKeysError`
* some `AttributeErrors` (special cased for structing mappings)
"""
if isinstance(exc, KeyError):
res = "required field missing"
elif isinstance(exc, ValueError):
if type is not None:
tn = type.__name__ if hasattr(type, "__name__") else repr(type)
res = f"invalid value for type, expected {tn}"
else:
res = "invalid value"
elif isinstance(exc, TypeError):
if type is None:
if exc.args[0].endswith("object is not iterable"):
res = "invalid value for type, expected an iterable"
else:
res = f"invalid type ({exc})"
else:
tn = type.__name__ if hasattr(type, "__name__") else repr(type)
res = f"invalid value for type, expected {tn}"
elif isinstance(exc, ForbiddenExtraKeysError):
res = f"extra fields found ({', '.join(exc.extra_fields)})"
elif isinstance(exc, AttributeError) and exc.args[0].endswith(
"object has no attribute 'items'"
):
# This was supposed to be a mapping (and have .items()) but it something else.
res = "expected a mapping"
elif isinstance(exc, AttributeError) and exc.args[0].endswith(
"object has no attribute 'copy'"
):
# This was supposed to be a mapping (and have .copy()) but it something else.
# Used for TypedDicts.
res = "expected a mapping"
else:
res = f"unknown error ({exc})"
return res
def transform_error(
exc: Union[ClassValidationError, IterableValidationError, BaseException],
path: str = "$",
format_exception: Callable[
[BaseException, Union[type, None]], str
] = format_exception,
) -> List[str]:
"""Transform an exception into a list of error messages.
To get detailed error messages, the exception should be produced by a converter
with `detailed_validation` set.
By default, the error messages are in the form of `{description} @ {path}`.
While traversing the exception and subexceptions, the path is formed:
* by appending `.{field_name}` for fields in classes
* by appending `[{int}]` for indices in iterables, like lists
* by appending `[{str}]` for keys in mappings, like dictionaries
:param exc: The exception to transform into error messages.
:param path: The root path to use.
:param format_exception: A callable to use to transform `Exceptions` into
string descriptions of errors.
.. versionadded:: 23.1.0
"""
errors = []
if isinstance(exc, IterableValidationError):
with_notes, without = exc.group_exceptions()
for exc, note in with_notes:
p = f"{path}[{note.index!r}]"
if isinstance(exc, (ClassValidationError, IterableValidationError)):
errors.extend(transform_error(exc, p, format_exception))
else:
errors.append(f"{format_exception(exc, note.type)} @ {p}")
for exc in without:
errors.append(f"{format_exception(exc, None)} @ {path}")
elif isinstance(exc, ClassValidationError):
with_notes, without = exc.group_exceptions()
for exc, note in with_notes:
p = f"{path}.{note.name}"
if isinstance(exc, (ClassValidationError, IterableValidationError)):
errors.extend(transform_error(exc, p, format_exception))
else:
errors.append(f"{format_exception(exc, note.type)} @ {p}")
for exc in without:
errors.append(f"{format_exception(exc, None)} @ {path}")
else:
errors.append(f"{format_exception(exc, None)} @ {path}")
return errors