186 lines
6.2 KiB
Python
186 lines
6.2 KiB
Python
![]() |
"""Preconfigured converters for msgspec."""
|
||
|
|
||
|
from __future__ import annotations
|
||
|
|
||
|
from base64 import b64decode
|
||
|
from datetime import date, datetime
|
||
|
from enum import Enum
|
||
|
from functools import partial
|
||
|
from typing import Any, Callable, TypeVar, Union, get_type_hints
|
||
|
|
||
|
from attrs import has as attrs_has
|
||
|
from attrs import resolve_types
|
||
|
from msgspec import Struct, convert, to_builtins
|
||
|
from msgspec.json import Encoder, decode
|
||
|
|
||
|
from .._compat import (
|
||
|
fields,
|
||
|
get_args,
|
||
|
get_origin,
|
||
|
has,
|
||
|
is_bare,
|
||
|
is_mapping,
|
||
|
is_sequence,
|
||
|
)
|
||
|
from ..cols import is_namedtuple
|
||
|
from ..converters import BaseConverter, Converter
|
||
|
from ..dispatch import UnstructureHook
|
||
|
from ..fns import identity
|
||
|
from ..gen import make_hetero_tuple_unstructure_fn
|
||
|
from ..strategies import configure_union_passthrough
|
||
|
from . import wrap
|
||
|
|
||
|
T = TypeVar("T")
|
||
|
|
||
|
__all__ = ["MsgspecJsonConverter", "configure_converter", "make_converter"]
|
||
|
|
||
|
|
||
|
class MsgspecJsonConverter(Converter):
|
||
|
"""A converter specialized for the _msgspec_ library."""
|
||
|
|
||
|
#: The msgspec encoder for dumping.
|
||
|
encoder: Encoder = Encoder()
|
||
|
|
||
|
def dumps(self, obj: Any, unstructure_as: Any = None, **kwargs: Any) -> bytes:
|
||
|
"""Unstructure and encode `obj` into JSON bytes."""
|
||
|
return self.encoder.encode(
|
||
|
self.unstructure(obj, unstructure_as=unstructure_as), **kwargs
|
||
|
)
|
||
|
|
||
|
def get_dumps_hook(
|
||
|
self, unstructure_as: Any, **kwargs: Any
|
||
|
) -> Callable[[Any], bytes]:
|
||
|
"""Produce a `dumps` hook for the given type."""
|
||
|
unstruct_hook = self.get_unstructure_hook(unstructure_as)
|
||
|
if unstruct_hook in (identity, to_builtins):
|
||
|
return self.encoder.encode
|
||
|
return self.dumps
|
||
|
|
||
|
def loads(self, data: bytes, cl: type[T], **kwargs: Any) -> T:
|
||
|
"""Decode and structure `cl` from the provided JSON bytes."""
|
||
|
return self.structure(decode(data, **kwargs), cl)
|
||
|
|
||
|
def get_loads_hook(self, cl: type[T]) -> Callable[[bytes], T]:
|
||
|
"""Produce a `loads` hook for the given type."""
|
||
|
return partial(self.loads, cl=cl)
|
||
|
|
||
|
|
||
|
def configure_converter(converter: Converter) -> None:
|
||
|
"""Configure the converter for the msgspec library.
|
||
|
|
||
|
* bytes are serialized as base64 strings, directly by msgspec
|
||
|
* datetimes and dates are passed through to be serialized as RFC 3339 directly
|
||
|
* enums are passed through to msgspec directly
|
||
|
* union passthrough configured for str, bool, int, float and None
|
||
|
"""
|
||
|
configure_passthroughs(converter)
|
||
|
|
||
|
converter.register_unstructure_hook(Struct, to_builtins)
|
||
|
converter.register_unstructure_hook(Enum, to_builtins)
|
||
|
|
||
|
converter.register_structure_hook(Struct, convert)
|
||
|
converter.register_structure_hook(bytes, lambda v, _: b64decode(v))
|
||
|
converter.register_structure_hook(datetime, lambda v, _: convert(v, datetime))
|
||
|
converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v))
|
||
|
configure_union_passthrough(Union[str, bool, int, float, None], converter)
|
||
|
|
||
|
|
||
|
@wrap(MsgspecJsonConverter)
|
||
|
def make_converter(*args: Any, **kwargs: Any) -> MsgspecJsonConverter:
|
||
|
res = MsgspecJsonConverter(*args, **kwargs)
|
||
|
configure_converter(res)
|
||
|
return res
|
||
|
|
||
|
|
||
|
def configure_passthroughs(converter: Converter) -> None:
|
||
|
"""Configure optimizing passthroughs.
|
||
|
|
||
|
A passthrough is when we let msgspec handle something automatically.
|
||
|
"""
|
||
|
converter.register_unstructure_hook(bytes, to_builtins)
|
||
|
converter.register_unstructure_hook_factory(is_mapping, mapping_unstructure_factory)
|
||
|
converter.register_unstructure_hook_factory(is_sequence, seq_unstructure_factory)
|
||
|
converter.register_unstructure_hook_factory(has, attrs_unstructure_factory)
|
||
|
converter.register_unstructure_hook_factory(
|
||
|
is_namedtuple, namedtuple_unstructure_factory
|
||
|
)
|
||
|
|
||
|
|
||
|
def seq_unstructure_factory(type, converter: Converter) -> UnstructureHook:
|
||
|
"""The msgspec unstructure hook factory for sequences."""
|
||
|
if is_bare(type):
|
||
|
type_arg = Any
|
||
|
else:
|
||
|
args = get_args(type)
|
||
|
type_arg = args[0]
|
||
|
handler = converter.get_unstructure_hook(type_arg, cache_result=False)
|
||
|
|
||
|
if handler in (identity, to_builtins):
|
||
|
return handler
|
||
|
return converter.gen_unstructure_iterable(type)
|
||
|
|
||
|
|
||
|
def mapping_unstructure_factory(type, converter: BaseConverter) -> UnstructureHook:
|
||
|
"""The msgspec unstructure hook factory for mappings."""
|
||
|
if is_bare(type):
|
||
|
key_arg = Any
|
||
|
val_arg = Any
|
||
|
key_handler = converter.get_unstructure_hook(key_arg, cache_result=False)
|
||
|
value_handler = converter.get_unstructure_hook(val_arg, cache_result=False)
|
||
|
else:
|
||
|
args = get_args(type)
|
||
|
if len(args) == 2:
|
||
|
key_arg, val_arg = args
|
||
|
else:
|
||
|
# Probably a Counter
|
||
|
key_arg, val_arg = args, Any
|
||
|
key_handler = converter.get_unstructure_hook(key_arg, cache_result=False)
|
||
|
value_handler = converter.get_unstructure_hook(val_arg, cache_result=False)
|
||
|
|
||
|
if key_handler in (identity, to_builtins) and value_handler in (
|
||
|
identity,
|
||
|
to_builtins,
|
||
|
):
|
||
|
return to_builtins
|
||
|
return converter.gen_unstructure_mapping(type)
|
||
|
|
||
|
|
||
|
def attrs_unstructure_factory(type: Any, converter: Converter) -> UnstructureHook:
|
||
|
"""Choose whether to use msgspec handling or our own."""
|
||
|
origin = get_origin(type)
|
||
|
attribs = fields(origin or type)
|
||
|
if attrs_has(type) and any(isinstance(a.type, str) for a in attribs):
|
||
|
resolve_types(type)
|
||
|
attribs = fields(origin or type)
|
||
|
|
||
|
if any(
|
||
|
attr.name.startswith("_")
|
||
|
or (
|
||
|
converter.get_unstructure_hook(attr.type, cache_result=False)
|
||
|
not in (identity, to_builtins)
|
||
|
)
|
||
|
for attr in attribs
|
||
|
):
|
||
|
return converter.gen_unstructure_attrs_fromdict(type)
|
||
|
|
||
|
return to_builtins
|
||
|
|
||
|
|
||
|
def namedtuple_unstructure_factory(
|
||
|
type: type[tuple], converter: BaseConverter
|
||
|
) -> UnstructureHook:
|
||
|
"""A hook factory for unstructuring namedtuples, modified for msgspec."""
|
||
|
|
||
|
if all(
|
||
|
converter.get_unstructure_hook(t) in (identity, to_builtins)
|
||
|
for t in get_type_hints(type).values()
|
||
|
):
|
||
|
return identity
|
||
|
|
||
|
return make_hetero_tuple_unstructure_fn(
|
||
|
type,
|
||
|
converter,
|
||
|
unstructure_to=tuple,
|
||
|
type_args=tuple(get_type_hints(type).values()),
|
||
|
)
|