some random stuff. caelestia incoming

This commit is contained in:
voidarclabs
2025-08-21 17:40:48 +01:00
parent 12df9a0b6e
commit 1cc414a96a
1308 changed files with 217219 additions and 8 deletions

View File

@@ -0,0 +1,10 @@
from mashumaro.exceptions import MissingField
from mashumaro.helper import field_options, pass_through
from mashumaro.mixins.dict import DataClassDictMixin
__all__ = [
"MissingField",
"DataClassDictMixin",
"field_options",
"pass_through",
]

View File

@@ -0,0 +1,6 @@
from .basic import BasicDecoder, BasicEncoder
__all__ = [
"BasicDecoder",
"BasicEncoder",
]

View File

@@ -0,0 +1,106 @@
import re
from collections.abc import Callable
from typing import Any, Optional, Type
from mashumaro.core.meta.code.builder import CodeBuilder
from mashumaro.core.meta.helpers import is_optional, is_type_var_any
from mashumaro.core.meta.types.common import (
AttrsHolder,
FieldContext,
ValueSpec,
)
from mashumaro.core.meta.types.pack import PackerRegistry
from mashumaro.core.meta.types.unpack import UnpackerRegistry
CALL_EXPR = re.compile(r"^([^ ]+)\(value\)$")
class CodecCodeBuilder(CodeBuilder):
@classmethod
def new(cls, **kwargs: Any) -> "CodecCodeBuilder":
if "attrs" not in kwargs:
kwargs["attrs"] = AttrsHolder()
return cls(AttrsHolder("__root__"), **kwargs) # type: ignore
def add_decode_method(
self,
shape_type: Type,
decoder_obj: Any,
pre_decoder_func: Optional[Callable[[Any], Any]] = None,
) -> None:
self.reset()
with self.indent("def decode(value):"):
if pre_decoder_func:
self.ensure_object_imported(pre_decoder_func, "decoder")
self.add_line("value = decoder(value)")
could_be_none = (
shape_type in (Any, type(None), None)
or is_type_var_any(self.get_real_type("", shape_type))
or is_optional(
shape_type, self.get_field_resolved_type_params("")
)
)
unpacked_value = UnpackerRegistry.get(
ValueSpec(
type=shape_type,
expression="value",
builder=self,
field_ctx=FieldContext(name="", metadata={}),
could_be_none=could_be_none,
)
)
self.add_line(f"return {unpacked_value}")
self.add_line("setattr(decoder_obj, 'decode', decode)")
if pre_decoder_func is None:
m = CALL_EXPR.match(unpacked_value)
if m:
method_name = m.group(1)
self.lines.reset()
self.add_line(f"setattr(decoder_obj, 'decode', {method_name})")
self.ensure_object_imported(decoder_obj, "decoder_obj")
self.ensure_object_imported(self.cls, "cls")
self.compile()
def add_encode_method(
self,
shape_type: Type,
encoder_obj: Any,
post_encoder_func: Optional[Callable[[Any], Any]] = None,
) -> None:
self.reset()
with self.indent("def encode(value):"):
could_be_none = (
shape_type in (Any, type(None), None)
or is_type_var_any(self.get_real_type("", shape_type))
or is_optional(
shape_type, self.get_field_resolved_type_params("")
)
)
packed_value = PackerRegistry.get(
ValueSpec(
type=shape_type,
expression="value",
builder=self,
field_ctx=FieldContext(name="", metadata={}),
could_be_none=could_be_none,
no_copy_collections=self.get_dialect_or_config_option(
"no_copy_collections", ()
),
)
)
if post_encoder_func:
self.ensure_object_imported(post_encoder_func, "encoder")
self.add_line(f"return encoder({packed_value})")
else:
self.add_line(f"return {packed_value}")
self.add_line("setattr(encoder_obj, 'encode', encode)")
if post_encoder_func is None:
m = CALL_EXPR.match(packed_value)
if m:
method_name = m.group(1)
self.lines.reset()
self.add_line(f"setattr(encoder_obj, 'encode', {method_name})")
self.ensure_object_imported(encoder_obj, "encoder_obj")
self.ensure_object_imported(self.cls, "cls")
self.ensure_object_imported(self.cls, "self")
self.compile()

View File

@@ -0,0 +1,103 @@
from collections.abc import Callable
from typing import (
Any,
Generic,
Optional,
Type,
TypeVar,
Union,
final,
overload,
)
from mashumaro.codecs._builder import CodecCodeBuilder
from mashumaro.core.meta.helpers import get_args
from mashumaro.dialect import Dialect
T = TypeVar("T")
class BasicDecoder(Generic[T]):
@overload
def __init__(
self,
shape_type: Type[T],
*,
default_dialect: Optional[Type[Dialect]] = None,
pre_decoder_func: Optional[Callable[[Any], Any]] = None,
): ...
@overload
def __init__(
self,
shape_type: Any,
*,
default_dialect: Optional[Type[Dialect]] = None,
pre_decoder_func: Optional[Callable[[Any], Any]] = None,
): ...
def __init__(
self,
shape_type: Union[Type[T], Any],
*,
default_dialect: Optional[Type[Dialect]] = None,
pre_decoder_func: Optional[Callable[[Any], Any]] = None,
):
code_builder = CodecCodeBuilder.new(
type_args=get_args(shape_type), default_dialect=default_dialect
)
code_builder.add_decode_method(shape_type, self, pre_decoder_func)
@final
def decode(self, data: Any) -> T: ...
class BasicEncoder(Generic[T]):
@overload
def __init__(
self,
shape_type: Type[T],
*,
default_dialect: Optional[Type[Dialect]] = None,
post_encoder_func: Optional[Callable[[Any], Any]] = None,
): ...
@overload
def __init__(
self,
shape_type: Any,
*,
default_dialect: Optional[Type[Dialect]] = None,
post_encoder_func: Optional[Callable[[Any], Any]] = None,
): ...
def __init__(
self,
shape_type: Union[Type[T], Any],
*,
default_dialect: Optional[Type[Dialect]] = None,
post_encoder_func: Optional[Callable[[Any], Any]] = None,
):
code_builder = CodecCodeBuilder.new(
type_args=get_args(shape_type), default_dialect=default_dialect
)
code_builder.add_encode_method(shape_type, self, post_encoder_func)
@final
def encode(self, obj: T) -> Any: ...
def decode(data: Any, shape_type: Union[Type[T], Any]) -> T:
return BasicDecoder(shape_type).decode(data)
def encode(obj: T, shape_type: Union[Type[T], Any]) -> Any:
return BasicEncoder(shape_type).encode(obj)
__all__ = [
"BasicDecoder",
"BasicEncoder",
"decode",
"encode",
]

View File

@@ -0,0 +1,123 @@
import json
from collections.abc import Callable
from typing import (
Any,
Generic,
Optional,
Type,
TypeVar,
Union,
final,
overload,
)
from mashumaro.codecs._builder import CodecCodeBuilder
from mashumaro.core.meta.helpers import get_args
from mashumaro.dialect import Dialect
T = TypeVar("T")
EncodedData = Union[str, bytes, bytearray]
class JSONDecoder(Generic[T]):
@overload
def __init__(
self,
shape_type: Type[T],
*,
default_dialect: Optional[Type[Dialect]] = None,
pre_decoder_func: Callable[[EncodedData], Any] = json.loads,
): ...
@overload
def __init__(
self,
shape_type: Any,
*,
default_dialect: Optional[Type[Dialect]] = None,
pre_decoder_func: Callable[[EncodedData], Any] = json.loads,
): ...
def __init__(
self,
shape_type: Union[Type[T], Any],
*,
default_dialect: Optional[Type[Dialect]] = None,
pre_decoder_func: Callable[[EncodedData], Any] = json.loads,
):
code_builder = CodecCodeBuilder.new(
type_args=get_args(shape_type), default_dialect=default_dialect
)
code_builder.add_decode_method(shape_type, self, pre_decoder_func)
@final
def decode(self, data: EncodedData) -> T: ...
class JSONEncoder(Generic[T]):
@overload
def __init__(
self,
shape_type: Type[T],
*,
default_dialect: Optional[Type[Dialect]] = None,
post_encoder_func: Callable[[Any], str] = json.dumps,
): ...
@overload
def __init__(
self,
shape_type: Any,
*,
default_dialect: Optional[Type[Dialect]] = None,
post_encoder_func: Callable[[Any], str] = json.dumps,
): ...
def __init__(
self,
shape_type: Union[Type[T], Any],
*,
default_dialect: Optional[Type[Dialect]] = None,
post_encoder_func: Callable[[Any], str] = json.dumps,
):
code_builder = CodecCodeBuilder.new(
type_args=get_args(shape_type), default_dialect=default_dialect
)
code_builder.add_encode_method(shape_type, self, post_encoder_func)
@final
def encode(self, obj: T) -> str: ...
def json_decode(
data: EncodedData,
shape_type: Union[Type[T], Any],
pre_decoder_func: Callable[[EncodedData], Any] = json.loads,
) -> T:
return JSONDecoder(shape_type, pre_decoder_func=pre_decoder_func).decode(
data
)
def json_encode(
obj: T,
shape_type: Union[Type[T], Any],
post_encoder_func: Callable[[Any], str] = json.dumps,
) -> str:
return JSONEncoder(shape_type, post_encoder_func=post_encoder_func).encode(
obj
)
decode = json_decode
encode = json_encode
__all__ = [
"JSONDecoder",
"JSONEncoder",
"json_decode",
"json_encode",
"decode",
"encode",
]

View File

@@ -0,0 +1,132 @@
from collections.abc import Callable
from typing import (
Any,
Generic,
Optional,
Type,
TypeVar,
Union,
final,
overload,
)
import msgpack
from mashumaro.codecs._builder import CodecCodeBuilder
from mashumaro.core.meta.helpers import get_args
from mashumaro.dialect import Dialect
from mashumaro.mixins.msgpack import MessagePackDialect
T = TypeVar("T")
EncodedData = bytes
PostEncoderFunc = Callable[[Any], EncodedData]
PreDecoderFunc = Callable[[EncodedData], Any]
def _default_decoder(data: EncodedData) -> Any:
return msgpack.unpackb(data, raw=False)
def _default_encoder(data: Any) -> EncodedData:
return msgpack.packb(data, use_bin_type=True)
class MessagePackDecoder(Generic[T]):
@overload
def __init__(
self,
shape_type: Type[T],
*,
default_dialect: Optional[Type[Dialect]] = None,
pre_decoder_func: Optional[PreDecoderFunc] = _default_decoder,
): ...
@overload
def __init__(
self,
shape_type: Any,
*,
default_dialect: Optional[Type[Dialect]] = None,
pre_decoder_func: Optional[PreDecoderFunc] = _default_decoder,
): ...
def __init__(
self,
shape_type: Union[Type[T], Any],
*,
default_dialect: Optional[Type[Dialect]] = None,
pre_decoder_func: Optional[PreDecoderFunc] = _default_decoder,
):
if default_dialect is not None:
default_dialect = MessagePackDialect.merge(default_dialect)
else:
default_dialect = MessagePackDialect
code_builder = CodecCodeBuilder.new(
type_args=get_args(shape_type), default_dialect=default_dialect
)
code_builder.add_decode_method(shape_type, self, pre_decoder_func)
@final
def decode(self, data: EncodedData) -> T: ...
class MessagePackEncoder(Generic[T]):
@overload
def __init__(
self,
shape_type: Type[T],
*,
default_dialect: Optional[Type[Dialect]] = None,
post_encoder_func: Optional[PostEncoderFunc] = _default_encoder,
): ...
@overload
def __init__(
self,
shape_type: Any,
*,
default_dialect: Optional[Type[Dialect]] = None,
post_encoder_func: Optional[PostEncoderFunc] = _default_encoder,
): ...
def __init__(
self,
shape_type: Union[Type[T], Any],
*,
default_dialect: Optional[Type[Dialect]] = None,
post_encoder_func: Optional[PostEncoderFunc] = _default_encoder,
):
if default_dialect is not None:
default_dialect = MessagePackDialect.merge(default_dialect)
else:
default_dialect = MessagePackDialect
code_builder = CodecCodeBuilder.new(
type_args=get_args(shape_type), default_dialect=default_dialect
)
code_builder.add_encode_method(shape_type, self, post_encoder_func)
@final
def encode(self, obj: T) -> EncodedData: ...
def msgpack_decode(data: EncodedData, shape_type: Union[Type[T], Any]) -> T:
return MessagePackDecoder(shape_type).decode(data)
def msgpack_encode(obj: T, shape_type: Union[Type[T], Any]) -> EncodedData:
return MessagePackEncoder(shape_type).encode(obj)
decode = msgpack_decode
encode = msgpack_encode
__all__ = [
"MessagePackDecoder",
"MessagePackEncoder",
"msgpack_decode",
"msgpack_encode",
"decode",
"encode",
]

View File

@@ -0,0 +1,114 @@
from typing import (
Any,
Generic,
Optional,
Type,
TypeVar,
Union,
final,
overload,
)
import orjson
from mashumaro.codecs._builder import CodecCodeBuilder
from mashumaro.core.meta.helpers import get_args
from mashumaro.dialect import Dialect
from mashumaro.mixins.orjson import OrjsonDialect
T = TypeVar("T")
EncodedData = Union[str, bytes, bytearray]
class ORJSONDecoder(Generic[T]):
@overload
def __init__(
self,
shape_type: Type[T],
*,
default_dialect: Optional[Type[Dialect]] = None,
): ...
@overload
def __init__(
self,
shape_type: Any,
*,
default_dialect: Optional[Type[Dialect]] = None,
): ...
def __init__(
self,
shape_type: Union[Type[T], Any],
*,
default_dialect: Optional[Type[Dialect]] = None,
):
if default_dialect is not None:
default_dialect = OrjsonDialect.merge(default_dialect)
else:
default_dialect = OrjsonDialect
code_builder = CodecCodeBuilder.new(
type_args=get_args(shape_type), default_dialect=default_dialect
)
code_builder.add_decode_method(shape_type, self, orjson.loads)
@final
def decode(self, data: EncodedData) -> T: ...
class ORJSONEncoder(Generic[T]):
@overload
def __init__(
self,
shape_type: Type[T],
*,
default_dialect: Optional[Type[Dialect]] = None,
): ...
@overload
def __init__(
self,
shape_type: Any,
*,
default_dialect: Optional[Type[Dialect]] = None,
): ...
def __init__(
self,
shape_type: Union[Type[T], Any],
*,
default_dialect: Optional[Type[Dialect]] = None,
):
if default_dialect is not None:
default_dialect = OrjsonDialect.merge(default_dialect)
else:
default_dialect = OrjsonDialect
code_builder = CodecCodeBuilder.new(
type_args=get_args(shape_type), default_dialect=default_dialect
)
code_builder.add_encode_method(shape_type, self, orjson.dumps)
@final
def encode(self, obj: T) -> bytes: ...
def json_decode(data: EncodedData, shape_type: Type[T]) -> T:
return ORJSONDecoder(shape_type).decode(data)
def json_encode(obj: T, shape_type: Union[Type[T], Any]) -> bytes:
return ORJSONEncoder(shape_type).encode(obj)
decode = json_decode
encode = json_encode
__all__ = [
"ORJSONDecoder",
"ORJSONEncoder",
"json_decode",
"json_encode",
"decode",
"encode",
]

View File

@@ -0,0 +1,119 @@
from typing import (
Any,
Generic,
Optional,
Type,
TypeVar,
Union,
final,
overload,
)
import tomli_w
from mashumaro.codecs._builder import CodecCodeBuilder
from mashumaro.core.meta.helpers import get_args
from mashumaro.dialect import Dialect
from mashumaro.mixins.toml import TOMLDialect
try:
import tomllib
except ModuleNotFoundError:
import tomli as tomllib # type: ignore
T = TypeVar("T")
EncodedData = str
class TOMLDecoder(Generic[T]):
@overload
def __init__(
self,
shape_type: Type[T],
*,
default_dialect: Optional[Type[Dialect]] = None,
): ...
@overload
def __init__(
self,
shape_type: Any,
*,
default_dialect: Optional[Type[Dialect]] = None,
): ...
def __init__(
self,
shape_type: Union[Type[T], Any],
*,
default_dialect: Optional[Type[Dialect]] = None,
):
if default_dialect is not None:
default_dialect = TOMLDialect.merge(default_dialect)
else:
default_dialect = TOMLDialect
code_builder = CodecCodeBuilder.new(
type_args=get_args(shape_type), default_dialect=default_dialect
)
code_builder.add_decode_method(shape_type, self, tomllib.loads)
@final
def decode(self, data: EncodedData) -> T: ...
class TOMLEncoder(Generic[T]):
@overload
def __init__(
self,
shape_type: Type[T],
*,
default_dialect: Optional[Type[Dialect]] = None,
): ...
@overload
def __init__(
self,
shape_type: Any,
*,
default_dialect: Optional[Type[Dialect]] = None,
): ...
def __init__(
self,
shape_type: Union[Type[T], Any],
*,
default_dialect: Optional[Type[Dialect]] = None,
):
if default_dialect is not None:
default_dialect = TOMLDialect.merge(default_dialect)
else:
default_dialect = TOMLDialect
code_builder = CodecCodeBuilder.new(
type_args=get_args(shape_type), default_dialect=default_dialect
)
code_builder.add_encode_method(shape_type, self, tomli_w.dumps)
@final
def encode(self, obj: T) -> bytes: ...
def toml_decode(data: EncodedData, shape_type: Type[T]) -> T:
return TOMLDecoder(shape_type).decode(data)
def toml_encode(obj: T, shape_type: Union[Type[T], Any]) -> bytes:
return TOMLEncoder(shape_type).encode(obj)
decode = toml_decode
encode = toml_encode
__all__ = [
"TOMLDecoder",
"TOMLEncoder",
"toml_decode",
"toml_encode",
"decode",
"encode",
]

View File

@@ -0,0 +1,128 @@
from collections.abc import Callable
from typing import (
Any,
Generic,
Optional,
Type,
TypeVar,
Union,
final,
overload,
)
import yaml
from mashumaro.codecs._builder import CodecCodeBuilder
from mashumaro.core.meta.helpers import get_args
from mashumaro.dialect import Dialect
T = TypeVar("T")
EncodedData = Union[str, bytes]
PostEncoderFunc = Callable[[Any], EncodedData]
PreDecoderFunc = Callable[[EncodedData], Any]
DefaultLoader = getattr(yaml, "CSafeLoader", yaml.SafeLoader)
DefaultDumper = getattr(yaml, "CDumper", yaml.Dumper)
def _default_encoder(data: Any) -> EncodedData:
return yaml.dump(data, Dumper=DefaultDumper)
def _default_decoder(data: EncodedData) -> Any:
return yaml.load(data, DefaultLoader)
class YAMLDecoder(Generic[T]):
@overload
def __init__(
self,
shape_type: Type[T],
*,
default_dialect: Optional[Type[Dialect]] = None,
pre_decoder_func: Optional[PreDecoderFunc] = _default_decoder,
): ...
@overload
def __init__(
self,
shape_type: Any,
*,
default_dialect: Optional[Type[Dialect]] = None,
pre_decoder_func: Optional[PreDecoderFunc] = _default_decoder,
): ...
def __init__(
self,
shape_type: Union[Type[T], Any],
*,
default_dialect: Optional[Type[Dialect]] = None,
pre_decoder_func: Optional[PreDecoderFunc] = _default_decoder,
):
code_builder = CodecCodeBuilder.new(
type_args=get_args(shape_type), default_dialect=default_dialect
)
code_builder.add_decode_method(shape_type, self, pre_decoder_func)
@final
def decode(self, data: EncodedData) -> T: ...
class YAMLEncoder(Generic[T]):
@overload
def __init__(
self,
shape_type: Type[T],
*,
default_dialect: Optional[Type[Dialect]] = None,
post_encoder_func: Optional[PostEncoderFunc] = _default_encoder,
): ...
@overload
def __init__(
self,
shape_type: Any,
*,
default_dialect: Optional[Type[Dialect]] = None,
post_encoder_func: Optional[PostEncoderFunc] = _default_encoder,
): ...
def __init__(
self,
shape_type: Union[Type[T], Any],
*,
default_dialect: Optional[Type[Dialect]] = None,
post_encoder_func: Optional[PostEncoderFunc] = _default_encoder,
):
code_builder = CodecCodeBuilder.new(
type_args=get_args(shape_type), default_dialect=default_dialect
)
code_builder.add_encode_method(shape_type, self, post_encoder_func)
@final
def encode(self, obj: T) -> EncodedData: ...
def yaml_decode(data: EncodedData, shape_type: Union[Type[T], Any]) -> T:
return YAMLDecoder(shape_type).decode(data)
def yaml_encode(obj: T, shape_type: Union[Type[T], Any]) -> EncodedData:
return YAMLEncoder(shape_type).encode(obj)
decode = yaml_decode
encode = yaml_encode
__all__ = [
"YAMLDecoder",
"YAMLEncoder",
"yaml_decode",
"yaml_encode",
"decode",
"encode",
]

View File

@@ -0,0 +1,63 @@
from collections.abc import Callable
from typing import Any, Literal, Optional, Type, TypedDict, Union
from mashumaro.core.const import Sentinel
from mashumaro.dialect import Dialect
from mashumaro.types import Discriminator, SerializationStrategy
__all__ = [
"BaseConfig",
"TO_DICT_ADD_BY_ALIAS_FLAG",
"TO_DICT_ADD_OMIT_NONE_FLAG",
"ADD_DIALECT_SUPPORT",
"ADD_SERIALIZATION_CONTEXT",
"SerializationStrategyValueType",
]
TO_DICT_ADD_BY_ALIAS_FLAG = "TO_DICT_ADD_BY_ALIAS_FLAG"
TO_DICT_ADD_OMIT_NONE_FLAG = "TO_DICT_ADD_OMIT_NONE_FLAG"
ADD_DIALECT_SUPPORT = "ADD_DIALECT_SUPPORT"
ADD_SERIALIZATION_CONTEXT = "ADD_SERIALIZATION_CONTEXT"
CodeGenerationOption = Literal[
"TO_DICT_ADD_BY_ALIAS_FLAG",
"TO_DICT_ADD_OMIT_NONE_FLAG",
"ADD_DIALECT_SUPPORT",
"ADD_SERIALIZATION_CONTEXT",
]
class SerializationStrategyDict(TypedDict, total=False):
serialize: Union[str, Callable]
deserialize: Union[str, Callable]
SerializationStrategyValueType = Union[
SerializationStrategy, SerializationStrategyDict
]
class BaseConfig:
debug: bool = False
code_generation_options: list[CodeGenerationOption] = []
serialization_strategy: dict[Any, SerializationStrategyValueType] = {}
aliases: dict[str, str] = {}
serialize_by_alias: Union[bool, Literal[Sentinel.MISSING]] = (
Sentinel.MISSING
)
namedtuple_as_dict: Union[bool, Literal[Sentinel.MISSING]] = (
Sentinel.MISSING
)
allow_postponed_evaluation: bool = True
dialect: Optional[Type[Dialect]] = None
omit_none: Union[bool, Literal[Sentinel.MISSING]] = Sentinel.MISSING
omit_default: Union[bool, Literal[Sentinel.MISSING]] = Sentinel.MISSING
orjson_options: Optional[int] = 0
json_schema: dict[str, Any] = {}
discriminator: Optional[Discriminator] = None
lazy_compilation: bool = False
sort_keys: bool = False
allow_deserialization_not_by_alias: bool = False
forbid_extra_keys: bool = False

View File

@@ -0,0 +1,27 @@
import enum
import sys
__all__ = [
"PY_39",
"PY_310",
"PY_310_MIN",
"PY_311_MIN",
"PY_312_MIN",
"PY_313_MIN",
"Sentinel",
]
PY_39 = sys.version_info.major == 3 and sys.version_info.minor == 9
PY_310 = sys.version_info.major == 3 and sys.version_info.minor == 10
PY_311 = sys.version_info.major == 3 and sys.version_info.minor == 11
PY_312 = sys.version_info.major == 3 and sys.version_info.minor == 12
PY_313_MIN = sys.version_info.major == 3 and sys.version_info.minor >= 13
PY_312_MIN = PY_312 or PY_313_MIN
PY_311_MIN = PY_311 or PY_312_MIN
PY_310_MIN = PY_310 or PY_311_MIN
class Sentinel(enum.Enum):
MISSING = enum.auto()

View File

@@ -0,0 +1,35 @@
import datetime
import re
__all__ = [
"parse_timezone",
"ConfigValue",
"UTC_OFFSET_PATTERN",
]
UTC_OFFSET_PATTERN = r"^UTC(([+-][0-2][0-9]):([0-5][0-9]))?$"
UTC_OFFSET_RE = re.compile(UTC_OFFSET_PATTERN)
def parse_timezone(s: str) -> datetime.timezone:
match = UTC_OFFSET_RE.match(s)
if not match:
raise ValueError(
f"Time zone {s} must be either UTC or in format UTC[+-]hh:mm"
)
if match.group(1):
hours = int(match.group(2))
minutes = int(match.group(3))
return datetime.timezone(
datetime.timedelta(
hours=hours, minutes=minutes if hours >= 0 else -minutes
)
)
else:
return datetime.timezone.utc
class ConfigValue:
def __init__(self, name: str):
self.name = name

View File

@@ -0,0 +1,38 @@
from collections.abc import Generator
from contextlib import contextmanager
from typing import Optional
__all__ = ["CodeLines"]
class CodeLines:
def __init__(self) -> None:
self._lines: list[str] = []
self._current_indent: str = ""
def append(self, line: str) -> None:
self._lines.append(f"{self._current_indent}{line}")
def extend(self, lines: "CodeLines") -> None:
for line in lines._lines:
self._lines.append(f"{self._current_indent}{line}")
@contextmanager
def indent(
self,
expr: Optional[str] = None,
) -> Generator[None, None, None]:
if expr:
self.append(expr)
self._current_indent += " " * 4
try:
yield
finally:
self._current_indent = self._current_indent[:-4]
def as_text(self) -> str:
return "\n".join(self._lines)
def reset(self) -> None:
self._lines = []
self._current_indent = ""

View File

@@ -0,0 +1,812 @@
import dataclasses
import enum
import inspect
import re
import sys
import types
import typing
from collections.abc import Callable, Hashable, Iterable, Iterator
from contextlib import suppress
# noinspection PyProtectedMember
from dataclasses import _FIELDS # type: ignore
from hashlib import md5
from typing import (
Any,
ClassVar,
ForwardRef,
Optional,
Sequence,
Tuple,
Type,
Union,
)
try:
from typing import Unpack # type: ignore[attr-defined]
except ImportError:
from typing_extensions import Unpack
import typing_extensions
from mashumaro.core.const import (
PY_39,
PY_310_MIN,
PY_311_MIN,
PY_312_MIN,
PY_313_MIN,
)
from mashumaro.dialect import Dialect
__all__ = [
"get_type_origin",
"get_args",
"type_name",
"is_special_typing_primitive",
"is_generic",
"is_typed_dict",
"is_named_tuple",
"is_optional",
"is_union",
"not_none_type_arg",
"is_type_var",
"is_type_var_any",
"is_class_var",
"is_final",
"is_init_var",
"get_class_that_defines_method",
"get_class_that_defines_field",
"is_dataclass_dict_mixin",
"is_dataclass_dict_mixin_subclass",
"collect_type_params",
"resolve_type_params",
"substitute_type_params",
"get_generic_name",
"get_name_error_name",
"is_dialect_subclass",
"is_new_type",
"is_annotated",
"get_type_annotations",
"is_literal",
"is_local_type_name",
"get_literal_values",
"is_self",
"is_required",
"is_not_required",
"get_function_arg_annotation",
"get_function_return_annotation",
"is_unpack",
"is_type_var_tuple",
"hash_type_args",
"iter_all_subclasses",
"is_hashable",
"is_hashable_type",
"evaluate_forward_ref",
"get_forward_ref_referencing_globals",
"is_type_alias_type",
]
NoneType = type(None)
DataClassDictMixinPath = (
f"{__name__.rsplit('.', 3)[:-3][0]}.mixins.dict.DataClassDictMixin"
)
def get_type_origin(typ: Type) -> Type:
try:
return typ.__origin__
except AttributeError:
return typ
def is_builtin_type(typ: Type) -> bool:
try:
return typ.__module__ == "builtins"
except AttributeError:
return False
def get_generic_name(typ: Type, short: bool = False) -> str:
name = getattr(typ, "_name", None)
if name is None:
origin = get_type_origin(typ)
if origin is typ:
return type_name(origin, short, is_type_origin=True)
else:
return get_generic_name(origin, short)
if short:
return name
else:
return f"{typ.__module__}.{name}"
def get_args(typ: Optional[Type]) -> tuple[Type, ...]:
return getattr(typ, "__args__", ())
def _get_args_str(
typ: Type,
short: bool,
resolved_type_params: Optional[dict[Type, Type]] = None,
limit: Optional[int] = None,
none_type_as_none: bool = False,
sep: str = ", ",
) -> str:
if typ == Tuple[()]:
return "()"
elif typ == tuple[()]:
return "()"
args = _flatten_type_args(get_args(typ)[:limit])
to_join = []
for arg in args:
to_join.append(
type_name(
typ=arg,
short=short,
resolved_type_params=resolved_type_params,
none_type_as_none=none_type_as_none,
)
)
if len(to_join) > 1:
return sep.join(s for s in to_join if s != "()")
else:
return sep.join(to_join)
def get_literal_values(typ: Type) -> tuple[Any, ...]:
values = typ.__args__
result: list[Any] = []
for value in values:
if is_literal(value):
result.extend(get_literal_values(value))
else:
result.append(value)
return tuple(result)
def _get_literal_values_str(typ: Type, short: bool) -> str:
values_str = []
for value in get_literal_values(typ):
if isinstance(value, enum.Enum):
values_str.append(f"{type_name(type(value), short)}.{value.name}")
elif isinstance(
value,
(int, str, bytes, bool, NoneType), # type: ignore
):
values_str.append(repr(value))
return ", ".join(values_str)
def _typing_name(
typ_name: str,
short: bool = False,
module_name: str = "typing",
) -> str:
return typ_name if short else f"{module_name}.{typ_name}"
def type_name(
typ: Optional[Type],
short: bool = False,
resolved_type_params: Optional[dict[Type, Type]] = None,
is_type_origin: bool = False,
none_type_as_none: bool = False,
) -> str:
if resolved_type_params is None:
resolved_type_params = {}
if typ is None:
return "None"
elif typ is NoneType and none_type_as_none:
return "None"
elif typ is Ellipsis:
return "..."
elif typ is Any:
return _typing_name("Any", short)
elif is_optional(typ, resolved_type_params):
args_str = type_name(
typ=not_none_type_arg(get_args(typ), resolved_type_params),
short=short,
resolved_type_params=resolved_type_params,
)
return f"{_typing_name('Optional', short)}[{args_str}]"
elif is_union(typ):
args_str = _get_args_str(
typ, short, resolved_type_params, none_type_as_none=True
)
return f"{_typing_name('Union', short)}[{args_str}]"
elif is_annotated(typ):
return type_name(get_args(typ)[0], short, resolved_type_params)
elif not is_type_origin and is_literal(typ):
args_str = _get_literal_values_str(typ, short)
return f"{_typing_name('Literal', short, typ.__module__)}[{args_str}]"
elif not is_type_origin and is_unpack(typ):
if (
typ in resolved_type_params
and resolved_type_params[typ] is not typ
):
return type_name(
resolved_type_params[typ], short, resolved_type_params
)
else:
unpacked_type_arg = get_args(typ)[0]
if not is_variable_length_tuple(
unpacked_type_arg
) and not is_type_var_tuple(unpacked_type_arg):
return _get_args_str(
unpacked_type_arg, short, resolved_type_params
)
unpacked_type_name = type_name(
unpacked_type_arg, short, resolved_type_params
)
if PY_311_MIN:
return f"*{unpacked_type_name}"
else:
_unpack = _typing_name("Unpack", short, typ.__module__)
return f"{_unpack}[{unpacked_type_name}]"
elif not is_type_origin and is_generic(typ):
args_str = _get_args_str(typ, short, resolved_type_params)
if not args_str:
return get_generic_name(typ, short)
else:
return f"{get_generic_name(typ, short)}[{args_str}]"
elif is_builtin_type(typ):
return typ.__qualname__
elif is_type_var(typ):
if (
typ in resolved_type_params
and resolved_type_params[typ] is not typ
):
return type_name(
resolved_type_params[typ], short, resolved_type_params
)
elif is_type_var_any(typ):
return _typing_name("Any", short)
constraints = getattr(typ, "__constraints__")
if constraints:
args_str = ", ".join(
type_name(c, short, resolved_type_params) for c in constraints
)
return f"{_typing_name('Union', short)}[{args_str}]"
else:
if type_var_has_default(typ):
bound = get_type_var_default(typ)
else:
bound = getattr(typ, "__bound__")
return type_name(bound, short, resolved_type_params)
elif is_new_type(typ) and not PY_310_MIN:
# because __qualname__ and __module__ are messed up
typ = typ.__supertype__
try:
if short:
return typ.__qualname__ # type: ignore
else:
return f"{typ.__module__}.{typ.__qualname__}" # type: ignore
except AttributeError:
return str(typ)
def is_special_typing_primitive(typ: Any) -> bool:
try:
issubclass(typ, object)
return False
except TypeError:
return True
def is_generic(typ: Type) -> bool:
with suppress(Exception):
if hasattr(typ, "__class_getitem__"):
return True
# noinspection PyProtectedMember
# noinspection PyUnresolvedReferences
if (
issubclass(typ.__class__, typing._BaseGenericAlias) # type: ignore
or type(typ) is types.GenericAlias # type: ignore # noqa: E721
):
return True
else:
return False
# else: # for PEP 585 generics without args
# try:
# return (
# hasattr(typ, "__class_getitem__")
# and type(typ[str]) is types.GenericAlias # type: ignore
# )
# except (TypeError, AttributeError):
# return False
def is_typed_dict(typ: Type) -> bool:
for module in (typing, typing_extensions):
with suppress(AttributeError):
if type(typ) is getattr(module, "_TypedDictMeta"):
return True
return False
def is_readonly(typ: Type) -> bool:
origin = get_type_origin(typ)
for module in (typing, typing_extensions):
with suppress(AttributeError):
if origin is getattr(module, "ReadOnly"):
return True
return False
def is_named_tuple(typ: Type) -> bool:
try:
return issubclass(typ, tuple) and hasattr(typ, "_fields")
except TypeError:
return False
def is_new_type(typ: Type) -> bool:
return hasattr(typ, "__supertype__")
def is_union(typ: Type) -> bool:
try:
if PY_310_MIN and isinstance(typ, types.UnionType): # type: ignore
return True
return typ.__origin__ is Union
except AttributeError:
return False
def is_optional(
typ: Type, resolved_type_params: Optional[dict[Type, Type]] = None
) -> bool:
if resolved_type_params is None:
resolved_type_params = {}
if not is_union(typ):
return False
args = get_args(typ)
if len(args) != 2:
return False
for arg in args:
if resolved_type_params.get(arg, arg) is NoneType:
return True
return False
def is_annotated(typ: Type) -> bool:
for module in (typing, typing_extensions):
with suppress(AttributeError):
if type(typ) is getattr(module, "_AnnotatedAlias"):
return True
return False
def get_type_annotations(typ: Type) -> Sequence[Any]:
return getattr(typ, "__metadata__", [])
def is_literal(typ: Type) -> bool:
if PY_39:
with suppress(AttributeError):
return is_generic(typ) and get_generic_name(typ, True) == "Literal"
elif PY_310_MIN:
with suppress(AttributeError):
# noinspection PyProtectedMember
# noinspection PyUnresolvedReferences
return type(typ) is typing._LiteralGenericAlias # type: ignore
return False
def is_local_type_name(typ_name: str) -> bool:
return "<locals>" in typ_name
def not_none_type_arg(
type_args: tuple[Type, ...],
resolved_type_params: Optional[dict[Type, Type]] = None,
) -> Optional[Type]:
if resolved_type_params is None:
resolved_type_params = {}
for type_arg in type_args:
if resolved_type_params.get(type_arg, type_arg) is not NoneType:
return type_arg
return None
def is_type_var(typ: Type) -> bool:
return hasattr(typ, "__constraints__")
def is_type_var_any(typ: Type) -> bool:
if not is_type_var(typ):
return False
elif typ.__constraints__ != ():
return False
elif typ.__bound__ not in (None, Any):
return False
elif type_var_has_default(typ):
return False
else:
return True
def is_class_var(typ: Type) -> bool:
return get_type_origin(typ) is ClassVar
def is_final(typ: Type) -> bool:
return get_type_origin(typ) is typing_extensions.Final
def is_init_var(typ: Type) -> bool:
return isinstance(typ, dataclasses.InitVar)
def get_class_that_defines_method(
method_name: str, cls: Type
) -> Optional[Type]:
for cls in cls.__mro__:
if method_name in cls.__dict__:
return cls
return None
def get_class_that_defines_field(field_name: str, cls: Type) -> Optional[Type]:
prev_cls = None
prev_field = None
for base in reversed(cls.__mro__):
if dataclasses.is_dataclass(base):
field = getattr(base, _FIELDS).get(field_name)
if field and field != prev_field:
prev_field = field
prev_cls = base
return prev_cls or cls
def is_dataclass_dict_mixin(typ: Type) -> bool:
return type_name(typ) == DataClassDictMixinPath
def is_dataclass_dict_mixin_subclass(typ: Type) -> bool:
with suppress(AttributeError):
for cls in typ.__mro__:
if is_dataclass_dict_mixin(cls):
return True
return False
def get_orig_bases(typ: Type) -> tuple[Type, ...]:
return getattr(typ, "__orig_bases__", ())
def collect_type_params(typ: Type) -> Sequence[Type]:
type_params = []
for type_arg in get_args(typ):
if type_arg in type_params:
continue
elif is_type_var(type_arg):
type_params.append(type_arg)
elif is_unpack(type_arg) and is_type_var_tuple(get_args(type_arg)[0]):
type_params.append(type_arg)
else:
for _type_param in collect_type_params(type_arg):
if _type_param not in type_params:
type_params.append(_type_param)
return type_params
def _check_generic(
typ: Type, type_params: Sequence[Type], type_args: Sequence[Type]
) -> None:
# https://github.com/python/cpython/issues/99382
unpacks = len(list(filter(is_unpack, type_params)))
if unpacks > 1:
raise TypeError(
"Multiple unpacks are disallowed within a single type parameter "
f"list for {type_name(typ)}"
)
elif unpacks == 1:
expected_count = len(type_params) - 1
expected_msg = f"at least {len(type_params) - 1}"
else:
expected_count = len(type_params)
expected_msg = f"{expected_count}"
args_len = len(type_args)
if 0 < args_len < expected_count:
raise TypeError(
f"Too few arguments for {type_name(typ)}; "
f"actual {args_len}, expected {expected_msg}"
)
def _flatten_type_args(
type_args: Sequence[Type],
allow_ellipsis_if_many_args: bool = False,
) -> Sequence[Type]:
result = []
for type_arg in type_args:
if is_unpack(type_arg):
unpacked_type = get_args(type_arg)[0]
if is_type_var_tuple(unpacked_type):
result.append(type_arg)
elif is_variable_length_tuple(unpacked_type):
if len(type_args) == 1:
result.extend(_flatten_type_args(get_args(unpacked_type)))
elif allow_ellipsis_if_many_args:
result.extend(_flatten_type_args(get_args(unpacked_type)))
else:
result.append(type_arg)
elif unpacked_type == Tuple[()]:
if len(type_args) == 1:
result.append(()) # type: ignore
elif unpacked_type == tuple[()]: # type: ignore
if len(type_args) == 1:
result.append(()) # type: ignore
else:
result.extend(_flatten_type_args(get_args(unpacked_type)))
else:
result.append(type_arg)
return result
def resolve_type_params(
typ: Type,
type_args: Sequence[Type] = (),
include_bases: bool = True,
) -> dict[Type, dict[Type, Type]]:
resolved_type_params: dict[Type, Type] = {}
result = {typ: resolved_type_params}
type_params = []
for base in get_orig_bases(typ):
base_type_params = collect_type_params(base)
for type_param in base_type_params:
if type_param not in type_params:
type_params.append(type_param)
_check_generic(typ, type_params, type_args)
type_args = _flatten_type_args(type_args, allow_ellipsis_if_many_args=True)
param_idx = 0
unpack_param_idx = -1
arg_idx = 0
while param_idx < len(type_params):
type_param = type_params[param_idx]
if not is_unpack(type_param):
if type_param not in resolved_type_params:
try:
next_type_arg = type_args[arg_idx]
if next_type_arg is Ellipsis:
next_type_arg = type_args[arg_idx - 1]
else:
if unpack_param_idx < 0:
arg_idx += 1
else:
arg_idx -= 1
except IndexError:
next_type_arg = type_param
resolved_type_params[type_param] = next_type_arg
if unpack_param_idx < 0:
param_idx += 1
else:
param_idx -= 1
elif unpack_param_idx < 0:
unpack_param_idx = param_idx
param_idx = -1
arg_idx = -1
unpacked_param = get_args(type_param)[0]
for y in reversed(get_args(unpacked_param)): # pragma: no cover
# We turn Tuple[x,y] to x, y, but leave this here just in case
type_params.insert(param_idx, y)
else:
if not type_args and is_type_var_tuple(get_args(type_param)[0]):
resolved_type_params[type_param] = Unpack[
Tuple[Any, ...] # type: ignore
]
break
t_args = type_args[unpack_param_idx : len(type_args) + arg_idx + 1]
if len(t_args) == 1 and t_args[0] == ():
x: Any = ()
elif len(t_args) > 2 and t_args[-1] is Ellipsis:
x = (*t_args[:-2], Unpack[Tuple[t_args[-2], ...]])
else:
x = tuple(t_args)
resolved_type_params[type_param] = Unpack[Tuple[x]] # type: ignore
break
if include_bases:
orig_bases = {
get_type_origin(orig_base): orig_base
for orig_base in get_orig_bases(typ)
}
for base in getattr(typ, "__bases__", ()):
orig_base = orig_bases.get(get_type_origin(base))
base_type_params = get_args(orig_base)
base_type_args = tuple(
[resolved_type_params.get(a, a) for a in base_type_params]
)
result.update(resolve_type_params(base, base_type_args))
return result
def substitute_type_params(typ: Type, substitutions: dict[Type, Type]) -> Type:
if is_annotated(typ):
origin = get_type_origin(typ)
subst = substitutions.get(origin, origin)
return typing_extensions.Annotated[
(subst, *get_type_annotations(typ)) # type: ignore
]
else:
new_type_args = []
for type_param in collect_type_params(typ):
new_type_args.append(substitutions.get(type_param, type_param))
if new_type_args:
with suppress(TypeError, KeyError):
return typ[tuple(new_type_args)]
if is_hashable(typ):
return substitutions.get(typ, typ)
else:
return typ
def get_name_error_name(e: NameError) -> str:
if PY_310_MIN:
return e.name # type: ignore
else:
match = re.search("'(.*)'", e.args[0])
return match.group(1) if match else ""
def is_dialect_subclass(typ: Type) -> bool:
try:
return issubclass(typ, Dialect)
except TypeError:
return False
def is_self(typ: Type) -> bool:
return typ is typing_extensions.Self
def is_required(typ: Type) -> bool:
return get_type_origin(typ) is typing_extensions.Required # noqa
def is_not_required(typ: Type) -> bool:
return get_type_origin(typ) is typing_extensions.NotRequired # noqa
def get_function_arg_annotation(
function: Callable[..., Any],
arg_name: Optional[str] = None,
arg_pos: Optional[int] = None,
) -> type:
parameters = inspect.signature(function).parameters
if arg_name is not None:
parameter = parameters[arg_name]
elif arg_pos is not None:
parameter = parameters[list(parameters.keys())[arg_pos]]
else:
raise ValueError("arg_name or arg_pos must be passed")
annotation = parameter.annotation
if annotation is inspect.Signature.empty:
raise ValueError(f"Argument {arg_name} doesn't have annotation")
if isinstance(annotation, str):
annotation = str_to_forward_ref(
annotation, inspect.getmodule(function)
)
return annotation
def get_function_return_annotation(function: Callable[[Any], Any]) -> Type:
annotation = inspect.signature(function).return_annotation
if annotation is inspect.Signature.empty:
raise ValueError("Function doesn't have return annotation")
if isinstance(annotation, str):
annotation = str_to_forward_ref(
annotation, inspect.getmodule(function)
)
return annotation
def is_unpack(typ: Type) -> bool:
for module in (typing, typing_extensions):
with suppress(AttributeError):
if get_type_origin(typ) is getattr(module, "Unpack"):
return True
return False
def is_type_var_tuple(typ: Type) -> bool:
for module in (typing, typing_extensions):
with suppress(AttributeError):
if type(typ) is getattr(module, "TypeVarTuple"):
return True
return False
def is_variable_length_tuple(typ: Type) -> bool:
type_args = get_args(typ)
return len(type_args) == 2 and type_args[1] is Ellipsis
def hash_type_args(type_args: Iterable[Type]) -> str:
return md5(",".join(map(type_name, type_args)).encode()).hexdigest()
def iter_all_subclasses(cls: Type) -> Iterator[Type]:
for subclass in cls.__subclasses__():
yield subclass
yield from iter_all_subclasses(subclass)
def is_hashable(value: Any) -> bool:
try:
hash(value)
return True
except TypeError:
return False
def is_hashable_type(typ: Any) -> bool:
try:
return issubclass(typ, Hashable)
except TypeError:
return True
def str_to_forward_ref(
annotation: str, module: Optional[types.ModuleType] = None
) -> ForwardRef:
return ForwardRef(annotation, module=module)
def evaluate_forward_ref(
typ: ForwardRef, globalns: dict[str, Any], localns: dict[str, Any]
) -> Optional[Type]:
if PY_313_MIN:
return typ._evaluate(
globalns, localns, type_params=(), recursive_guard=frozenset()
) # type: ignore[call-arg]
else:
return typ._evaluate(
globalns, localns, recursive_guard=frozenset()
) # type: ignore[call-arg]
def get_forward_ref_referencing_globals(
referenced_type: ForwardRef,
referencing_object: Optional[Any] = None,
fallback: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
if fallback is None:
fallback = {}
forward_module = getattr(referenced_type, "__forward_module__", None)
if not forward_module and referencing_object:
# We can't get the module in which ForwardRef's value is defined on
# Python < 3.10, ForwardRef evaluation might not work properly
# without this information, so we will consider the namespace of
# the module in which this ForwardRef is used as globalns.
return getattr(
sys.modules.get(referencing_object.__module__, None),
"__dict__",
fallback,
)
else:
return getattr(forward_module, "__dict__", fallback)
def is_type_alias_type(typ: Type) -> bool:
if PY_312_MIN:
return isinstance(typ, typing.TypeAliasType) # type: ignore
else:
return False
def type_var_has_default(typ: Any) -> bool:
try:
return typ.has_default()
except AttributeError:
return getattr(typ, "__default__", None) is not None
def get_type_var_default(typ: Any) -> Type:
return getattr(typ, "__default__")

View File

@@ -0,0 +1,52 @@
from typing import Any, Optional, Type
from mashumaro.core.meta.code.builder import CodeBuilder
from mashumaro.dialect import Dialect
from mashumaro.exceptions import UnresolvedTypeReferenceError
__all__ = [
"compile_mixin_packer",
"compile_mixin_unpacker",
]
def compile_mixin_packer(
cls: Type,
format_name: str = "dict",
dialect: Optional[Type[Dialect]] = None,
encoder: Any = None,
encoder_kwargs: Optional[dict[str, dict[str, tuple[str, Any]]]] = None,
) -> None:
builder = CodeBuilder(
cls=cls,
format_name=format_name,
encoder=encoder,
encoder_kwargs=encoder_kwargs,
default_dialect=dialect,
)
config = builder.get_config()
try:
builder.add_pack_method()
except UnresolvedTypeReferenceError:
if not config.allow_postponed_evaluation:
raise
def compile_mixin_unpacker(
cls: Type,
format_name: str = "dict",
dialect: Optional[Type[Dialect]] = None,
decoder: Any = None,
) -> None:
builder = CodeBuilder(
cls=cls,
format_name=format_name,
decoder=decoder,
default_dialect=dialect,
)
config = builder.get_config()
try:
builder.add_unpack_method()
except UnresolvedTypeReferenceError:
if not config.allow_postponed_evaluation:
raise

View File

@@ -0,0 +1,285 @@
import re
import uuid
from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass, field, replace
from functools import cached_property
from types import new_class
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
from typing_extensions import ParamSpec, TypeAlias
from mashumaro.core.meta.code.lines import CodeLines
from mashumaro.core.meta.helpers import (
get_type_origin,
is_annotated,
is_generic,
is_hashable_type,
is_self,
type_name,
)
from mashumaro.exceptions import UnserializableField
if TYPE_CHECKING: # pragma: no cover
from mashumaro.core.meta.code.builder import CodeBuilder
else:
CodeBuilder = Any
class TypeMatchEligibleExpression(str):
pass
NoneType = type(None)
Expression: TypeAlias = Union[str, TypeMatchEligibleExpression]
P = ParamSpec("P")
T = TypeVar("T")
_PY_VALID_ID_RE = re.compile(r"\W|^(?=\d)")
class AttrsHolder:
def __new__(
cls, name: Optional[str] = None, *args: Any, **kwargs: Any
) -> Any:
ah = new_class("AttrsHolder")
ah_id = id(ah)
if not name:
name = f"attrs_{ah_id}"
ah.__name__ = ah.__qualname__ = name
return ah
class ExpressionWrapper:
def __init__(self, expression: str):
self.expression = expression
@dataclass
class FieldContext:
name: str
metadata: Mapping
packer: Optional[str] = None
unpacker: Optional[str] = None
def copy(self, **changes: Any) -> "FieldContext":
return replace(self, **changes)
@dataclass
class ValueSpec:
type: Type
origin_type: Type = field(init=False)
expression: Expression
builder: CodeBuilder
field_ctx: FieldContext
could_be_none: bool = True
annotated_type: Optional[Type] = None
owner: Optional[Type] = None
no_copy_collections: Sequence = tuple()
def __setattr__(self, key: str, value: Any) -> None:
if key == "type":
self.origin_type = get_type_origin(value)
super().__setattr__(key, value)
def copy(self, **changes: Any) -> "ValueSpec":
return replace(self, **changes)
@cached_property
def annotations(self) -> Sequence[str]:
return getattr(self.annotated_type, "__metadata__", [])
@cached_property
def attrs(self) -> Any:
if self.builder.is_nailed:
return self.builder.attrs
if is_self(self.type):
typ = self.builder.cls
else:
typ = self.origin_type
attrs = self.attrs_registry.get(typ)
if attrs is None:
attrs = AttrsHolder()
self.attrs_registry[typ] = attrs
return attrs
@cached_property
def cls_attrs_name(self) -> str:
if self.builder.is_nailed:
return "cls"
else:
self.builder.ensure_object_imported(self.attrs)
return self.attrs.__name__
@cached_property
def self_attrs_name(self) -> str:
if self.builder.is_nailed:
return "self"
else:
self.builder.ensure_object_imported(self.attrs)
return self.attrs.__name__
@cached_property
def attrs_registry(self) -> dict[Any, Any]:
return self.builder.attrs_registry
@cached_property
def attrs_registry_name(self) -> str:
name = f"attrs_registry_{id(self.attrs_registry)}"
self.builder.ensure_object_imported(self.attrs_registry, name)
return name
class AbstractMethodBuilder(ABC):
@abstractmethod
def get_method_prefix(self) -> str: # pragma: no cover
raise NotImplementedError
def _generate_method_name(
self, spec: ValueSpec
) -> str: # pragma: no cover
prefix = self.get_method_prefix()
if prefix:
prefix = f"{prefix}_"
if spec.field_ctx.name:
suffix = f"_{spec.field_ctx.name}"
else:
suffix = ""
return f"__{prefix}{spec.builder.cls.__name__}{suffix}__{random_hex()}"
@abstractmethod
def _add_definition(self, spec: ValueSpec, lines: CodeLines) -> str:
raise NotImplementedError
@abstractmethod
def _generate_method_args(self, spec: ValueSpec) -> str:
raise NotImplementedError
@abstractmethod
def _add_body(
self, spec: ValueSpec, lines: CodeLines
) -> None: # pragma: no cover
raise NotImplementedError
def _add_setattr(
self, spec: ValueSpec, method_name: str, lines: CodeLines
) -> None:
lines.append(
f"setattr({spec.cls_attrs_name}, '{method_name}', {method_name})"
)
def _compile(self, spec: ValueSpec, lines: CodeLines) -> None:
if spec.builder.get_config().debug:
print(f"{type_name(spec.builder.cls)}:")
print(lines.as_text())
exec(lines.as_text(), spec.builder.globals, spec.builder.__dict__)
@abstractmethod
def _get_call_expr(self, spec: ValueSpec, method_name: str) -> str:
raise NotImplementedError
def _before_build(self, spec: ValueSpec) -> None:
pass
def _get_existing_method(self, spec: ValueSpec) -> Optional[str]:
return None
def build(self, spec: ValueSpec) -> str:
self._before_build(spec)
if method := self._get_existing_method(spec):
return method
lines = CodeLines()
method_name = self._add_definition(spec, lines)
with lines.indent():
self._add_body(spec, lines)
self._add_setattr(spec, method_name, lines)
self._compile(spec, lines)
return self._get_call_expr(spec, method_name)
ValueSpecExprCreator: TypeAlias = Callable[[ValueSpec], Optional[Expression]]
@dataclass
class Registry:
_registry: list[ValueSpecExprCreator] = field(default_factory=list)
def register(self, function: ValueSpecExprCreator) -> ValueSpecExprCreator:
self._registry.append(function)
return function
def get(self, spec: ValueSpec) -> Expression:
if is_annotated(spec.type):
spec.annotated_type = spec.builder.get_real_type(
spec.field_ctx.name, spec.type
)
spec.type = get_type_origin(spec.type)
spec.type = spec.builder.get_real_type(spec.field_ctx.name, spec.type)
spec.builder.add_type_modules(spec.type)
for packer in self._registry:
expr = packer(spec)
if expr is not None:
return expr
raise UnserializableField(
spec.field_ctx.name, spec.type, spec.builder.cls
)
def ensure_generic_collection(spec: ValueSpec) -> bool:
if not is_generic(spec.type):
return False
return True
def ensure_mapping_key_type_hashable(
spec: ValueSpec, type_args: Sequence[Type]
) -> bool:
if type_args:
first_type_arg = type_args[0]
if not is_hashable_type(first_type_arg):
raise UnserializableField(
field_name=spec.field_ctx.name,
field_type=spec.type,
holder_class=spec.builder.cls,
msg=(
f"{type_name(first_type_arg, short=True)} "
"is unhashable and can not be used as a key"
),
)
return True
def ensure_generic_collection_subclass(
spec: ValueSpec, *checked_types: Type
) -> bool:
return issubclass(
spec.origin_type, checked_types
) and ensure_generic_collection(spec)
def ensure_generic_mapping(
spec: ValueSpec, args: Sequence[Type], checked_type: Type
) -> bool:
return ensure_generic_collection_subclass(
spec, checked_type
) and ensure_mapping_key_type_hashable(spec, args)
def expr_or_maybe_none(spec: ValueSpec, new_expr: Expression) -> Expression:
if spec.could_be_none:
return f"{new_expr} if {spec.expression} is not None else None"
else:
return new_expr
def random_hex() -> str:
return str(uuid.uuid4().hex)
def clean_id(value: str) -> str:
if not value:
return "_"
return _PY_VALID_ID_RE.sub("_", value)

View File

@@ -0,0 +1,873 @@
import datetime
import enum
import ipaddress
import os
import re
import typing
import uuid
import zoneinfo
from base64 import encodebytes
from collections import ChainMap, Counter, OrderedDict, deque
from collections.abc import Callable, Collection, Mapping, Sequence, Set
from contextlib import suppress
from dataclasses import is_dataclass
from decimal import Decimal
from fractions import Fraction
from typing import Any, ForwardRef, Optional, Tuple, Union
import typing_extensions
from mashumaro.core.const import PY_311_MIN
from mashumaro.core.meta.code.lines import CodeLines
from mashumaro.core.meta.helpers import (
get_args,
get_class_that_defines_method,
get_function_return_annotation,
get_literal_values,
get_type_origin,
get_type_var_default,
is_final,
is_generic,
is_literal,
is_named_tuple,
is_new_type,
is_not_required,
is_optional,
is_readonly,
is_required,
is_self,
is_special_typing_primitive,
is_type_alias_type,
is_type_var,
is_type_var_any,
is_type_var_tuple,
is_typed_dict,
is_union,
is_unpack,
not_none_type_arg,
resolve_type_params,
substitute_type_params,
type_name,
type_var_has_default,
)
from mashumaro.core.meta.types.common import (
Expression,
ExpressionWrapper,
NoneType,
Registry,
ValueSpec,
clean_id,
ensure_generic_collection,
ensure_generic_collection_subclass,
ensure_generic_mapping,
expr_or_maybe_none,
random_hex,
)
from mashumaro.exceptions import (
UnserializableDataError,
UnserializableField,
UnsupportedSerializationEngine,
)
from mashumaro.helper import pass_through
from mashumaro.types import (
GenericSerializableType,
SerializableType,
SerializationStrategy,
)
__all__ = ["PackerRegistry"]
PackerRegistry = Registry()
register = PackerRegistry.register
def _pack_with_annotated_serialization_strategy(
spec: ValueSpec,
strategy: SerializationStrategy,
) -> Expression:
strategy_type = type(strategy)
try:
value_type: Union[type, Any] = get_function_return_annotation(
strategy.serialize
)
except (KeyError, ValueError):
value_type = Any
if isinstance(value_type, ForwardRef):
value_type = spec.builder.evaluate_forward_ref(
value_type, spec.origin_type
)
value_type = substitute_type_params(
value_type, # type: ignore
resolve_type_params(strategy_type, get_args(spec.type))[strategy_type],
)
overridden_fn = f"__{spec.field_ctx.name}_serialize_{random_hex()}"
setattr(spec.attrs, overridden_fn, strategy.serialize)
new_spec = spec.copy(
type=value_type,
expression=(
f"{spec.self_attrs_name}.{overridden_fn}({spec.expression})"
),
)
field_metadata = new_spec.field_ctx.metadata
if field_metadata.get("serialization_strategy") is strategy:
new_spec.field_ctx.metadata = {
k: v
for k, v in field_metadata.items()
if k != "serialization_strategy"
}
return PackerRegistry.get(
spec.copy(
type=value_type,
expression=(
f"{spec.self_attrs_name}.{overridden_fn}({spec.expression})"
),
)
)
def get_overridden_serialization_method(
spec: ValueSpec,
) -> Optional[Union[Callable, str, ExpressionWrapper]]:
serialize_option = spec.field_ctx.metadata.get("serialize")
if serialize_option is not None:
return serialize_option
checking_types = [spec.type, spec.origin_type]
if spec.annotated_type:
checking_types.insert(0, spec.annotated_type)
for typ in checking_types:
for strategy in spec.builder.iter_serialization_strategies(
spec.field_ctx.metadata, typ
):
if strategy is pass_through:
return pass_through
elif isinstance(strategy, dict):
serialize_option = strategy.get("serialize")
elif isinstance(strategy, SerializationStrategy):
if strategy.__use_annotations__ or is_generic(type(strategy)):
return ExpressionWrapper(
_pack_with_annotated_serialization_strategy(
spec=spec,
strategy=strategy,
)
)
else:
serialize_option = strategy.serialize
if serialize_option is not None:
return serialize_option
@register
def pack_type_with_overridden_serialization(
spec: ValueSpec,
) -> Optional[Expression]:
serialization_method = get_overridden_serialization_method(spec)
if serialization_method is pass_through:
return spec.expression
elif isinstance(serialization_method, ExpressionWrapper):
return serialization_method.expression
elif callable(serialization_method):
overridden_fn = f"__{spec.field_ctx.name}_serialize_{random_hex()}"
setattr(spec.attrs, overridden_fn, staticmethod(serialization_method))
return f"{spec.self_attrs_name}.{overridden_fn}({spec.expression})"
def _pack_annotated_serializable_type(
spec: ValueSpec,
) -> Optional[Expression]:
try:
# noinspection PyProtectedMember
# noinspection PyUnresolvedReferences
value_type = get_function_return_annotation(
spec.origin_type._serialize
)
except (KeyError, ValueError):
raise UnserializableField(
field_name=spec.field_ctx.name,
field_type=spec.type,
holder_class=spec.builder.cls,
msg="Method _serialize must have return annotation",
) from None
if is_self(value_type):
return f"{spec.expression}._serialize()"
if isinstance(value_type, ForwardRef):
value_type = spec.builder.evaluate_forward_ref(
value_type, spec.origin_type
)
value_type = substitute_type_params(
value_type,
resolve_type_params(spec.origin_type, get_args(spec.type))[
spec.origin_type
],
)
return PackerRegistry.get(
spec.copy(
type=value_type,
expression=f"{spec.expression}._serialize()",
)
)
@register
def pack_serializable_type(spec: ValueSpec) -> Optional[Expression]:
try:
if not issubclass(spec.origin_type, SerializableType):
return None
except TypeError:
return None
if spec.origin_type.__use_annotations__:
return _pack_annotated_serializable_type(spec)
else:
return f"{spec.expression}._serialize()"
@register
def pack_generic_serializable_type(spec: ValueSpec) -> Optional[Expression]:
with suppress(TypeError):
if issubclass(spec.origin_type, GenericSerializableType):
type_args = get_args(spec.type)
spec.builder.add_type_modules(*type_args)
type_arg_names = ", ".join(list(map(type_name, type_args)))
return f"{spec.expression}._serialize([{type_arg_names}])"
@register
def pack_dataclass(spec: ValueSpec) -> Optional[Expression]:
if is_dataclass(spec.origin_type):
type_args = get_args(spec.type)
method_name = spec.builder.get_pack_method_name(
type_args, spec.builder.format_name
)
method_loc = spec.origin_type if spec.builder.is_nailed else spec.attrs
if get_class_that_defines_method(
method_name, method_loc
) != method_loc and (
spec.origin_type is not spec.builder.cls
or spec.builder.get_pack_method_name(
type_args=type_args,
format_name=spec.builder.format_name,
encoder=spec.builder.encoder,
)
!= method_name
):
builder = spec.builder.__class__(
spec.origin_type,
type_args,
dialect=spec.builder.dialect,
format_name=spec.builder.format_name,
default_dialect=spec.builder.default_dialect,
attrs=method_loc,
attrs_registry=(
spec.attrs_registry if not spec.builder.is_nailed else None
),
)
builder.add_pack_method()
flags = spec.builder.get_pack_method_flags(spec.type)
if spec.builder.is_nailed:
return f"{spec.expression}.{method_name}({flags})"
else:
cls_alias = clean_id(type_name(spec.origin_type))
method_name_alias = f"{cls_alias}_{method_name}"
spec.builder.ensure_object_imported(
getattr(spec.attrs, method_name), method_name_alias
)
method_args = spec.expression
return f"{method_name_alias}({method_args})"
@register
def pack_final(spec: ValueSpec) -> Optional[Expression]:
if is_final(spec.type):
return PackerRegistry.get(spec.copy(type=get_args(spec.type)[0]))
@register
def pack_any(spec: ValueSpec) -> Optional[Expression]:
if spec.type is Any:
return spec.expression
def pack_union(
spec: ValueSpec, args: tuple[type, ...], prefix: str = "union"
) -> Expression:
if spec.type is spec.owner and spec.field_ctx.packer:
return spec.field_ctx.packer
lines = CodeLines()
method_name = (
f"__pack_{prefix}_{spec.builder.cls.__name__}_{spec.field_ctx.name}__"
f"{random_hex()}"
)
if not spec.field_ctx.packer:
method_args = ", ".join(
filter(None, ("value", spec.builder.get_pack_method_flags()))
)
if spec.builder.is_nailed:
union_packer = (
f"{spec.self_attrs_name}.{method_name}({method_args})"
)
else:
union_packer = f"{method_name}({method_args})"
spec.field_ctx.packer = union_packer
method_args = "self, value" if spec.builder.is_nailed else "value"
default_kwargs = spec.builder.get_pack_method_default_flag_values()
if default_kwargs:
lines.append(f"def {method_name}({method_args}, {default_kwargs}):")
else:
lines.append(f"def {method_name}({method_args}):")
packers: list[str] = []
packer_arg_types: dict[str, list[type]] = {}
for type_arg in args:
packer = PackerRegistry.get(
spec.copy(type=type_arg, expression="value", owner=spec.type)
)
if packer not in packers:
if packer == "value":
packers.insert(0, packer)
else:
packers.append(packer)
packer_arg_types.setdefault(packer, []).append(type_arg)
if len(packers) == 1 and packers[0] == "value":
return spec.expression
with lines.indent():
for packer in packers:
packer_arg_type_names = []
for packer_arg_type in packer_arg_types[packer]:
if is_generic(packer_arg_type):
packer_arg_type = get_type_origin(packer_arg_type)
packer_arg_type_name = clean_id(type_name(packer_arg_type))
spec.builder.ensure_object_imported(
packer_arg_type, packer_arg_type_name
)
if packer_arg_type_name not in packer_arg_type_names:
packer_arg_type_names.append(packer_arg_type_name)
if len(packer_arg_type_names) > 1:
packer_arg_type_check = (
f"in ({', '.join(packer_arg_type_names)})"
)
else:
packer_arg_type_check = f"is {packer_arg_type_names[0]}"
if packer == "value":
with lines.indent(
f"if value.__class__ {packer_arg_type_check}:"
):
lines.append(f"return {packer}")
else:
with lines.indent("try:"):
lines.append(f"return {packer}")
with lines.indent("except Exception:"):
lines.append("pass")
field_type = spec.builder.get_type_name_identifier(
typ=spec.type,
resolved_type_params=spec.builder.get_field_resolved_type_params(
spec.field_ctx.name
),
)
if spec.builder.is_nailed:
lines.append(
"raise InvalidFieldValue("
f"'{spec.field_ctx.name}',{field_type},value,type(self))"
)
else:
lines.append("raise ValueError(value)")
lines.append(
f"setattr({spec.cls_attrs_name}, '{method_name}', {method_name})"
)
if spec.builder.get_config().debug:
print(f"{type_name(spec.builder.cls)}:")
print(lines.as_text())
exec(lines.as_text(), spec.builder.globals, spec.builder.__dict__)
method_args = ", ".join(
filter(None, (spec.expression, spec.builder.get_pack_method_flags()))
)
if spec.builder.is_nailed:
return f"{spec.self_attrs_name}.{method_name}({method_args})"
else:
spec.builder.ensure_object_imported(
getattr(spec.attrs, method_name), method_name
)
return f"{method_name}({method_args})"
def pack_literal(spec: ValueSpec) -> Expression:
spec.builder.add_type_modules(spec.type)
lines = CodeLines()
method_name = (
f"__pack_literal_{spec.builder.cls.__name__}_{spec.field_ctx.name}__"
f"{random_hex()}"
)
method_args = "self, value" if spec.builder.is_nailed else "value"
default_kwargs = spec.builder.get_pack_method_default_flag_values()
if default_kwargs:
lines.append(f"def {method_name}({method_args}, {default_kwargs}):")
else:
lines.append(f"def {method_name}({method_args}):")
resolved_type_params = spec.builder.get_field_resolved_type_params(
spec.field_ctx.name
)
with lines.indent():
for literal_value in get_literal_values(spec.type):
value_type = type(literal_value)
packer = PackerRegistry.get(
spec.copy(type=value_type, expression="value")
)
if isinstance(literal_value, enum.Enum):
enum_type_name = spec.builder.get_type_name_identifier(
typ=value_type,
resolved_type_params=resolved_type_params,
)
with lines.indent(
f"if value == {enum_type_name}.{literal_value.name}:"
):
lines.append(f"return {packer}")
elif isinstance(
literal_value,
(int, str, bytes, bool, NoneType), # type: ignore
):
with lines.indent(f"if value == {literal_value!r}:"):
lines.append(f"return {packer}")
field_type = spec.builder.get_type_name_identifier(
typ=spec.type,
resolved_type_params=resolved_type_params,
)
if spec.builder.is_nailed:
lines.append(
f"raise InvalidFieldValue('{spec.field_ctx.name}',"
f"{field_type},value,type(self))"
)
else:
lines.append("raise ValueError(value)")
lines.append(
f"setattr({spec.cls_attrs_name}, '{method_name}', {method_name})"
)
if spec.builder.get_config().debug:
print(f"{type_name(spec.builder.cls)}:")
print(lines.as_text())
exec(lines.as_text(), spec.builder.globals, spec.builder.__dict__)
method_args = ", ".join(
filter(None, (spec.expression, spec.builder.get_pack_method_flags()))
)
return f"{spec.self_attrs_name}.{method_name}({method_args})"
@register
def pack_special_typing_primitive(spec: ValueSpec) -> Optional[Expression]:
if is_special_typing_primitive(spec.origin_type):
if is_union(spec.type):
resolved_type_params = spec.builder.get_field_resolved_type_params(
spec.field_ctx.name
)
if is_optional(spec.type, resolved_type_params):
arg = not_none_type_arg(
get_args(spec.type), resolved_type_params
)
pv = PackerRegistry.get(spec.copy(type=arg))
return expr_or_maybe_none(spec, pv)
else:
return pack_union(spec, get_args(spec.type))
elif spec.origin_type is typing.AnyStr:
raise UnserializableDataError(
"AnyStr is not supported by mashumaro"
)
elif is_type_var_any(spec.type):
return spec.expression
elif is_type_var(spec.type):
constraints = getattr(spec.type, "__constraints__")
if constraints:
return pack_union(spec, constraints, "type_var")
else:
if type_var_has_default(spec.type):
bound = get_type_var_default(spec.type)
else:
bound = getattr(spec.type, "__bound__")
# act as if it was Optional[bound]
pv = PackerRegistry.get(spec.copy(type=bound))
return expr_or_maybe_none(spec, pv)
elif is_new_type(spec.type):
return PackerRegistry.get(spec.copy(type=spec.type.__supertype__))
elif is_literal(spec.type):
return pack_literal(spec)
elif spec.type is typing_extensions.LiteralString:
return PackerRegistry.get(spec.copy(type=str))
elif is_self(spec.type):
method_name = spec.builder.get_pack_method_name(
format_name=spec.builder.format_name
)
method_loc = (
spec.builder.cls if spec.builder.is_nailed else spec.attrs
)
if (
get_class_that_defines_method(method_name, method_loc)
!= method_loc
# not hasattr(self.cls, method_name)
and spec.builder.get_pack_method_name(
format_name=spec.builder.format_name,
encoder=spec.builder.encoder,
)
!= method_name
):
builder = spec.builder.__class__(
spec.builder.cls,
dialect=spec.builder.dialect,
format_name=spec.builder.format_name,
default_dialect=spec.builder.default_dialect,
attrs=method_loc,
attrs_registry=(
spec.attrs_registry
if not spec.builder.is_nailed
else None
),
)
builder.add_pack_method()
flags = spec.builder.get_pack_method_flags(spec.builder.cls)
if spec.builder.is_nailed:
return f"{spec.expression}.{method_name}({flags})"
else:
method_args = spec.expression
return f"_cls.{method_name}({method_args})"
elif is_required(spec.type) or is_not_required(spec.type):
return PackerRegistry.get(spec.copy(type=get_args(spec.type)[0]))
elif is_unpack(spec.type):
packer = PackerRegistry.get(spec.copy(type=get_args(spec.type)[0]))
return f"*{packer}"
elif is_type_var_tuple(spec.type):
return PackerRegistry.get(spec.copy(type=tuple[Any, ...]))
elif isinstance(spec.type, ForwardRef):
evaluated = spec.builder.evaluate_forward_ref(
spec.type, spec.owner
)
if evaluated is not None:
return PackerRegistry.get(spec.copy(type=evaluated))
elif is_type_alias_type(spec.type):
return PackerRegistry.get(spec.copy(type=spec.type.__value__))
elif is_readonly(spec.type):
return PackerRegistry.get(spec.copy(type=get_args(spec.type)[0]))
raise UnserializableDataError(
f"{spec.type} as a field type is not supported by mashumaro"
)
@register
def pack_number_and_bool_and_none(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type in (int, float, bool, NoneType, None):
return spec.expression
@register
def pack_date_objects(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type in (datetime.datetime, datetime.date, datetime.time):
return f"{spec.expression}.isoformat()"
@register
def pack_timedelta(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type is datetime.timedelta:
return f"{spec.expression}.total_seconds()"
@register
def pack_timezone(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type is datetime.timezone:
return f"{spec.expression}.tzname(None)"
@register
def pack_zone_info(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type is zoneinfo.ZoneInfo:
return f"str({spec.expression})"
@register
def pack_uuid(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type is uuid.UUID:
return f"str({spec.expression})"
@register
def pack_ipaddress(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type in (
ipaddress.IPv4Address,
ipaddress.IPv6Address,
ipaddress.IPv4Network,
ipaddress.IPv6Network,
ipaddress.IPv4Interface,
ipaddress.IPv6Interface,
):
return f"str({spec.expression})"
@register
def pack_decimal(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type is Decimal:
return f"str({spec.expression})"
@register
def pack_fraction(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type is Fraction:
return f"str({spec.expression})"
def pack_tuple(spec: ValueSpec, args: tuple[type, ...]) -> Expression:
if not args:
if spec.type in (Tuple, tuple):
args = [Any, ...] # type: ignore
else:
return "[]"
elif len(args) == 1 and args[0] == ():
if not PY_311_MIN:
return "[]"
if len(args) == 2 and args[1] is Ellipsis:
packer = PackerRegistry.get(
spec.copy(type=args[0], expression="value", could_be_none=True)
)
return f"[{packer} for value in {spec.expression}]"
else:
arg_indexes: list[Union[int, tuple[int, Union[int, None]]]] = []
unpack_idx: Optional[int] = None
for arg_idx, type_arg in enumerate(args):
if is_unpack(type_arg):
if unpack_idx is not None:
raise TypeError(
"Multiple unpacks are disallowed within a single type "
f"parameter list for {type_name(spec.type)}"
)
unpack_idx = arg_idx
if len(args) == 1:
arg_indexes.append((arg_idx, None))
elif arg_idx < len(args) - 1:
arg_indexes.append((arg_idx, arg_idx + 1 - len(args)))
else:
arg_indexes.append((arg_idx, None))
else:
if unpack_idx is None:
arg_indexes.append(arg_idx)
else:
arg_indexes.append(arg_idx - len(args))
packers: list[Expression] = []
for _idx, _arg_idx in enumerate(arg_indexes):
if isinstance(_arg_idx, tuple):
p_expr = f"{spec.expression}[{_arg_idx[0]}:{_arg_idx[1]}]"
else:
p_expr = f"{spec.expression}[{_arg_idx}]"
packer = PackerRegistry.get(
spec.copy(
type=args[_idx],
expression=p_expr,
could_be_none=True,
)
)
if packer != "*[]":
packers.append(packer)
return f"[{', '.join(packers)}]"
def pack_named_tuple(spec: ValueSpec) -> Expression:
resolved = resolve_type_params(spec.origin_type, get_args(spec.type))[
spec.origin_type
]
annotations = {
k: resolved.get(v, v)
for k, v in getattr(spec.origin_type, "__annotations__", {}).items()
}
fields = getattr(spec.type, "_fields", ())
packers = []
as_dict = spec.builder.get_dialect_or_config_option(
"namedtuple_as_dict", False
)
serialize_option = get_overridden_serialization_method(spec)
if serialize_option is not None:
if serialize_option == "as_dict":
as_dict = True
elif serialize_option == "as_list":
as_dict = False
else:
raise UnsupportedSerializationEngine(
field_name=spec.field_ctx.name,
field_type=spec.type,
holder_class=spec.builder.cls,
engine=serialize_option,
)
for idx, field in enumerate(fields):
packer = PackerRegistry.get(
spec.copy(
type=annotations.get(field, Any),
expression=f"{spec.expression}[{idx}]",
could_be_none=True,
)
)
packers.append(packer)
if as_dict:
kv = (f"'{key}': {value}" for key, value in zip(fields, packers))
return f"{{{', '.join(kv)}}}"
else:
return f"[{', '.join(packers)}]"
def pack_typed_dict(spec: ValueSpec) -> Expression:
resolved = resolve_type_params(spec.origin_type, get_args(spec.type))[
spec.origin_type
]
annotations = {
k: resolved.get(v, v)
for k, v in spec.origin_type.__annotations__.items()
}
all_keys = list(annotations.keys())
required_keys = getattr(spec.type, "__required_keys__", all_keys)
optional_keys = getattr(spec.type, "__optional_keys__", [])
lines = CodeLines()
method_name = (
f"__pack_typed_dict_{spec.builder.cls.__name__}_"
f"{spec.field_ctx.name}__{random_hex()}"
)
method_args = "self, value" if spec.builder.is_nailed else "value"
default_kwargs = spec.builder.get_pack_method_default_flag_values()
if default_kwargs:
lines.append(f"def {method_name}({method_args}, {default_kwargs}):")
else:
lines.append(f"def {method_name}({method_args}):")
with lines.indent():
lines.append("d = {}")
for key in sorted(required_keys, key=all_keys.index):
packer = PackerRegistry.get(
spec.copy(
type=annotations[key],
expression=f"value['{key}']",
could_be_none=True,
owner=spec.type,
)
)
lines.append(f"d['{key}'] = {packer}")
for key in sorted(optional_keys, key=all_keys.index):
lines.append(f"key_value = value.get('{key}', MISSING)")
with lines.indent("if key_value is not MISSING:"):
packer = PackerRegistry.get(
spec.copy(
type=annotations[key],
expression="key_value",
could_be_none=True,
owner=spec.type,
)
)
lines.append(f"d['{key}'] = {packer}")
lines.append("return d")
lines.append(
f"setattr({spec.cls_attrs_name}, '{method_name}', {method_name})"
)
if spec.builder.get_config().debug:
print(f"{type_name(spec.builder.cls)}:")
print(lines.as_text())
exec(lines.as_text(), spec.builder.globals, spec.builder.__dict__)
method_args = ", ".join(
filter(None, (spec.expression, spec.builder.get_pack_method_flags()))
)
return f"{spec.self_attrs_name}.{method_name}({method_args})"
@register
def pack_collection(spec: ValueSpec) -> Optional[Expression]:
if not issubclass(spec.origin_type, Collection):
return None
elif issubclass(spec.origin_type, enum.Enum):
return None
args = get_args(spec.type)
def inner_expr(
arg_num: int = 0, v_name: str = "value", v_type: Optional[type] = None
) -> Expression:
if v_type:
return PackerRegistry.get(
spec.copy(type=v_type, expression=v_name)
)
else:
if args and len(args) > arg_num:
type_arg: Any = args[arg_num]
else:
type_arg = Any
return PackerRegistry.get(
spec.copy(
type=type_arg,
expression=v_name,
could_be_none=True,
field_ctx=spec.field_ctx.copy(metadata={}),
)
)
def _make_sequence_expression(ie: Expression) -> Expression:
if ie == "value":
if spec.origin_type in spec.no_copy_collections:
return spec.expression
elif spec.origin_type is list:
return f"{spec.expression}.copy()"
return f"[{ie} for value in {spec.expression}]"
def _make_mapping_expression(ke: Expression, ve: Expression) -> Expression:
if ke == "key" and ve == "value":
if spec.origin_type in spec.no_copy_collections:
return spec.expression
elif spec.origin_type is dict:
return f"{spec.expression}.copy()"
return f"{{{ke}: {ve} for key, value in {spec.expression}.items()}}"
if issubclass(spec.origin_type, typing.ByteString): # type: ignore
spec.builder.ensure_object_imported(encodebytes)
return f"encodebytes({spec.expression}).decode()"
elif issubclass(spec.origin_type, str):
return spec.expression
elif issubclass(spec.origin_type, tuple):
if is_named_tuple(spec.origin_type):
return pack_named_tuple(spec)
elif ensure_generic_collection(spec):
return pack_tuple(spec, args)
elif ensure_generic_collection_subclass(spec, list, deque, Set):
ie = inner_expr()
return _make_sequence_expression(ie)
elif ensure_generic_mapping(spec, args, ChainMap):
ke = inner_expr(0, "key")
ve = inner_expr(1)
return (
f"[{{{ke}: {ve} for key, value in m.items()}} "
f"for m in {spec.expression}.maps]"
)
elif ensure_generic_mapping(spec, args, OrderedDict):
ke = inner_expr(0, "key")
ve = inner_expr(1)
return _make_mapping_expression(ke, ve)
elif ensure_generic_mapping(spec, args, Counter):
ke = inner_expr(0, "key")
ve = inner_expr(1, v_type=int)
return _make_mapping_expression(ke, ve)
elif is_typed_dict(spec.origin_type):
return pack_typed_dict(spec)
elif ensure_generic_mapping(spec, args, Mapping):
ke = inner_expr(0, "key")
ve = inner_expr(1)
return _make_mapping_expression(ke, ve)
elif ensure_generic_collection_subclass(spec, Sequence):
ie = inner_expr()
return _make_sequence_expression(ie)
@register
def pack_pathlike(spec: ValueSpec) -> Optional[Expression]:
if issubclass(spec.origin_type, os.PathLike):
return f"{spec.expression}.__fspath__()"
@register
def pack_enum(spec: ValueSpec) -> Optional[Expression]:
if issubclass(spec.origin_type, enum.Enum):
return f"{spec.expression}.value"
@register
def pack_pattern(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type in (typing.Pattern, re.Pattern):
return f"{spec.expression}.pattern"

View File

@@ -0,0 +1,62 @@
from collections.abc import Callable, Sequence
from types import new_class
from typing import Any, Type, Union, cast
from typing_extensions import Literal
from mashumaro.core.const import Sentinel
from mashumaro.types import SerializationStrategy
__all__ = ["Dialect"]
SerializationStrategyValueType = Union[
SerializationStrategy, dict[str, Union[str, Callable]]
]
class Dialect:
serialization_strategy: dict[Any, SerializationStrategyValueType] = {}
serialize_by_alias: Union[bool, Literal[Sentinel.MISSING]] = (
Sentinel.MISSING
)
namedtuple_as_dict: Union[bool, Literal[Sentinel.MISSING]] = (
Sentinel.MISSING
)
omit_none: Union[bool, Literal[Sentinel.MISSING]] = Sentinel.MISSING
omit_default: Union[bool, Literal[Sentinel.MISSING]] = Sentinel.MISSING
no_copy_collections: Union[Sequence[Any], Literal[Sentinel.MISSING]] = (
Sentinel.MISSING
)
@classmethod
def merge(cls, other: Type["Dialect"]) -> Type["Dialect"]:
serialization_strategy: dict[Any, SerializationStrategyValueType] = {}
for key, value in cls.serialization_strategy.items():
if isinstance(value, SerializationStrategy):
serialization_strategy[key] = value
else:
serialization_strategy[key] = value.copy()
for key, value in other.serialization_strategy.items():
if isinstance(value, SerializationStrategy):
serialization_strategy[key] = value
elif isinstance(
serialization_strategy.get(key), SerializationStrategy
):
serialization_strategy[key] = value
else:
(
serialization_strategy.setdefault(
key, {}
).update( # type: ignore
value
)
)
new_dialect = cast(Type[Dialect], new_class("Dialect", (Dialect,)))
new_dialect.serialization_strategy = serialization_strategy
for key in ("omit_none", "omit_default", "no_copy_collections"):
if (others_value := getattr(other, key)) is not Sentinel.MISSING:
setattr(new_dialect, key, others_value)
else:
setattr(new_dialect, key, getattr(cls, key))
return new_dialect

View File

@@ -0,0 +1,216 @@
from typing import Any, Optional, Type
from mashumaro.core.meta.helpers import type_name
class MissingField(LookupError):
def __init__(self, field_name: str, field_type: Type, holder_class: Type):
self.field_name = field_name
self.field_type = field_type
self.holder_class = holder_class
@property
def field_type_name(self) -> str:
return type_name(self.field_type, short=True)
@property
def holder_class_name(self) -> str:
return type_name(self.holder_class, short=True)
def __str__(self) -> str:
return (
f'Field "{self.field_name}" of type {self.field_type_name}'
f" is missing in {self.holder_class_name} instance"
)
class ExtraKeysError(ValueError):
def __init__(self, extra_keys: set[str], target_type: Type):
self.extra_keys = extra_keys
self.target_type = target_type
@property
def target_class_name(self) -> str:
return type_name(self.target_type, short=True)
def __str__(self) -> str:
extra_keys_str = ", ".join(k for k in self.extra_keys)
return (
"Serialized dict has keys that are not defined in "
f"{self.target_class_name}: {extra_keys_str}"
)
class UnserializableDataError(TypeError):
pass
class UnserializableField(UnserializableDataError):
def __init__(
self,
field_name: str,
field_type: Type,
holder_class: Type,
msg: Optional[str] = None,
):
self.field_name = field_name
self.field_type = field_type
self.holder_class = holder_class
self.msg = msg
@property
def field_type_name(self) -> str:
return type_name(self.field_type, short=True)
@property
def holder_class_name(self) -> str:
return type_name(self.holder_class, short=True)
def __str__(self) -> str:
s = (
f'Field "{self.field_name}" of type {self.field_type_name} '
f"in {self.holder_class_name} is not serializable"
)
if self.msg:
s += f": {self.msg}"
return s
class UnsupportedSerializationEngine(UnserializableField):
def __init__(
self,
field_name: str,
field_type: Type,
holder_class: Type,
engine: Any,
):
super(UnsupportedSerializationEngine, self).__init__(
field_name,
field_type,
holder_class,
msg=f'Unsupported serialization engine "{engine}"',
)
class UnsupportedDeserializationEngine(UnserializableField):
def __init__(
self,
field_name: str,
field_type: Type,
holder_class: Type,
engine: Any,
):
super(UnsupportedDeserializationEngine, self).__init__(
field_name,
field_type,
holder_class,
msg=f'Unsupported deserialization engine "{engine}"',
)
class InvalidFieldValue(ValueError):
def __init__(
self,
field_name: str,
field_type: Type,
field_value: Any,
holder_class: Type,
msg: Optional[str] = None,
):
self.field_name = field_name
self.field_type = field_type
self.field_value = field_value
self.holder_class = holder_class
self.msg = msg
@property
def field_type_name(self) -> str:
return type_name(self.field_type, short=True)
@property
def holder_class_name(self) -> str:
return type_name(self.holder_class, short=True)
def __str__(self) -> str:
s = (
f'Field "{self.field_name}" of type {self.field_type_name} '
f"in {self.holder_class_name} has invalid value "
f"{repr(self.field_value)}"
)
if self.msg:
s += f": {self.msg}"
return s
class MissingDiscriminatorError(LookupError):
def __init__(self, field_name: str):
self.field_name = field_name
def __str__(self) -> str:
return f"Discriminator '{self.field_name}' is missing"
class SuitableVariantNotFoundError(ValueError):
def __init__(
self,
variants_type: Type,
discriminator_name: Optional[str] = None,
discriminator_value: Any = None,
):
self.variants_type = variants_type
self.discriminator_name = discriminator_name
self.discriminator_value = discriminator_value
def __str__(self) -> str:
s = f"{type_name(self.variants_type)} has no "
if self.discriminator_value is not None:
s += (
f"subtype with attribute '{self.discriminator_name}' "
f"equal to {self.discriminator_value!r}"
)
else:
s += "suitable subtype"
return s
class BadHookSignature(TypeError):
pass
class ThirdPartyModuleNotFoundError(ModuleNotFoundError):
def __init__(self, module_name: str, field_name: str, holder_class: Type):
self.module_name = module_name
self.field_name = field_name
self.holder_class = holder_class
@property
def holder_class_name(self) -> str:
return type_name(self.holder_class, short=True)
def __str__(self) -> str:
s = (
f'Install "{self.module_name}" to use it as the serialization '
f'method for the field "{self.field_name}" '
f"in {self.holder_class_name}"
)
return s
class UnresolvedTypeReferenceError(NameError):
def __init__(self, holder_class: Type, unresolved_type_name: str):
self.holder_class = holder_class
self.name = unresolved_type_name
@property
def holder_class_name(self) -> str:
return type_name(self.holder_class, short=True)
def __str__(self) -> str:
return (
f"Class {self.holder_class_name} has unresolved type reference "
f"{self.name} in some of its fields"
)
class BadDialect(ValueError):
pass

View File

@@ -0,0 +1,59 @@
from collections.abc import Callable
from typing import Any, Optional, TypeVar, Union
from typing_extensions import Literal
from mashumaro.types import SerializationStrategy
__all__ = [
"field_options",
"pass_through",
]
NamedTupleDeserializationEngine = Literal["as_dict", "as_list"]
DateTimeDeserializationEngine = Literal["ciso8601", "pendulum"]
AnyDeserializationEngine = Literal[
NamedTupleDeserializationEngine, DateTimeDeserializationEngine
]
NamedTupleSerializationEngine = Literal["as_dict", "as_list"]
OmitSerializationEngine = Literal["omit"]
AnySerializationEngine = Union[
NamedTupleSerializationEngine, OmitSerializationEngine
]
T = TypeVar("T")
def field_options(
serialize: Optional[
Union[AnySerializationEngine, Callable[[Any], Any]]
] = None,
deserialize: Optional[
Union[AnyDeserializationEngine, Callable[[Any], Any]]
] = None,
serialization_strategy: Optional[SerializationStrategy] = None,
alias: Optional[str] = None,
) -> dict[str, Any]:
return {
"serialize": serialize,
"deserialize": deserialize,
"serialization_strategy": serialization_strategy,
"alias": alias,
}
class _PassThrough(SerializationStrategy):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
raise NotImplementedError
def serialize(self, value: T) -> T:
return value
def deserialize(self, value: T) -> T:
return value
pass_through = _PassThrough()

View File

@@ -0,0 +1,9 @@
from .builder import JSONSchemaBuilder, build_json_schema
from .dialects import DRAFT_2020_12, OPEN_API_3_1
__all__ = [
"JSONSchemaBuilder",
"build_json_schema",
"DRAFT_2020_12",
"OPEN_API_3_1",
]

View File

@@ -0,0 +1,134 @@
from dataclasses import dataclass
from mashumaro.jsonschema.models import JSONSchema, Number
class Annotation:
pass
class Constraint(Annotation):
pass
class NumberConstraint(Constraint):
pass
@dataclass(unsafe_hash=True)
class Minimum(NumberConstraint):
value: Number
@dataclass(unsafe_hash=True)
class Maximum(NumberConstraint):
value: Number
@dataclass(unsafe_hash=True)
class ExclusiveMinimum(NumberConstraint):
value: Number
@dataclass(unsafe_hash=True)
class ExclusiveMaximum(NumberConstraint):
value: Number
@dataclass(unsafe_hash=True)
class MultipleOf(NumberConstraint):
value: Number
class StringConstraint(Constraint):
pass
@dataclass(unsafe_hash=True)
class MinLength(StringConstraint):
value: int
@dataclass(unsafe_hash=True)
class MaxLength(StringConstraint):
value: int
@dataclass(unsafe_hash=True)
class Pattern(StringConstraint):
value: str
class ArrayConstraint(Constraint):
pass
@dataclass(unsafe_hash=True)
class MinItems(ArrayConstraint):
value: int
@dataclass(unsafe_hash=True)
class MaxItems(ArrayConstraint):
value: int
@dataclass(unsafe_hash=True)
class UniqueItems(ArrayConstraint):
value: bool
@dataclass(unsafe_hash=True)
class Contains(ArrayConstraint):
value: JSONSchema
@dataclass(unsafe_hash=True)
class MinContains(ArrayConstraint):
value: int
@dataclass(unsafe_hash=True)
class MaxContains(ArrayConstraint):
value: int
class ObjectConstraint(Constraint):
pass
@dataclass(unsafe_hash=True)
class MaxProperties(ObjectConstraint):
value: int
@dataclass(unsafe_hash=True)
class MinProperties(ObjectConstraint):
value: int
@dataclass
class DependentRequired(ObjectConstraint):
value: dict[str, set[str]]
__all__ = [
"Annotation",
"MultipleOf",
"Maximum",
"ExclusiveMaximum",
"Minimum",
"ExclusiveMinimum",
"MaxLength",
"MinLength",
"Pattern",
"MaxItems",
"MinItems",
"UniqueItems",
"Contains",
"MaxContains",
"MinContains",
"MaxProperties",
"MinProperties",
"DependentRequired",
]

View File

@@ -0,0 +1,97 @@
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Optional, Type
from mashumaro.jsonschema.dialects import DRAFT_2020_12, JSONSchemaDialect
from mashumaro.jsonschema.models import Context, JSONSchema
from mashumaro.jsonschema.plugins import BasePlugin
from mashumaro.jsonschema.schema import Instance, get_schema
try:
from mashumaro.mixins.orjson import (
DataClassORJSONMixin as DataClassJSONMixin,
)
except ImportError: # pragma: no cover
from mashumaro.mixins.json import DataClassJSONMixin # type: ignore
def build_json_schema(
instance_type: Type,
context: Optional[Context] = None,
with_definitions: bool = True,
all_refs: Optional[bool] = None,
with_dialect_uri: bool = False,
dialect: Optional[JSONSchemaDialect] = None,
ref_prefix: Optional[str] = None,
plugins: Sequence[BasePlugin] = (),
) -> JSONSchema:
if context is None:
context = Context()
else:
context = Context(
dialect=context.dialect,
definitions=context.definitions,
all_refs=context.all_refs,
ref_prefix=context.ref_prefix,
plugins=context.plugins,
)
if dialect is not None:
context.dialect = dialect
if all_refs is not None:
context.all_refs = all_refs
elif context.all_refs is None:
context.all_refs = context.dialect.all_refs
if ref_prefix is not None:
context.ref_prefix = ref_prefix.rstrip("/")
elif context.ref_prefix is None:
context.ref_prefix = context.dialect.definitions_root_pointer
if plugins:
context.plugins = plugins
instance = Instance(instance_type)
schema = get_schema(instance, context, with_dialect_uri=with_dialect_uri)
if with_definitions and context.definitions:
schema.definitions = context.definitions
return schema
@dataclass
class JSONSchemaDefinitions(DataClassJSONMixin):
definitions: dict[str, JSONSchema]
def __post_serialize__( # type: ignore
self, d: dict[Any, Any]
) -> list[dict[str, Any]]:
return d["definitions"]
class JSONSchemaBuilder:
def __init__(
self,
dialect: JSONSchemaDialect = DRAFT_2020_12,
all_refs: Optional[bool] = None,
ref_prefix: Optional[str] = None,
plugins: Sequence[BasePlugin] = (),
):
if all_refs is None:
all_refs = dialect.all_refs
if ref_prefix is None:
ref_prefix = dialect.definitions_root_pointer
self.context = Context(
dialect=dialect,
all_refs=all_refs,
ref_prefix=ref_prefix.rstrip("/"),
plugins=plugins,
)
def build(self, instance_type: Type) -> JSONSchema:
return build_json_schema(
instance_type=instance_type,
context=self.context,
with_definitions=False,
)
def get_definitions(self) -> JSONSchemaDefinitions:
return JSONSchemaDefinitions(self.context.definitions)
__all__ = ["JSONSchemaBuilder", "build_json_schema"]

View File

@@ -0,0 +1,29 @@
from dataclasses import dataclass
@dataclass(frozen=True)
class JSONSchemaDialect:
uri: str
definitions_root_pointer: str
all_refs: bool
@dataclass(frozen=True)
class JSONSchemaDraft202012Dialect(JSONSchemaDialect):
uri: str = "https://json-schema.org/draft/2020-12/schema"
definitions_root_pointer: str = "#/$defs"
all_refs: bool = False
@dataclass(frozen=True)
class OpenAPISchema31Dialect(JSONSchemaDialect):
uri: str = "https://spec.openapis.org/oas/3.1/dialect/base"
definitions_root_pointer: str = "#/components/schemas"
all_refs: bool = True
DRAFT_2020_12 = JSONSchemaDraft202012Dialect()
OPEN_API_3_1 = OpenAPISchema31Dialect()
__all__ = ["JSONSchemaDialect", "DRAFT_2020_12", "OPEN_API_3_1"]

View File

@@ -0,0 +1,212 @@
import datetime
import ipaddress
from collections.abc import Sequence
from dataclasses import MISSING, dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing_extensions import TYPE_CHECKING, Self, TypeAlias
from mashumaro.config import BaseConfig
from mashumaro.core.meta.helpers import iter_all_subclasses
from mashumaro.helper import pass_through
from mashumaro.jsonschema.dialects import DRAFT_2020_12, JSONSchemaDialect
if TYPE_CHECKING: # pragma: no cover
from mashumaro.jsonschema.plugins import BasePlugin
else:
BasePlugin = Any
try:
from mashumaro.mixins.orjson import (
DataClassORJSONMixin as DataClassJSONMixin,
)
except ImportError: # pragma: no cover
from mashumaro.mixins.json import DataClassJSONMixin # type: ignore
# https://github.com/python/mypy/issues/3186
Number: TypeAlias = Union[int, float]
Null = object()
class JSONSchemaInstanceType(Enum):
NULL = "null"
BOOLEAN = "boolean"
OBJECT = "object"
ARRAY = "array"
NUMBER = "number"
STRING = "string"
INTEGER = "integer"
class JSONSchemaInstanceFormat(Enum):
pass
class JSONSchemaStringFormat(JSONSchemaInstanceFormat):
DATETIME = "date-time"
DATE = "date"
TIME = "time"
DURATION = "duration"
EMAIL = "email"
IDN_EMAIL = "idn-email"
HOSTNAME = "hostname"
IDN_HOSTNAME = "idn-hostname"
IPV4ADDRESS = "ipv4"
IPV6ADDRESS = "ipv6"
URI = "uri"
URI_REFERENCE = "uri-reference"
IRI = "iri"
IRI_REFERENCE = "iri-reference"
UUID = "uuid"
URI_TEMPLATE = "uri-template"
JSON_POINTER = "json-pointer"
RELATIVE_JSON_POINTER = "relative-json-pointer"
REGEX = "regex"
class JSONSchemaInstanceFormatExtension(JSONSchemaInstanceFormat):
TIMEDELTA = "time-delta"
TIME_ZONE = "time-zone"
IPV4NETWORK = "ipv4network"
IPV6NETWORK = "ipv6network"
IPV4INTERFACE = "ipv4interface"
IPV6INTERFACE = "ipv6interface"
DECIMAL = "decimal"
FRACTION = "fraction"
BASE64 = "base64"
PATH = "path"
DATETIME_FORMATS = {
datetime.datetime: JSONSchemaStringFormat.DATETIME,
datetime.date: JSONSchemaStringFormat.DATE,
datetime.time: JSONSchemaStringFormat.TIME,
}
IPADDRESS_FORMATS = {
ipaddress.IPv4Address: JSONSchemaStringFormat.IPV4ADDRESS,
ipaddress.IPv6Address: JSONSchemaStringFormat.IPV6ADDRESS,
ipaddress.IPv4Network: JSONSchemaInstanceFormatExtension.IPV4NETWORK,
ipaddress.IPv6Network: JSONSchemaInstanceFormatExtension.IPV6NETWORK,
ipaddress.IPv4Interface: JSONSchemaInstanceFormatExtension.IPV4INTERFACE,
ipaddress.IPv6Interface: JSONSchemaInstanceFormatExtension.IPV6INTERFACE,
}
def _deserialize_json_schema_instance_format(
value: Any,
) -> JSONSchemaInstanceFormat:
for cls in iter_all_subclasses(JSONSchemaInstanceFormat):
try:
return cls(value)
except (ValueError, TypeError):
pass
raise ValueError(value)
@dataclass(unsafe_hash=True)
class JSONSchema(DataClassJSONMixin):
# Common keywords
schema: Optional[str] = None
type: Optional[JSONSchemaInstanceType] = None
enum: Optional[list[Any]] = None
const: Optional[Any] = field(default_factory=lambda: MISSING)
format: Optional[JSONSchemaInstanceFormat] = None
title: Optional[str] = None
description: Optional[str] = None
anyOf: Optional[List["JSONSchema"]] = None
reference: Optional[str] = None
definitions: Optional[Dict[str, "JSONSchema"]] = None
default: Optional[Any] = field(default_factory=lambda: MISSING)
deprecated: Optional[bool] = None
examples: Optional[list[Any]] = None
# Keywords for Objects
properties: Optional[Dict[str, "JSONSchema"]] = None
patternProperties: Optional[Dict[str, "JSONSchema"]] = None
additionalProperties: Union["JSONSchema", bool, None] = None
propertyNames: Optional["JSONSchema"] = None
# Keywords for Arrays
prefixItems: Optional[List["JSONSchema"]] = None
items: Optional["JSONSchema"] = None
contains: Optional["JSONSchema"] = None
# Validation keywords for numeric instances
multipleOf: Optional[Number] = None
maximum: Optional[Number] = None
exclusiveMaximum: Optional[Number] = None
minimum: Optional[Number] = None
exclusiveMinimum: Optional[Number] = None
# Validation keywords for Strings
maxLength: Optional[int] = None
minLength: Optional[int] = None
pattern: Optional[str] = None
# Validation keywords for Arrays
maxItems: Optional[int] = None
minItems: Optional[int] = None
uniqueItems: Optional[bool] = None
maxContains: Optional[int] = None
minContains: Optional[int] = None
# Validation keywords for Objects
maxProperties: Optional[int] = None
minProperties: Optional[int] = None
required: Optional[list[str]] = None
dependentRequired: Optional[dict[str, set[str]]] = None
class Config(BaseConfig):
omit_none = True
serialize_by_alias = True
aliases = {
"schema": "$schema",
"reference": "$ref",
"definitions": "$defs",
}
serialization_strategy = {
int: pass_through,
float: pass_through,
Null: pass_through,
JSONSchemaInstanceFormat: {
"deserialize": _deserialize_json_schema_instance_format,
},
}
def __pre_serialize__(self) -> Self:
if self.const is None:
self.const = Null
if self.default is None:
self.default = Null
return self
def __post_serialize__(self, d: dict[Any, Any]) -> dict[Any, Any]:
const = d.get("const")
if const is MISSING:
d.pop("const")
elif const is Null:
d["const"] = None
default = d.get("default")
if default is MISSING:
d.pop("default")
elif default is Null:
d["default"] = None
return d
@dataclass
class JSONObjectSchema(JSONSchema):
type: Optional[JSONSchemaInstanceType] = JSONSchemaInstanceType.OBJECT
@dataclass
class JSONArraySchema(JSONSchema):
type: Optional[JSONSchemaInstanceType] = JSONSchemaInstanceType.ARRAY
@dataclass
class Context:
dialect: JSONSchemaDialect = DRAFT_2020_12
definitions: dict[str, JSONSchema] = field(default_factory=dict)
all_refs: Optional[bool] = None
ref_prefix: Optional[str] = None
plugins: Sequence[BasePlugin] = ()

View File

@@ -0,0 +1,28 @@
from dataclasses import is_dataclass
from inspect import cleandoc
from typing import Optional
from mashumaro.jsonschema.models import Context, JSONSchema
from mashumaro.jsonschema.schema import Instance
class BasePlugin:
def get_schema(
self,
instance: Instance,
ctx: Context,
schema: Optional[JSONSchema] = None,
) -> Optional[JSONSchema]:
pass
class DocstringDescriptionPlugin(BasePlugin):
def get_schema(
self,
instance: Instance,
ctx: Context,
schema: Optional[JSONSchema] = None,
) -> Optional[JSONSchema]:
if schema and is_dataclass(instance.type) and instance.type.__doc__:
schema.description = cleandoc(instance.type.__doc__)
return None

View File

@@ -0,0 +1,886 @@
import datetime
import ipaddress
import os
import warnings
from base64 import encodebytes
from collections import ChainMap, Counter, deque
from collections.abc import (
ByteString,
Callable,
Collection,
Iterable,
Mapping,
Sequence,
Set,
)
from dataclasses import MISSING, dataclass, field, is_dataclass, replace
from decimal import Decimal
from enum import Enum
from fractions import Fraction
from functools import cached_property
from typing import Any, ForwardRef, Optional, Tuple, Type, Union
from uuid import UUID
from zoneinfo import ZoneInfo
from typing_extensions import TypeAlias
from mashumaro.config import BaseConfig
from mashumaro.core.const import PY_311_MIN
from mashumaro.core.meta.code.builder import CodeBuilder
from mashumaro.core.meta.helpers import (
evaluate_forward_ref,
get_args,
get_forward_ref_referencing_globals,
get_function_return_annotation,
get_literal_values,
get_type_origin,
is_annotated,
is_generic,
is_literal,
is_named_tuple,
is_new_type,
is_not_required,
is_readonly,
is_required,
is_special_typing_primitive,
is_type_var,
is_type_var_any,
is_type_var_tuple,
is_typed_dict,
is_union,
is_unpack,
resolve_type_params,
type_name,
)
from mashumaro.core.meta.types.common import NoneType
from mashumaro.helper import pass_through
from mashumaro.jsonschema.annotations import (
Annotation,
Contains,
DependentRequired,
ExclusiveMaximum,
ExclusiveMinimum,
MaxContains,
Maximum,
MaxItems,
MaxLength,
MaxProperties,
MinContains,
Minimum,
MinItems,
MinLength,
MinProperties,
MultipleOf,
Pattern,
UniqueItems,
)
from mashumaro.jsonschema.models import (
DATETIME_FORMATS,
IPADDRESS_FORMATS,
Context,
JSONArraySchema,
JSONObjectSchema,
JSONSchema,
JSONSchemaInstanceFormatExtension,
JSONSchemaInstanceType,
JSONSchemaStringFormat,
)
from mashumaro.types import SerializationStrategy
try:
from mashumaro.mixins.orjson import (
DataClassORJSONMixin as DataClassJSONMixin,
)
except ImportError: # pragma: no cover
from mashumaro.mixins.json import DataClassJSONMixin # type: ignore
UTC_OFFSET_PATTERN = r"^UTC([+-][0-2][0-9]:[0-5][0-9])?$"
@dataclass
class Instance:
type: Type
name: Optional[str] = None
__owner_builder: Optional[CodeBuilder] = None
__self_builder: Optional[CodeBuilder] = None
# Original type despite custom serialization. To be revised.
_original_type: Type = field(init=False)
origin_type: Type = field(init=False)
annotations: list[Annotation] = field(init=False, default_factory=list)
@cached_property
def metadata(self) -> dict[str, Any]:
if self.name and self.__owner_builder:
return dict(**self.__owner_builder.metadatas.get(self.name, {}))
else:
return {}
@property
def _self_builder(self) -> CodeBuilder:
assert self.__self_builder
return self.__self_builder
@property
def alias(self) -> Optional[str]:
alias = self.metadata.get("alias")
if alias is None:
aliases_config = self.get_owner_config().aliases
alias = aliases_config.get(self.name) # type: ignore
if alias is None:
alias = self.name
return alias
@property
def owner_class(self) -> Optional[Type]:
if self.__owner_builder:
return self.__owner_builder.cls
return None
def derive(self, **changes: Any) -> "Instance":
new_type = changes.get("type")
if isinstance(new_type, ForwardRef):
changes["type"] = evaluate_forward_ref(
new_type,
get_forward_ref_referencing_globals(new_type, self.type),
self.__dict__,
)
new_instance = replace(self, **changes)
if is_dataclass(self.origin_type):
new_instance.__owner_builder = self.__self_builder
return new_instance
def __post_init__(self) -> None:
self._original_type = self.type
self.update_type(self.type)
if is_annotated(self.type):
self.annotations = getattr(self.type, "__metadata__", [])
self.type = get_args(self.type)[0]
self.origin_type = get_type_origin(self.type)
def update_type(self, new_type: Type) -> None:
if self.__owner_builder:
self.type = self.__owner_builder.get_real_type(
field_name=self.name, # type: ignore
field_type=new_type,
)
self.origin_type = get_type_origin(self.type)
if is_dataclass(self.origin_type):
type_args = get_args(self.type)
self.__self_builder = CodeBuilder(self.origin_type, type_args)
self.__self_builder.reset()
else:
self.__self_builder = None
def fields(self) -> Iterable[tuple[str, Type, bool, Any]]:
for f_name, f_type in self._self_builder.get_field_types(
include_extras=True
).items():
f = self._self_builder.dataclass_fields.get(f_name)
if not f or f and not f.init:
continue
f_default = f.default
if f_default is MISSING:
f_default = self._self_builder.namespace.get(f_name, MISSING)
if f_default is not MISSING:
f_default = _default(f_type, f_default, self.get_self_config())
has_default = (
f.default is not MISSING or f.default_factory is not MISSING
)
yield f_name, f_type, has_default, f_default
def get_overridden_serialization_method(
self,
) -> Optional[Union[Callable, str]]:
if not self.__owner_builder:
return None
serialize_option = self.metadata.get("serialize")
if serialize_option is not None:
if callable(serialize_option):
self.metadata.pop("serialize", None) # prevent recursion
return serialize_option
for strategy in self.__owner_builder.iter_serialization_strategies(
self.metadata, self.type
):
if strategy is pass_through:
return pass_through
elif isinstance(strategy, dict):
serialize_option = strategy.get("serialize")
elif isinstance(strategy, SerializationStrategy):
serialize_option = strategy.serialize
if serialize_option is not None:
return serialize_option
return None
def get_owner_config(self) -> Type[BaseConfig]:
if self.__owner_builder:
return self.__owner_builder.get_config()
else:
return BaseConfig
def get_owner_dialect_or_config_option(
self, option: str, default: Any
) -> Any:
if self.__owner_builder:
return self.__owner_builder.get_dialect_or_config_option(
option, default
)
else:
return default
def get_self_config(self) -> Type[BaseConfig]:
if self.__self_builder:
return self.__self_builder.get_config()
else:
return BaseConfig
InstanceSchemaCreator: TypeAlias = Callable[
[Instance, Context], Optional[JSONSchema]
]
@dataclass
class InstanceSchemaCreatorRegistry:
_registry: list[InstanceSchemaCreator] = field(default_factory=list)
def register(self, func: InstanceSchemaCreator) -> InstanceSchemaCreator:
self._registry.append(func)
return func
def iter(self) -> Iterable[InstanceSchemaCreator]:
yield from self._registry
@dataclass
class EmptyJSONSchema(JSONSchema):
pass
def get_schema(
instance: Instance, ctx: Context, with_dialect_uri: bool = False
) -> JSONSchema:
schema = None
for schema_creator in Registry.iter():
schema = schema_creator(instance, ctx)
if schema is not None:
if with_dialect_uri:
schema.schema = ctx.dialect.uri
break
for plugin in ctx.plugins:
try:
new_schema = plugin.get_schema(instance, ctx, schema)
if new_schema:
schema = new_schema
except NotImplementedError:
continue
if schema:
return schema
raise NotImplementedError(
f'Type {type_name(instance.type)} of field "{instance.name}" '
f"in {type_name(instance.owner_class)} isn't supported"
)
def _get_schema_or_none(
instance: Instance, ctx: Context
) -> Optional[JSONSchema]:
schema = get_schema(instance, ctx)
if isinstance(schema, EmptyJSONSchema):
return None
return schema
def _default(f_type: Type, f_value: Any, config_cls: Type[BaseConfig]) -> Any:
@dataclass
class CC(DataClassJSONMixin):
x: f_type = f_value # type: ignore
class Config(config_cls): # type: ignore
pass
return CC(f_value).to_dict()["x"]
Registry = InstanceSchemaCreatorRegistry()
register = Registry.register
BASIC_TYPES = {str, int, float, bool}
@register
def on_type_with_overridden_serialization(
instance: Instance, ctx: Context
) -> Optional[JSONSchema]:
def override_with_any(reason: Any) -> None:
if instance.owner_class is not None:
name = f"{type_name(instance.owner_class)}.{instance.name}"
else: # pragma: no cover
# we will have an owner class, but leave this here just in case
name = type_name(instance.type)
warnings.warn(
f"Type Any will be used for {name} with "
f"overridden serialization method: {reason}"
)
instance.update_type(Any) # type: ignore[arg-type]
overridden_method = instance.get_overridden_serialization_method()
if overridden_method is pass_through:
return None
elif overridden_method in BASIC_TYPES:
instance.update_type(overridden_method) # type: ignore
elif callable(overridden_method):
try:
new_type = get_function_return_annotation(overridden_method)
if new_type is instance.type:
return None
else:
instance.update_type(new_type)
except Exception as e:
override_with_any(e)
return get_schema(instance, ctx)
@register
def on_dataclass(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
# TODO: Self references might not work
if is_dataclass(instance.origin_type):
jsonschema_config = instance.get_self_config().json_schema
schema = JSONObjectSchema(
title=instance.origin_type.__name__,
additionalProperties=jsonschema_config.get(
"additionalProperties", False
),
)
properties: dict[str, JSONSchema] = {}
required = []
field_schema_overrides = jsonschema_config.get("properties", {})
for f_name, f_type, has_default, f_default in instance.fields():
override = field_schema_overrides.get(f_name)
f_instance = instance.derive(type=f_type, name=f_name)
if override:
f_schema = JSONSchema.from_dict(override)
else:
f_schema = get_schema(f_instance, ctx)
if f_instance.alias:
f_name = f_instance.alias
if f_default is not MISSING:
f_schema.default = f_default
description = f_instance.metadata.get("description")
if description:
f_schema.description = description
if not has_default:
required.append(f_name)
properties[f_name] = f_schema
if properties:
schema.properties = properties
if required:
schema.required = required
if ctx.all_refs:
ctx.definitions[instance.origin_type.__name__] = schema
ref_prefix = ctx.ref_prefix or ctx.dialect.definitions_root_pointer
return JSONSchema(
reference=f"{ref_prefix}/{instance.origin_type.__name__}"
)
else:
return schema
@register
def on_any(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
if instance.type is Any:
return EmptyJSONSchema()
def on_literal(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
enum_values = []
for value in get_literal_values(instance.type):
if isinstance(value, Enum):
enum_values.append(value.value)
elif isinstance(value, (int, str, bool, NoneType)): # type: ignore
enum_values.append(value)
elif isinstance(value, bytes):
enum_values.append(encodebytes(value).decode())
if len(enum_values) == 1:
return JSONSchema(const=enum_values[0])
else:
return JSONSchema(enum=enum_values)
@register
def on_special_typing_primitive(
instance: Instance, ctx: Context
) -> Optional[JSONSchema]:
if not is_special_typing_primitive(instance.origin_type):
return None
args = get_args(instance.type)
if is_union(instance.type):
return JSONSchema(
anyOf=[get_schema(instance.derive(type=arg), ctx) for arg in args]
)
elif is_type_var_any(instance.type):
return EmptyJSONSchema()
elif is_type_var(instance.type):
constraints = getattr(instance.type, "__constraints__")
if constraints:
return JSONSchema(
anyOf=[
get_schema(instance.derive(type=arg), ctx)
for arg in constraints
]
)
else:
bound = getattr(instance.type, "__bound__")
return get_schema(instance.derive(type=bound), ctx)
elif is_new_type(instance.type):
return get_schema(
instance.derive(type=instance.type.__supertype__), ctx
)
elif is_literal(instance.type):
return on_literal(instance, ctx)
# elif is_self(instance.type):
# raise NotImplementedError
elif is_required(instance.type) or is_not_required(instance.type):
return get_schema(instance.derive(type=args[0]), ctx)
elif is_unpack(instance.type):
return get_schema(
instance.derive(type=get_args(instance.type)[0]), ctx
)
elif is_type_var_tuple(instance.type):
return get_schema(instance.derive(type=tuple[Any, ...]), ctx)
elif is_readonly(instance.type):
return get_schema(instance.derive(type=args[0]), ctx)
elif isinstance(instance.type, ForwardRef):
evaluated = evaluate_forward_ref(
instance.type,
get_forward_ref_referencing_globals(instance.type),
None,
)
if evaluated is not None:
return get_schema(instance.derive(type=evaluated), ctx)
@register
def on_number(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
if instance.origin_type is int:
schema = JSONSchema(type=JSONSchemaInstanceType.INTEGER)
elif instance.origin_type is float:
schema = JSONSchema(type=JSONSchemaInstanceType.NUMBER)
else:
return None
for annotation in instance.annotations:
if isinstance(annotation, Maximum):
schema.maximum = annotation.value
elif isinstance(annotation, Minimum):
schema.minimum = annotation.value
elif isinstance(annotation, ExclusiveMaximum):
schema.exclusiveMaximum = annotation.value
elif isinstance(annotation, ExclusiveMinimum):
schema.exclusiveMinimum = annotation.value
elif isinstance(annotation, MultipleOf):
schema.multipleOf = annotation.value
return schema
@register
def on_bool(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
if instance.origin_type is bool:
return JSONSchema(type=JSONSchemaInstanceType.BOOLEAN)
@register
def on_none(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
if instance.origin_type in (NoneType, None):
return JSONSchema(type=JSONSchemaInstanceType.NULL)
@register
def on_date_objects(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
if instance.origin_type in (
datetime.datetime,
datetime.date,
datetime.time,
):
return JSONSchema(
type=JSONSchemaInstanceType.STRING,
format=DATETIME_FORMATS[instance.origin_type],
)
@register
def on_timedelta(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
if instance.origin_type is datetime.timedelta:
return JSONSchema(
type=JSONSchemaInstanceType.NUMBER,
format=JSONSchemaInstanceFormatExtension.TIMEDELTA,
)
@register
def on_timezone(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
if instance.origin_type is datetime.timezone:
return JSONSchema(
type=JSONSchemaInstanceType.STRING, pattern=UTC_OFFSET_PATTERN
)
@register
def on_zone_info(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
if instance.origin_type is ZoneInfo:
return JSONSchema(
type=JSONSchemaInstanceType.STRING,
format=JSONSchemaInstanceFormatExtension.TIME_ZONE,
)
@register
def on_uuid(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
if instance.origin_type is UUID:
return JSONSchema(
type=JSONSchemaInstanceType.STRING,
format=JSONSchemaStringFormat.UUID,
)
@register
def on_ipaddress(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
if instance.origin_type in (
ipaddress.IPv4Address,
ipaddress.IPv6Address,
ipaddress.IPv4Network,
ipaddress.IPv6Network,
ipaddress.IPv4Interface,
ipaddress.IPv6Interface,
):
return JSONSchema(
type=JSONSchemaInstanceType.STRING,
format=IPADDRESS_FORMATS[instance.origin_type], # type: ignore
)
@register
def on_decimal(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
if instance.origin_type is Decimal:
return JSONSchema(
type=JSONSchemaInstanceType.STRING,
format=JSONSchemaInstanceFormatExtension.DECIMAL,
)
@register
def on_fraction(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
if instance.origin_type is Fraction:
return JSONSchema(
type=JSONSchemaInstanceType.STRING,
format=JSONSchemaInstanceFormatExtension.FRACTION,
)
def on_tuple(instance: Instance, ctx: Context) -> JSONArraySchema:
args = get_args(instance.type)
if not args:
if instance.type in (Tuple, tuple):
args = [Any, ...] # type: ignore
else:
return JSONArraySchema(maxItems=0)
elif len(args) == 1 and args[0] == ():
if not PY_311_MIN:
return JSONArraySchema(maxItems=0)
if len(args) == 2 and args[1] is Ellipsis:
items_schema = _get_schema_or_none(instance.derive(type=args[0]), ctx)
return JSONArraySchema(items=items_schema)
else:
min_items = 0
max_items = 0
prefix_items = []
items: Optional[JSONSchema] = None
unpack_schema: Optional[JSONSchema] = None
unpack_idx = 0
for arg_idx, arg in enumerate(args, start=1):
if not is_unpack(arg):
min_items += 1
if not unpack_schema:
prefix_items.append(
get_schema(instance.derive(type=arg), ctx)
)
else:
unpack_schema = get_schema(instance.derive(type=arg), ctx)
unpack_idx = arg_idx
if unpack_schema:
prefix_items.extend(unpack_schema.prefixItems or [])
min_items += unpack_schema.minItems or 0
max_items += unpack_schema.maxItems or 0
if unpack_idx == len(args):
items = unpack_schema.items
else:
min_items = len(args)
max_items = len(args)
return JSONArraySchema(
prefixItems=prefix_items or None,
items=items,
minItems=min_items or None,
maxItems=max_items or None,
)
def on_named_tuple(instance: Instance, ctx: Context) -> JSONSchema:
resolved = resolve_type_params(
instance.origin_type, get_args(instance.type)
)[instance.origin_type]
annotations = {
k: resolved.get(v, v)
for k, v in getattr(
instance.origin_type, "__annotations__", {}
).items()
}
fields = getattr(instance.type, "_fields", ())
defaults = getattr(instance.type, "_field_defaults", {})
as_dict = instance.get_owner_dialect_or_config_option(
"namedtuple_as_dict", False
)
serialize_option = instance.get_overridden_serialization_method()
if serialize_option == "as_dict":
as_dict = True
elif serialize_option == "as_list":
as_dict = False
properties = {}
for f_name in fields:
f_type = annotations.get(f_name, Any)
f_schema = get_schema(instance.derive(type=f_type), ctx)
f_default = defaults.get(f_name, MISSING)
if f_default is not MISSING:
if isinstance(f_schema, EmptyJSONSchema):
f_schema = JSONSchema()
f_schema.default = _default(
f_type, f_default, instance.get_self_config()
)
properties[f_name] = f_schema
if as_dict:
return JSONObjectSchema(
properties=properties or None,
required=list(fields),
additionalProperties=False,
)
else:
return JSONArraySchema(
prefixItems=list(properties.values()) or None,
maxItems=len(properties) or None,
minItems=len(properties) or None,
)
def on_typed_dict(instance: Instance, ctx: Context) -> JSONObjectSchema:
resolved = resolve_type_params(
instance.origin_type, get_args(instance.type)
)[instance.origin_type]
annotations = {
k: resolved.get(v, v)
for k, v in instance.origin_type.__annotations__.items()
}
all_keys = list(annotations.keys())
required_keys = getattr(instance.type, "__required_keys__", all_keys)
return JSONObjectSchema(
properties={
key: get_schema(instance.derive(type=annotations[key]), ctx)
for key in all_keys
}
or None,
required=sorted(required_keys) or None,
additionalProperties=False,
)
def apply_array_constraints(
instance: Instance,
schema: JSONSchema,
) -> JSONSchema:
has_contains = False
min_contains: Optional[int] = None
max_contains: Optional[int] = None
for annotation in instance.annotations:
if isinstance(annotation, MinItems):
schema.minItems = annotation.value
elif isinstance(annotation, MaxItems):
schema.maxItems = annotation.value
elif isinstance(annotation, UniqueItems):
schema.uniqueItems = annotation.value
elif isinstance(annotation, Contains):
schema.contains = annotation.value
has_contains = True
elif isinstance(annotation, MinContains):
min_contains = annotation.value
elif isinstance(annotation, MaxContains):
max_contains = annotation.value
if has_contains:
if min_contains is not None:
schema.minContains = min_contains
if max_contains is not None:
schema.maxContains = max_contains
return schema
def apply_object_constraints(
instance: Instance, schema: JSONSchema
) -> JSONSchema:
for annotation in instance.annotations:
if isinstance(annotation, MaxProperties):
schema.maxProperties = annotation.value
elif isinstance(annotation, MinProperties):
schema.minProperties = annotation.value
elif isinstance(annotation, DependentRequired):
schema.dependentRequired = annotation.value
return schema
@register
def on_collection(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
if not issubclass(instance.origin_type, Collection):
return None
elif issubclass(instance.origin_type, Enum):
return None
args = get_args(instance.type)
if issubclass(instance.origin_type, ByteString): # type: ignore[arg-type]
return JSONSchema(
type=JSONSchemaInstanceType.STRING,
format=JSONSchemaInstanceFormatExtension.BASE64,
)
elif issubclass(instance.origin_type, str):
schema = JSONSchema(type=JSONSchemaInstanceType.STRING)
for annotation in instance.annotations:
if isinstance(annotation, MinLength):
schema.minLength = annotation.value
elif isinstance(annotation, MaxLength):
schema.maxLength = annotation.value
elif isinstance(annotation, Pattern):
schema.pattern = annotation.value
return schema
elif is_generic(instance.type) and issubclass(
instance.origin_type, (list, deque)
):
return apply_array_constraints(
instance,
JSONArraySchema(
items=(
_get_schema_or_none(instance.derive(type=args[0]), ctx)
if args
else None
)
),
)
elif issubclass(instance.origin_type, tuple):
if is_named_tuple(instance.origin_type):
return apply_array_constraints(
instance, on_named_tuple(instance, ctx)
)
elif is_generic(instance.type):
return apply_array_constraints(instance, on_tuple(instance, ctx))
elif is_generic(instance.type) and issubclass(
instance.origin_type, (frozenset, Set)
):
return apply_array_constraints(
instance,
JSONArraySchema(
items=(
_get_schema_or_none(instance.derive(type=args[0]), ctx)
if args
else None
),
uniqueItems=True,
),
)
elif is_generic(instance.type) and issubclass(
instance.origin_type, ChainMap
):
return apply_array_constraints(
instance,
JSONArraySchema(
items=get_schema(
instance=instance.derive(
type=(
dict[args[0], args[1]] # type: ignore
if args
else dict
)
),
ctx=ctx,
)
),
)
elif is_generic(instance.type) and issubclass(
instance.origin_type, Counter
):
schema = JSONObjectSchema(
additionalProperties=get_schema(instance.derive(type=int), ctx),
)
if args:
schema.propertyNames = _get_schema_or_none(
instance.derive(type=args[0]), ctx
)
return apply_object_constraints(instance, schema)
elif is_typed_dict(instance.origin_type):
return on_typed_dict(instance, ctx)
elif is_generic(instance.type) and issubclass(
instance.origin_type, Mapping
):
schema = JSONObjectSchema(
additionalProperties=(
_get_schema_or_none(instance.derive(type=args[1]), ctx)
if args
else None
),
propertyNames=(
_get_schema_or_none(instance.derive(type=args[0]), ctx)
if args
else None
),
)
return apply_object_constraints(instance, schema)
elif is_generic(instance.type) and issubclass(
instance.origin_type, Sequence
):
return apply_array_constraints(
instance,
JSONArraySchema(
items=(
_get_schema_or_none(instance.derive(type=args[0]), ctx)
if args
else None
)
),
)
@register
def on_pathlike(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
if issubclass(instance.origin_type, os.PathLike):
schema = JSONSchema(
type=JSONSchemaInstanceType.STRING,
format=JSONSchemaInstanceFormatExtension.PATH,
)
for annotation in instance.annotations:
if isinstance(annotation, MaxLength):
schema.maxLength = annotation.value
elif isinstance(annotation, MinLength):
schema.minLength = annotation.value
return schema
@register
def on_enum(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
if issubclass(instance.origin_type, Enum):
return JSONSchema(enum=[m.value for m in instance.origin_type])
__all__ = ["Instance", "get_schema"]

View File

@@ -0,0 +1,68 @@
from collections.abc import Mapping
from typing import Any, Type, TypeVar, final
from mashumaro.core.meta.mixin import (
compile_mixin_packer,
compile_mixin_unpacker,
)
__all__ = ["DataClassDictMixin"]
T = TypeVar("T", bound="DataClassDictMixin")
class DataClassDictMixin:
__slots__ = ()
__mashumaro_builder_params = {"packer": {}, "unpacker": {}} # type: ignore
def __init_subclass__(cls: Type[T], **kwargs: Any):
super().__init_subclass__(**kwargs)
for ancestor in cls.__mro__[-1:0:-1]:
builder_params_ = f"_{ancestor.__name__}__mashumaro_builder_params"
builder_params = getattr(ancestor, builder_params_, None)
if builder_params:
compile_mixin_unpacker(cls, **builder_params["unpacker"])
compile_mixin_packer(cls, **builder_params["packer"])
@final
def to_dict(
self: T,
# *
# keyword-only arguments that exist with the code generation options:
# omit_none: bool = False
# by_alias: bool = False
# dialect: Type[Dialect] = None
**kwargs: Any,
) -> dict[Any, Any]: ...
@classmethod
@final
def from_dict(
cls: Type[T],
d: Mapping,
# *
# keyword-only arguments that exist with the code generation options:
# dialect: Type[Dialect] = None
**kwargs: Any,
) -> T: ...
@classmethod
def __pre_deserialize__(
cls: Type[T], d: dict[Any, Any]
) -> dict[Any, Any]: ...
@classmethod
def __post_deserialize__(cls: Type[T], obj: T) -> T: ...
def __pre_serialize__(
self: T,
# context: Any = None, # added with ADD_SERIALIZATION_CONTEXT option
) -> T: ...
def __post_serialize__(
self: T,
d: dict[Any, Any],
# context: Any = None, # added with ADD_SERIALIZATION_CONTEXT option
) -> dict[Any, Any]: ...

View File

@@ -0,0 +1,32 @@
import json
from collections.abc import Callable
from typing import Any, Type, TypeVar, Union
from mashumaro.mixins.dict import DataClassDictMixin
T = TypeVar("T", bound="DataClassJSONMixin")
EncodedData = Union[str, bytes, bytearray]
Encoder = Callable[[Any], EncodedData]
Decoder = Callable[[EncodedData], dict[Any, Any]]
class DataClassJSONMixin(DataClassDictMixin):
__slots__ = ()
def to_json(
self: T,
encoder: Encoder = json.dumps,
**to_dict_kwargs: Any,
) -> EncodedData:
return encoder(self.to_dict(**to_dict_kwargs))
@classmethod
def from_json(
cls: Type[T],
data: EncodedData,
decoder: Decoder = json.loads,
**from_dict_kwargs: Any,
) -> T:
return cls.from_dict(decoder(data), **from_dict_kwargs)

View File

@@ -0,0 +1,67 @@
from collections.abc import Callable
from typing import Any, Type, TypeVar, final
import msgpack
from mashumaro.dialect import Dialect
from mashumaro.helper import pass_through
from mashumaro.mixins.dict import DataClassDictMixin
T = TypeVar("T", bound="DataClassMessagePackMixin")
EncodedData = bytes
Encoder = Callable[[Any], EncodedData]
Decoder = Callable[[EncodedData], dict[Any, Any]]
class MessagePackDialect(Dialect):
no_copy_collections = (list, dict)
serialization_strategy = {
bytes: pass_through,
bytearray: {
"deserialize": bytearray,
"serialize": pass_through,
},
}
def default_encoder(data: Any) -> EncodedData:
return msgpack.packb(data, use_bin_type=True)
def default_decoder(data: EncodedData) -> dict[Any, Any]:
return msgpack.unpackb(data, raw=False)
class DataClassMessagePackMixin(DataClassDictMixin):
__slots__ = ()
__mashumaro_builder_params = {
"packer": {
"format_name": "msgpack",
"dialect": MessagePackDialect,
"encoder": default_encoder,
},
"unpacker": {
"format_name": "msgpack",
"dialect": MessagePackDialect,
"decoder": default_decoder,
},
}
@final
def to_msgpack(
self: T,
encoder: Encoder = default_encoder,
**to_dict_kwargs: Any,
) -> EncodedData: ...
@classmethod
@final
def from_msgpack(
cls: Type[T],
data: EncodedData,
decoder: Decoder = default_decoder,
**from_dict_kwargs: Any,
) -> T: ...

View File

@@ -0,0 +1,69 @@
from collections.abc import Callable
from datetime import date, datetime, time
from typing import Any, Type, TypeVar, Union, final
from uuid import UUID
import orjson
from mashumaro.core.helpers import ConfigValue
from mashumaro.dialect import Dialect
from mashumaro.helper import pass_through
from mashumaro.mixins.dict import DataClassDictMixin
T = TypeVar("T", bound="DataClassORJSONMixin")
EncodedData = Union[str, bytes, bytearray]
Encoder = Callable[[Any], EncodedData]
Decoder = Callable[[EncodedData], dict[Any, Any]]
class OrjsonDialect(Dialect):
no_copy_collections = (list, dict)
serialization_strategy = {
datetime: {"serialize": pass_through},
date: {"serialize": pass_through},
time: {"serialize": pass_through},
UUID: {"serialize": pass_through},
}
class DataClassORJSONMixin(DataClassDictMixin):
__slots__ = ()
__mashumaro_builder_params = {
"packer": {
"format_name": "jsonb",
"dialect": OrjsonDialect,
"encoder": orjson.dumps,
"encoder_kwargs": {
"option": ("orjson_options", ConfigValue("orjson_options")),
},
},
"unpacker": {
"format_name": "json",
"dialect": OrjsonDialect,
"decoder": orjson.loads,
},
}
@final
def to_jsonb(
self: T,
encoder: Encoder = orjson.dumps,
*,
orjson_options: int = ...,
**to_dict_kwargs: Any,
) -> bytes: ...
def to_json(self: T, **kwargs: Any) -> str:
return self.to_jsonb(**kwargs).decode()
@classmethod
@final
def from_json(
cls: Type[T],
data: EncodedData,
decoder: Decoder = orjson.loads,
**from_dict_kwargs: Any,
) -> T: ...

View File

@@ -0,0 +1,42 @@
from collections.abc import Callable
from typing import Any, Type, TypeVar, Union, final
import orjson
from mashumaro.dialect import Dialect
from mashumaro.mixins.dict import DataClassDictMixin
T = TypeVar("T", bound="DataClassORJSONMixin")
EncodedData = Union[str, bytes, bytearray]
Encoder = Callable[[Any], EncodedData]
Decoder = Callable[[EncodedData], dict[Any, Any]]
class OrjsonDialect(Dialect):
serialization_strategy: Any
class DataClassORJSONMixin(DataClassDictMixin):
__slots__ = ()
@final
def to_jsonb(
self: T,
encoder: Encoder = orjson.dumps,
*,
orjson_options: int = ...,
**to_dict_kwargs: Any,
) -> bytes: ...
def to_json(
self: T,
encoder: Encoder = orjson.dumps,
*,
orjson_options: int = ...,
**to_dict_kwargs: Any,
) -> str: ...
@classmethod
@final
def from_json(
cls: Type[T],
data: EncodedData,
decoder: Decoder = orjson.loads,
**from_dict_kwargs: Any,
) -> T: ...

View File

@@ -0,0 +1,64 @@
from collections.abc import Callable
from datetime import date, datetime, time
from typing import Any, Type, TypeVar, final
import tomli_w
from mashumaro.dialect import Dialect
from mashumaro.helper import pass_through
from mashumaro.mixins.dict import DataClassDictMixin
try:
import tomllib
except ModuleNotFoundError:
import tomli as tomllib # type: ignore
T = TypeVar("T", bound="DataClassTOMLMixin")
EncodedData = str
Encoder = Callable[[Any], EncodedData]
Decoder = Callable[[EncodedData], dict[Any, Any]]
class TOMLDialect(Dialect):
no_copy_collections = (list, dict)
omit_none = True
serialization_strategy = {
datetime: pass_through,
date: pass_through,
time: pass_through,
}
class DataClassTOMLMixin(DataClassDictMixin):
__slots__ = ()
__mashumaro_builder_params = {
"packer": {
"format_name": "toml",
"dialect": TOMLDialect,
"encoder": tomli_w.dumps,
},
"unpacker": {
"format_name": "toml",
"dialect": TOMLDialect,
"decoder": tomllib.loads,
},
}
@final
def to_toml(
self: T,
encoder: Encoder = tomli_w.dumps,
**to_dict_kwargs: Any,
) -> EncodedData: ...
@classmethod
@final
def from_toml(
cls: Type[T],
data: EncodedData,
decoder: Decoder = tomllib.loads,
**from_dict_kwargs: Any,
) -> T: ...

View File

@@ -0,0 +1,45 @@
from collections.abc import Callable
from typing import Any, Type, TypeVar, Union
import yaml
from mashumaro.mixins.dict import DataClassDictMixin
T = TypeVar("T", bound="DataClassYAMLMixin")
EncodedData = Union[str, bytes]
Encoder = Callable[[Any], EncodedData]
Decoder = Callable[[EncodedData], dict[Any, Any]]
DefaultLoader = getattr(yaml, "CSafeLoader", yaml.SafeLoader)
DefaultDumper = getattr(yaml, "CDumper", yaml.Dumper)
def default_encoder(data: Any) -> EncodedData:
return yaml.dump(data, Dumper=DefaultDumper)
def default_decoder(data: EncodedData) -> dict[Any, Any]:
return yaml.load(data, DefaultLoader)
class DataClassYAMLMixin(DataClassDictMixin):
__slots__ = ()
def to_yaml(
self: T,
encoder: Encoder = default_encoder,
**to_dict_kwargs: Any,
) -> EncodedData:
return encoder(self.to_dict(**to_dict_kwargs))
@classmethod
def from_yaml(
cls: Type[T],
data: EncodedData,
decoder: Decoder = default_decoder,
**from_dict_kwargs: Any,
) -> T:
return cls.from_dict(decoder(data), **from_dict_kwargs)

View File

@@ -0,0 +1,127 @@
import decimal
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Optional, Type, Union
from typing_extensions import Literal
from mashumaro.core.const import Sentinel
__all__ = [
"SerializableType",
"GenericSerializableType",
"SerializationStrategy",
"RoundedDecimal",
"Discriminator",
"Alias",
]
class SerializableType:
__slots__ = ()
__use_annotations__ = False
def __init_subclass__(
cls,
use_annotations: Union[
bool, Literal[Sentinel.MISSING]
] = Sentinel.MISSING,
**kwargs: Any,
):
super().__init_subclass__(**kwargs)
if use_annotations is not Sentinel.MISSING:
cls.__use_annotations__ = use_annotations
def _serialize(self) -> Any:
raise NotImplementedError
@classmethod
def _deserialize(cls, value: Any) -> Any:
raise NotImplementedError
class GenericSerializableType:
__slots__ = ()
def _serialize(self, types: list[Type]) -> Any:
raise NotImplementedError
@classmethod
def _deserialize(cls, value: Any, types: list[Type]) -> Any:
raise NotImplementedError
class SerializationStrategy:
__use_annotations__ = False
def __init_subclass__(
cls,
use_annotations: Union[
bool, Literal[Sentinel.MISSING]
] = Sentinel.MISSING,
**kwargs: Any,
):
super().__init_subclass__(**kwargs)
if use_annotations is not Sentinel.MISSING:
cls.__use_annotations__ = use_annotations
def serialize(self, value: Any) -> Any:
raise NotImplementedError
def deserialize(self, value: Any) -> Any:
raise NotImplementedError
class RoundedDecimal(SerializationStrategy):
def __init__(
self, places: Optional[int] = None, rounding: Optional[str] = None
):
if places is not None:
self.exp = decimal.Decimal((0, (1,), -places))
else:
self.exp = None # type: ignore
self.rounding = rounding
def serialize(self, value: decimal.Decimal) -> str:
if self.exp:
if self.rounding:
return str(value.quantize(self.exp, rounding=self.rounding))
else:
return str(value.quantize(self.exp))
else:
return str(value)
def deserialize(self, value: str) -> decimal.Decimal:
return decimal.Decimal(str(value))
@dataclass(unsafe_hash=True)
class Discriminator:
field: Optional[str] = None
include_supertypes: bool = False
include_subtypes: bool = False
variant_tagger_fn: Optional[Callable[[Any], Any]] = None
def __post_init__(self) -> None:
if not self.include_supertypes and not self.include_subtypes:
raise ValueError(
"Either 'include_supertypes' or 'include_subtypes' "
"must be enabled"
)
class Alias:
def __init__(self, name: str, /):
self.name = name
def __repr__(self) -> str:
return f"Alias(name='{self.name}')"
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Alias):
return False
return self.name == other.name
def __hash__(self) -> int:
return hash(self.name)