some random stuff. caelestia incoming
This commit is contained in:
@@ -0,0 +1,10 @@
|
||||
from mashumaro.exceptions import MissingField
|
||||
from mashumaro.helper import field_options, pass_through
|
||||
from mashumaro.mixins.dict import DataClassDictMixin
|
||||
|
||||
__all__ = [
|
||||
"MissingField",
|
||||
"DataClassDictMixin",
|
||||
"field_options",
|
||||
"pass_through",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,6 @@
|
||||
from .basic import BasicDecoder, BasicEncoder
|
||||
|
||||
__all__ = [
|
||||
"BasicDecoder",
|
||||
"BasicEncoder",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,106 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from mashumaro.core.meta.code.builder import CodeBuilder
|
||||
from mashumaro.core.meta.helpers import is_optional, is_type_var_any
|
||||
from mashumaro.core.meta.types.common import (
|
||||
AttrsHolder,
|
||||
FieldContext,
|
||||
ValueSpec,
|
||||
)
|
||||
from mashumaro.core.meta.types.pack import PackerRegistry
|
||||
from mashumaro.core.meta.types.unpack import UnpackerRegistry
|
||||
|
||||
CALL_EXPR = re.compile(r"^([^ ]+)\(value\)$")
|
||||
|
||||
|
||||
class CodecCodeBuilder(CodeBuilder):
|
||||
@classmethod
|
||||
def new(cls, **kwargs: Any) -> "CodecCodeBuilder":
|
||||
if "attrs" not in kwargs:
|
||||
kwargs["attrs"] = AttrsHolder()
|
||||
return cls(AttrsHolder("__root__"), **kwargs) # type: ignore
|
||||
|
||||
def add_decode_method(
|
||||
self,
|
||||
shape_type: Type,
|
||||
decoder_obj: Any,
|
||||
pre_decoder_func: Optional[Callable[[Any], Any]] = None,
|
||||
) -> None:
|
||||
self.reset()
|
||||
with self.indent("def decode(value):"):
|
||||
if pre_decoder_func:
|
||||
self.ensure_object_imported(pre_decoder_func, "decoder")
|
||||
self.add_line("value = decoder(value)")
|
||||
could_be_none = (
|
||||
shape_type in (Any, type(None), None)
|
||||
or is_type_var_any(self.get_real_type("", shape_type))
|
||||
or is_optional(
|
||||
shape_type, self.get_field_resolved_type_params("")
|
||||
)
|
||||
)
|
||||
unpacked_value = UnpackerRegistry.get(
|
||||
ValueSpec(
|
||||
type=shape_type,
|
||||
expression="value",
|
||||
builder=self,
|
||||
field_ctx=FieldContext(name="", metadata={}),
|
||||
could_be_none=could_be_none,
|
||||
)
|
||||
)
|
||||
self.add_line(f"return {unpacked_value}")
|
||||
self.add_line("setattr(decoder_obj, 'decode', decode)")
|
||||
if pre_decoder_func is None:
|
||||
m = CALL_EXPR.match(unpacked_value)
|
||||
if m:
|
||||
method_name = m.group(1)
|
||||
self.lines.reset()
|
||||
self.add_line(f"setattr(decoder_obj, 'decode', {method_name})")
|
||||
self.ensure_object_imported(decoder_obj, "decoder_obj")
|
||||
self.ensure_object_imported(self.cls, "cls")
|
||||
self.compile()
|
||||
|
||||
def add_encode_method(
|
||||
self,
|
||||
shape_type: Type,
|
||||
encoder_obj: Any,
|
||||
post_encoder_func: Optional[Callable[[Any], Any]] = None,
|
||||
) -> None:
|
||||
self.reset()
|
||||
with self.indent("def encode(value):"):
|
||||
could_be_none = (
|
||||
shape_type in (Any, type(None), None)
|
||||
or is_type_var_any(self.get_real_type("", shape_type))
|
||||
or is_optional(
|
||||
shape_type, self.get_field_resolved_type_params("")
|
||||
)
|
||||
)
|
||||
packed_value = PackerRegistry.get(
|
||||
ValueSpec(
|
||||
type=shape_type,
|
||||
expression="value",
|
||||
builder=self,
|
||||
field_ctx=FieldContext(name="", metadata={}),
|
||||
could_be_none=could_be_none,
|
||||
no_copy_collections=self.get_dialect_or_config_option(
|
||||
"no_copy_collections", ()
|
||||
),
|
||||
)
|
||||
)
|
||||
if post_encoder_func:
|
||||
self.ensure_object_imported(post_encoder_func, "encoder")
|
||||
self.add_line(f"return encoder({packed_value})")
|
||||
else:
|
||||
self.add_line(f"return {packed_value}")
|
||||
self.add_line("setattr(encoder_obj, 'encode', encode)")
|
||||
if post_encoder_func is None:
|
||||
m = CALL_EXPR.match(packed_value)
|
||||
if m:
|
||||
method_name = m.group(1)
|
||||
self.lines.reset()
|
||||
self.add_line(f"setattr(encoder_obj, 'encode', {method_name})")
|
||||
self.ensure_object_imported(encoder_obj, "encoder_obj")
|
||||
self.ensure_object_imported(self.cls, "cls")
|
||||
self.ensure_object_imported(self.cls, "self")
|
||||
self.compile()
|
||||
@@ -0,0 +1,103 @@
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
final,
|
||||
overload,
|
||||
)
|
||||
|
||||
from mashumaro.codecs._builder import CodecCodeBuilder
|
||||
from mashumaro.core.meta.helpers import get_args
|
||||
from mashumaro.dialect import Dialect
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BasicDecoder(Generic[T]):
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Type[T],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
pre_decoder_func: Optional[Callable[[Any], Any]] = None,
|
||||
): ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Any,
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
pre_decoder_func: Optional[Callable[[Any], Any]] = None,
|
||||
): ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Union[Type[T], Any],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
pre_decoder_func: Optional[Callable[[Any], Any]] = None,
|
||||
):
|
||||
code_builder = CodecCodeBuilder.new(
|
||||
type_args=get_args(shape_type), default_dialect=default_dialect
|
||||
)
|
||||
code_builder.add_decode_method(shape_type, self, pre_decoder_func)
|
||||
|
||||
@final
|
||||
def decode(self, data: Any) -> T: ...
|
||||
|
||||
|
||||
class BasicEncoder(Generic[T]):
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Type[T],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
post_encoder_func: Optional[Callable[[Any], Any]] = None,
|
||||
): ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Any,
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
post_encoder_func: Optional[Callable[[Any], Any]] = None,
|
||||
): ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Union[Type[T], Any],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
post_encoder_func: Optional[Callable[[Any], Any]] = None,
|
||||
):
|
||||
code_builder = CodecCodeBuilder.new(
|
||||
type_args=get_args(shape_type), default_dialect=default_dialect
|
||||
)
|
||||
code_builder.add_encode_method(shape_type, self, post_encoder_func)
|
||||
|
||||
@final
|
||||
def encode(self, obj: T) -> Any: ...
|
||||
|
||||
|
||||
def decode(data: Any, shape_type: Union[Type[T], Any]) -> T:
|
||||
return BasicDecoder(shape_type).decode(data)
|
||||
|
||||
|
||||
def encode(obj: T, shape_type: Union[Type[T], Any]) -> Any:
|
||||
return BasicEncoder(shape_type).encode(obj)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BasicDecoder",
|
||||
"BasicEncoder",
|
||||
"decode",
|
||||
"encode",
|
||||
]
|
||||
@@ -0,0 +1,123 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
final,
|
||||
overload,
|
||||
)
|
||||
|
||||
from mashumaro.codecs._builder import CodecCodeBuilder
|
||||
from mashumaro.core.meta.helpers import get_args
|
||||
from mashumaro.dialect import Dialect
|
||||
|
||||
T = TypeVar("T")
|
||||
EncodedData = Union[str, bytes, bytearray]
|
||||
|
||||
|
||||
class JSONDecoder(Generic[T]):
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Type[T],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
pre_decoder_func: Callable[[EncodedData], Any] = json.loads,
|
||||
): ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Any,
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
pre_decoder_func: Callable[[EncodedData], Any] = json.loads,
|
||||
): ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Union[Type[T], Any],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
pre_decoder_func: Callable[[EncodedData], Any] = json.loads,
|
||||
):
|
||||
code_builder = CodecCodeBuilder.new(
|
||||
type_args=get_args(shape_type), default_dialect=default_dialect
|
||||
)
|
||||
code_builder.add_decode_method(shape_type, self, pre_decoder_func)
|
||||
|
||||
@final
|
||||
def decode(self, data: EncodedData) -> T: ...
|
||||
|
||||
|
||||
class JSONEncoder(Generic[T]):
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Type[T],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
post_encoder_func: Callable[[Any], str] = json.dumps,
|
||||
): ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Any,
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
post_encoder_func: Callable[[Any], str] = json.dumps,
|
||||
): ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Union[Type[T], Any],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
post_encoder_func: Callable[[Any], str] = json.dumps,
|
||||
):
|
||||
code_builder = CodecCodeBuilder.new(
|
||||
type_args=get_args(shape_type), default_dialect=default_dialect
|
||||
)
|
||||
code_builder.add_encode_method(shape_type, self, post_encoder_func)
|
||||
|
||||
@final
|
||||
def encode(self, obj: T) -> str: ...
|
||||
|
||||
|
||||
def json_decode(
|
||||
data: EncodedData,
|
||||
shape_type: Union[Type[T], Any],
|
||||
pre_decoder_func: Callable[[EncodedData], Any] = json.loads,
|
||||
) -> T:
|
||||
return JSONDecoder(shape_type, pre_decoder_func=pre_decoder_func).decode(
|
||||
data
|
||||
)
|
||||
|
||||
|
||||
def json_encode(
|
||||
obj: T,
|
||||
shape_type: Union[Type[T], Any],
|
||||
post_encoder_func: Callable[[Any], str] = json.dumps,
|
||||
) -> str:
|
||||
return JSONEncoder(shape_type, post_encoder_func=post_encoder_func).encode(
|
||||
obj
|
||||
)
|
||||
|
||||
|
||||
decode = json_decode
|
||||
encode = json_encode
|
||||
|
||||
|
||||
__all__ = [
|
||||
"JSONDecoder",
|
||||
"JSONEncoder",
|
||||
"json_decode",
|
||||
"json_encode",
|
||||
"decode",
|
||||
"encode",
|
||||
]
|
||||
@@ -0,0 +1,132 @@
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
final,
|
||||
overload,
|
||||
)
|
||||
|
||||
import msgpack
|
||||
|
||||
from mashumaro.codecs._builder import CodecCodeBuilder
|
||||
from mashumaro.core.meta.helpers import get_args
|
||||
from mashumaro.dialect import Dialect
|
||||
from mashumaro.mixins.msgpack import MessagePackDialect
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
EncodedData = bytes
|
||||
PostEncoderFunc = Callable[[Any], EncodedData]
|
||||
PreDecoderFunc = Callable[[EncodedData], Any]
|
||||
|
||||
|
||||
def _default_decoder(data: EncodedData) -> Any:
|
||||
return msgpack.unpackb(data, raw=False)
|
||||
|
||||
|
||||
def _default_encoder(data: Any) -> EncodedData:
|
||||
return msgpack.packb(data, use_bin_type=True)
|
||||
|
||||
|
||||
class MessagePackDecoder(Generic[T]):
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Type[T],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
pre_decoder_func: Optional[PreDecoderFunc] = _default_decoder,
|
||||
): ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Any,
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
pre_decoder_func: Optional[PreDecoderFunc] = _default_decoder,
|
||||
): ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Union[Type[T], Any],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
pre_decoder_func: Optional[PreDecoderFunc] = _default_decoder,
|
||||
):
|
||||
if default_dialect is not None:
|
||||
default_dialect = MessagePackDialect.merge(default_dialect)
|
||||
else:
|
||||
default_dialect = MessagePackDialect
|
||||
code_builder = CodecCodeBuilder.new(
|
||||
type_args=get_args(shape_type), default_dialect=default_dialect
|
||||
)
|
||||
code_builder.add_decode_method(shape_type, self, pre_decoder_func)
|
||||
|
||||
@final
|
||||
def decode(self, data: EncodedData) -> T: ...
|
||||
|
||||
|
||||
class MessagePackEncoder(Generic[T]):
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Type[T],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
post_encoder_func: Optional[PostEncoderFunc] = _default_encoder,
|
||||
): ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Any,
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
post_encoder_func: Optional[PostEncoderFunc] = _default_encoder,
|
||||
): ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Union[Type[T], Any],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
post_encoder_func: Optional[PostEncoderFunc] = _default_encoder,
|
||||
):
|
||||
if default_dialect is not None:
|
||||
default_dialect = MessagePackDialect.merge(default_dialect)
|
||||
else:
|
||||
default_dialect = MessagePackDialect
|
||||
code_builder = CodecCodeBuilder.new(
|
||||
type_args=get_args(shape_type), default_dialect=default_dialect
|
||||
)
|
||||
code_builder.add_encode_method(shape_type, self, post_encoder_func)
|
||||
|
||||
@final
|
||||
def encode(self, obj: T) -> EncodedData: ...
|
||||
|
||||
|
||||
def msgpack_decode(data: EncodedData, shape_type: Union[Type[T], Any]) -> T:
|
||||
return MessagePackDecoder(shape_type).decode(data)
|
||||
|
||||
|
||||
def msgpack_encode(obj: T, shape_type: Union[Type[T], Any]) -> EncodedData:
|
||||
return MessagePackEncoder(shape_type).encode(obj)
|
||||
|
||||
|
||||
decode = msgpack_decode
|
||||
encode = msgpack_encode
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MessagePackDecoder",
|
||||
"MessagePackEncoder",
|
||||
"msgpack_decode",
|
||||
"msgpack_encode",
|
||||
"decode",
|
||||
"encode",
|
||||
]
|
||||
@@ -0,0 +1,114 @@
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
final,
|
||||
overload,
|
||||
)
|
||||
|
||||
import orjson
|
||||
|
||||
from mashumaro.codecs._builder import CodecCodeBuilder
|
||||
from mashumaro.core.meta.helpers import get_args
|
||||
from mashumaro.dialect import Dialect
|
||||
from mashumaro.mixins.orjson import OrjsonDialect
|
||||
|
||||
T = TypeVar("T")
|
||||
EncodedData = Union[str, bytes, bytearray]
|
||||
|
||||
|
||||
class ORJSONDecoder(Generic[T]):
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Type[T],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
): ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Any,
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
): ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Union[Type[T], Any],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
):
|
||||
if default_dialect is not None:
|
||||
default_dialect = OrjsonDialect.merge(default_dialect)
|
||||
else:
|
||||
default_dialect = OrjsonDialect
|
||||
code_builder = CodecCodeBuilder.new(
|
||||
type_args=get_args(shape_type), default_dialect=default_dialect
|
||||
)
|
||||
code_builder.add_decode_method(shape_type, self, orjson.loads)
|
||||
|
||||
@final
|
||||
def decode(self, data: EncodedData) -> T: ...
|
||||
|
||||
|
||||
class ORJSONEncoder(Generic[T]):
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Type[T],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
): ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Any,
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
): ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Union[Type[T], Any],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
):
|
||||
if default_dialect is not None:
|
||||
default_dialect = OrjsonDialect.merge(default_dialect)
|
||||
else:
|
||||
default_dialect = OrjsonDialect
|
||||
code_builder = CodecCodeBuilder.new(
|
||||
type_args=get_args(shape_type), default_dialect=default_dialect
|
||||
)
|
||||
code_builder.add_encode_method(shape_type, self, orjson.dumps)
|
||||
|
||||
@final
|
||||
def encode(self, obj: T) -> bytes: ...
|
||||
|
||||
|
||||
def json_decode(data: EncodedData, shape_type: Type[T]) -> T:
|
||||
return ORJSONDecoder(shape_type).decode(data)
|
||||
|
||||
|
||||
def json_encode(obj: T, shape_type: Union[Type[T], Any]) -> bytes:
|
||||
return ORJSONEncoder(shape_type).encode(obj)
|
||||
|
||||
|
||||
decode = json_decode
|
||||
encode = json_encode
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ORJSONDecoder",
|
||||
"ORJSONEncoder",
|
||||
"json_decode",
|
||||
"json_encode",
|
||||
"decode",
|
||||
"encode",
|
||||
]
|
||||
@@ -0,0 +1,119 @@
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
final,
|
||||
overload,
|
||||
)
|
||||
|
||||
import tomli_w
|
||||
|
||||
from mashumaro.codecs._builder import CodecCodeBuilder
|
||||
from mashumaro.core.meta.helpers import get_args
|
||||
from mashumaro.dialect import Dialect
|
||||
from mashumaro.mixins.toml import TOMLDialect
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ModuleNotFoundError:
|
||||
import tomli as tomllib # type: ignore
|
||||
|
||||
T = TypeVar("T")
|
||||
EncodedData = str
|
||||
|
||||
|
||||
class TOMLDecoder(Generic[T]):
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Type[T],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
): ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Any,
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
): ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Union[Type[T], Any],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
):
|
||||
if default_dialect is not None:
|
||||
default_dialect = TOMLDialect.merge(default_dialect)
|
||||
else:
|
||||
default_dialect = TOMLDialect
|
||||
code_builder = CodecCodeBuilder.new(
|
||||
type_args=get_args(shape_type), default_dialect=default_dialect
|
||||
)
|
||||
code_builder.add_decode_method(shape_type, self, tomllib.loads)
|
||||
|
||||
@final
|
||||
def decode(self, data: EncodedData) -> T: ...
|
||||
|
||||
|
||||
class TOMLEncoder(Generic[T]):
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Type[T],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
): ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Any,
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
): ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Union[Type[T], Any],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
):
|
||||
if default_dialect is not None:
|
||||
default_dialect = TOMLDialect.merge(default_dialect)
|
||||
else:
|
||||
default_dialect = TOMLDialect
|
||||
code_builder = CodecCodeBuilder.new(
|
||||
type_args=get_args(shape_type), default_dialect=default_dialect
|
||||
)
|
||||
code_builder.add_encode_method(shape_type, self, tomli_w.dumps)
|
||||
|
||||
@final
|
||||
def encode(self, obj: T) -> bytes: ...
|
||||
|
||||
|
||||
def toml_decode(data: EncodedData, shape_type: Type[T]) -> T:
|
||||
return TOMLDecoder(shape_type).decode(data)
|
||||
|
||||
|
||||
def toml_encode(obj: T, shape_type: Union[Type[T], Any]) -> bytes:
|
||||
return TOMLEncoder(shape_type).encode(obj)
|
||||
|
||||
|
||||
decode = toml_decode
|
||||
encode = toml_encode
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TOMLDecoder",
|
||||
"TOMLEncoder",
|
||||
"toml_decode",
|
||||
"toml_encode",
|
||||
"decode",
|
||||
"encode",
|
||||
]
|
||||
@@ -0,0 +1,128 @@
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
final,
|
||||
overload,
|
||||
)
|
||||
|
||||
import yaml
|
||||
|
||||
from mashumaro.codecs._builder import CodecCodeBuilder
|
||||
from mashumaro.core.meta.helpers import get_args
|
||||
from mashumaro.dialect import Dialect
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
EncodedData = Union[str, bytes]
|
||||
PostEncoderFunc = Callable[[Any], EncodedData]
|
||||
PreDecoderFunc = Callable[[EncodedData], Any]
|
||||
|
||||
|
||||
DefaultLoader = getattr(yaml, "CSafeLoader", yaml.SafeLoader)
|
||||
DefaultDumper = getattr(yaml, "CDumper", yaml.Dumper)
|
||||
|
||||
|
||||
def _default_encoder(data: Any) -> EncodedData:
|
||||
return yaml.dump(data, Dumper=DefaultDumper)
|
||||
|
||||
|
||||
def _default_decoder(data: EncodedData) -> Any:
|
||||
return yaml.load(data, DefaultLoader)
|
||||
|
||||
|
||||
class YAMLDecoder(Generic[T]):
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Type[T],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
pre_decoder_func: Optional[PreDecoderFunc] = _default_decoder,
|
||||
): ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Any,
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
pre_decoder_func: Optional[PreDecoderFunc] = _default_decoder,
|
||||
): ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Union[Type[T], Any],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
pre_decoder_func: Optional[PreDecoderFunc] = _default_decoder,
|
||||
):
|
||||
code_builder = CodecCodeBuilder.new(
|
||||
type_args=get_args(shape_type), default_dialect=default_dialect
|
||||
)
|
||||
code_builder.add_decode_method(shape_type, self, pre_decoder_func)
|
||||
|
||||
@final
|
||||
def decode(self, data: EncodedData) -> T: ...
|
||||
|
||||
|
||||
class YAMLEncoder(Generic[T]):
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Type[T],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
post_encoder_func: Optional[PostEncoderFunc] = _default_encoder,
|
||||
): ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Any,
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
post_encoder_func: Optional[PostEncoderFunc] = _default_encoder,
|
||||
): ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape_type: Union[Type[T], Any],
|
||||
*,
|
||||
default_dialect: Optional[Type[Dialect]] = None,
|
||||
post_encoder_func: Optional[PostEncoderFunc] = _default_encoder,
|
||||
):
|
||||
code_builder = CodecCodeBuilder.new(
|
||||
type_args=get_args(shape_type), default_dialect=default_dialect
|
||||
)
|
||||
code_builder.add_encode_method(shape_type, self, post_encoder_func)
|
||||
|
||||
@final
|
||||
def encode(self, obj: T) -> EncodedData: ...
|
||||
|
||||
|
||||
def yaml_decode(data: EncodedData, shape_type: Union[Type[T], Any]) -> T:
|
||||
return YAMLDecoder(shape_type).decode(data)
|
||||
|
||||
|
||||
def yaml_encode(obj: T, shape_type: Union[Type[T], Any]) -> EncodedData:
|
||||
return YAMLEncoder(shape_type).encode(obj)
|
||||
|
||||
|
||||
decode = yaml_decode
|
||||
encode = yaml_encode
|
||||
|
||||
|
||||
__all__ = [
|
||||
"YAMLDecoder",
|
||||
"YAMLEncoder",
|
||||
"yaml_decode",
|
||||
"yaml_encode",
|
||||
"decode",
|
||||
"encode",
|
||||
]
|
||||
@@ -0,0 +1,63 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal, Optional, Type, TypedDict, Union
|
||||
|
||||
from mashumaro.core.const import Sentinel
|
||||
from mashumaro.dialect import Dialect
|
||||
from mashumaro.types import Discriminator, SerializationStrategy
|
||||
|
||||
__all__ = [
|
||||
"BaseConfig",
|
||||
"TO_DICT_ADD_BY_ALIAS_FLAG",
|
||||
"TO_DICT_ADD_OMIT_NONE_FLAG",
|
||||
"ADD_DIALECT_SUPPORT",
|
||||
"ADD_SERIALIZATION_CONTEXT",
|
||||
"SerializationStrategyValueType",
|
||||
]
|
||||
|
||||
|
||||
TO_DICT_ADD_BY_ALIAS_FLAG = "TO_DICT_ADD_BY_ALIAS_FLAG"
|
||||
TO_DICT_ADD_OMIT_NONE_FLAG = "TO_DICT_ADD_OMIT_NONE_FLAG"
|
||||
ADD_DIALECT_SUPPORT = "ADD_DIALECT_SUPPORT"
|
||||
ADD_SERIALIZATION_CONTEXT = "ADD_SERIALIZATION_CONTEXT"
|
||||
|
||||
|
||||
CodeGenerationOption = Literal[
|
||||
"TO_DICT_ADD_BY_ALIAS_FLAG",
|
||||
"TO_DICT_ADD_OMIT_NONE_FLAG",
|
||||
"ADD_DIALECT_SUPPORT",
|
||||
"ADD_SERIALIZATION_CONTEXT",
|
||||
]
|
||||
|
||||
|
||||
class SerializationStrategyDict(TypedDict, total=False):
|
||||
serialize: Union[str, Callable]
|
||||
deserialize: Union[str, Callable]
|
||||
|
||||
|
||||
SerializationStrategyValueType = Union[
|
||||
SerializationStrategy, SerializationStrategyDict
|
||||
]
|
||||
|
||||
|
||||
class BaseConfig:
|
||||
debug: bool = False
|
||||
code_generation_options: list[CodeGenerationOption] = []
|
||||
serialization_strategy: dict[Any, SerializationStrategyValueType] = {}
|
||||
aliases: dict[str, str] = {}
|
||||
serialize_by_alias: Union[bool, Literal[Sentinel.MISSING]] = (
|
||||
Sentinel.MISSING
|
||||
)
|
||||
namedtuple_as_dict: Union[bool, Literal[Sentinel.MISSING]] = (
|
||||
Sentinel.MISSING
|
||||
)
|
||||
allow_postponed_evaluation: bool = True
|
||||
dialect: Optional[Type[Dialect]] = None
|
||||
omit_none: Union[bool, Literal[Sentinel.MISSING]] = Sentinel.MISSING
|
||||
omit_default: Union[bool, Literal[Sentinel.MISSING]] = Sentinel.MISSING
|
||||
orjson_options: Optional[int] = 0
|
||||
json_schema: dict[str, Any] = {}
|
||||
discriminator: Optional[Discriminator] = None
|
||||
lazy_compilation: bool = False
|
||||
sort_keys: bool = False
|
||||
allow_deserialization_not_by_alias: bool = False
|
||||
forbid_extra_keys: bool = False
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,27 @@
|
||||
import enum
|
||||
import sys
|
||||
|
||||
__all__ = [
|
||||
"PY_39",
|
||||
"PY_310",
|
||||
"PY_310_MIN",
|
||||
"PY_311_MIN",
|
||||
"PY_312_MIN",
|
||||
"PY_313_MIN",
|
||||
"Sentinel",
|
||||
]
|
||||
|
||||
|
||||
PY_39 = sys.version_info.major == 3 and sys.version_info.minor == 9
|
||||
PY_310 = sys.version_info.major == 3 and sys.version_info.minor == 10
|
||||
PY_311 = sys.version_info.major == 3 and sys.version_info.minor == 11
|
||||
PY_312 = sys.version_info.major == 3 and sys.version_info.minor == 12
|
||||
PY_313_MIN = sys.version_info.major == 3 and sys.version_info.minor >= 13
|
||||
|
||||
PY_312_MIN = PY_312 or PY_313_MIN
|
||||
PY_311_MIN = PY_311 or PY_312_MIN
|
||||
PY_310_MIN = PY_310 or PY_311_MIN
|
||||
|
||||
|
||||
class Sentinel(enum.Enum):
|
||||
MISSING = enum.auto()
|
||||
@@ -0,0 +1,35 @@
|
||||
import datetime
|
||||
import re
|
||||
|
||||
__all__ = [
|
||||
"parse_timezone",
|
||||
"ConfigValue",
|
||||
"UTC_OFFSET_PATTERN",
|
||||
]
|
||||
|
||||
|
||||
UTC_OFFSET_PATTERN = r"^UTC(([+-][0-2][0-9]):([0-5][0-9]))?$"
|
||||
UTC_OFFSET_RE = re.compile(UTC_OFFSET_PATTERN)
|
||||
|
||||
|
||||
def parse_timezone(s: str) -> datetime.timezone:
|
||||
match = UTC_OFFSET_RE.match(s)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"Time zone {s} must be either UTC or in format UTC[+-]hh:mm"
|
||||
)
|
||||
if match.group(1):
|
||||
hours = int(match.group(2))
|
||||
minutes = int(match.group(3))
|
||||
return datetime.timezone(
|
||||
datetime.timedelta(
|
||||
hours=hours, minutes=minutes if hours >= 0 else -minutes
|
||||
)
|
||||
)
|
||||
else:
|
||||
return datetime.timezone.utc
|
||||
|
||||
|
||||
class ConfigValue:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,38 @@
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
__all__ = ["CodeLines"]
|
||||
|
||||
|
||||
class CodeLines:
|
||||
def __init__(self) -> None:
|
||||
self._lines: list[str] = []
|
||||
self._current_indent: str = ""
|
||||
|
||||
def append(self, line: str) -> None:
|
||||
self._lines.append(f"{self._current_indent}{line}")
|
||||
|
||||
def extend(self, lines: "CodeLines") -> None:
|
||||
for line in lines._lines:
|
||||
self._lines.append(f"{self._current_indent}{line}")
|
||||
|
||||
@contextmanager
|
||||
def indent(
|
||||
self,
|
||||
expr: Optional[str] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
if expr:
|
||||
self.append(expr)
|
||||
self._current_indent += " " * 4
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._current_indent = self._current_indent[:-4]
|
||||
|
||||
def as_text(self) -> str:
|
||||
return "\n".join(self._lines)
|
||||
|
||||
def reset(self) -> None:
|
||||
self._lines = []
|
||||
self._current_indent = ""
|
||||
@@ -0,0 +1,812 @@
|
||||
import dataclasses
|
||||
import enum
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
from collections.abc import Callable, Hashable, Iterable, Iterator
|
||||
from contextlib import suppress
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
from dataclasses import _FIELDS # type: ignore
|
||||
from hashlib import md5
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
ForwardRef,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
try:
|
||||
from typing import Unpack # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
from typing_extensions import Unpack
|
||||
|
||||
import typing_extensions
|
||||
|
||||
from mashumaro.core.const import (
|
||||
PY_39,
|
||||
PY_310_MIN,
|
||||
PY_311_MIN,
|
||||
PY_312_MIN,
|
||||
PY_313_MIN,
|
||||
)
|
||||
from mashumaro.dialect import Dialect
|
||||
|
||||
__all__ = [
|
||||
"get_type_origin",
|
||||
"get_args",
|
||||
"type_name",
|
||||
"is_special_typing_primitive",
|
||||
"is_generic",
|
||||
"is_typed_dict",
|
||||
"is_named_tuple",
|
||||
"is_optional",
|
||||
"is_union",
|
||||
"not_none_type_arg",
|
||||
"is_type_var",
|
||||
"is_type_var_any",
|
||||
"is_class_var",
|
||||
"is_final",
|
||||
"is_init_var",
|
||||
"get_class_that_defines_method",
|
||||
"get_class_that_defines_field",
|
||||
"is_dataclass_dict_mixin",
|
||||
"is_dataclass_dict_mixin_subclass",
|
||||
"collect_type_params",
|
||||
"resolve_type_params",
|
||||
"substitute_type_params",
|
||||
"get_generic_name",
|
||||
"get_name_error_name",
|
||||
"is_dialect_subclass",
|
||||
"is_new_type",
|
||||
"is_annotated",
|
||||
"get_type_annotations",
|
||||
"is_literal",
|
||||
"is_local_type_name",
|
||||
"get_literal_values",
|
||||
"is_self",
|
||||
"is_required",
|
||||
"is_not_required",
|
||||
"get_function_arg_annotation",
|
||||
"get_function_return_annotation",
|
||||
"is_unpack",
|
||||
"is_type_var_tuple",
|
||||
"hash_type_args",
|
||||
"iter_all_subclasses",
|
||||
"is_hashable",
|
||||
"is_hashable_type",
|
||||
"evaluate_forward_ref",
|
||||
"get_forward_ref_referencing_globals",
|
||||
"is_type_alias_type",
|
||||
]
|
||||
|
||||
|
||||
NoneType = type(None)
|
||||
DataClassDictMixinPath = (
|
||||
f"{__name__.rsplit('.', 3)[:-3][0]}.mixins.dict.DataClassDictMixin"
|
||||
)
|
||||
|
||||
|
||||
def get_type_origin(typ: Type) -> Type:
|
||||
try:
|
||||
return typ.__origin__
|
||||
except AttributeError:
|
||||
return typ
|
||||
|
||||
|
||||
def is_builtin_type(typ: Type) -> bool:
|
||||
try:
|
||||
return typ.__module__ == "builtins"
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
||||
def get_generic_name(typ: Type, short: bool = False) -> str:
|
||||
name = getattr(typ, "_name", None)
|
||||
if name is None:
|
||||
origin = get_type_origin(typ)
|
||||
if origin is typ:
|
||||
return type_name(origin, short, is_type_origin=True)
|
||||
else:
|
||||
return get_generic_name(origin, short)
|
||||
if short:
|
||||
return name
|
||||
else:
|
||||
return f"{typ.__module__}.{name}"
|
||||
|
||||
|
||||
def get_args(typ: Optional[Type]) -> tuple[Type, ...]:
|
||||
return getattr(typ, "__args__", ())
|
||||
|
||||
|
||||
def _get_args_str(
|
||||
typ: Type,
|
||||
short: bool,
|
||||
resolved_type_params: Optional[dict[Type, Type]] = None,
|
||||
limit: Optional[int] = None,
|
||||
none_type_as_none: bool = False,
|
||||
sep: str = ", ",
|
||||
) -> str:
|
||||
if typ == Tuple[()]:
|
||||
return "()"
|
||||
elif typ == tuple[()]:
|
||||
return "()"
|
||||
args = _flatten_type_args(get_args(typ)[:limit])
|
||||
to_join = []
|
||||
for arg in args:
|
||||
to_join.append(
|
||||
type_name(
|
||||
typ=arg,
|
||||
short=short,
|
||||
resolved_type_params=resolved_type_params,
|
||||
none_type_as_none=none_type_as_none,
|
||||
)
|
||||
)
|
||||
if len(to_join) > 1:
|
||||
return sep.join(s for s in to_join if s != "()")
|
||||
else:
|
||||
return sep.join(to_join)
|
||||
|
||||
|
||||
def get_literal_values(typ: Type) -> tuple[Any, ...]:
|
||||
values = typ.__args__
|
||||
result: list[Any] = []
|
||||
for value in values:
|
||||
if is_literal(value):
|
||||
result.extend(get_literal_values(value))
|
||||
else:
|
||||
result.append(value)
|
||||
return tuple(result)
|
||||
|
||||
|
||||
def _get_literal_values_str(typ: Type, short: bool) -> str:
|
||||
values_str = []
|
||||
for value in get_literal_values(typ):
|
||||
if isinstance(value, enum.Enum):
|
||||
values_str.append(f"{type_name(type(value), short)}.{value.name}")
|
||||
elif isinstance(
|
||||
value,
|
||||
(int, str, bytes, bool, NoneType), # type: ignore
|
||||
):
|
||||
values_str.append(repr(value))
|
||||
return ", ".join(values_str)
|
||||
|
||||
|
||||
def _typing_name(
|
||||
typ_name: str,
|
||||
short: bool = False,
|
||||
module_name: str = "typing",
|
||||
) -> str:
|
||||
return typ_name if short else f"{module_name}.{typ_name}"
|
||||
|
||||
|
||||
def type_name(
|
||||
typ: Optional[Type],
|
||||
short: bool = False,
|
||||
resolved_type_params: Optional[dict[Type, Type]] = None,
|
||||
is_type_origin: bool = False,
|
||||
none_type_as_none: bool = False,
|
||||
) -> str:
|
||||
if resolved_type_params is None:
|
||||
resolved_type_params = {}
|
||||
if typ is None:
|
||||
return "None"
|
||||
elif typ is NoneType and none_type_as_none:
|
||||
return "None"
|
||||
elif typ is Ellipsis:
|
||||
return "..."
|
||||
elif typ is Any:
|
||||
return _typing_name("Any", short)
|
||||
elif is_optional(typ, resolved_type_params):
|
||||
args_str = type_name(
|
||||
typ=not_none_type_arg(get_args(typ), resolved_type_params),
|
||||
short=short,
|
||||
resolved_type_params=resolved_type_params,
|
||||
)
|
||||
return f"{_typing_name('Optional', short)}[{args_str}]"
|
||||
elif is_union(typ):
|
||||
args_str = _get_args_str(
|
||||
typ, short, resolved_type_params, none_type_as_none=True
|
||||
)
|
||||
return f"{_typing_name('Union', short)}[{args_str}]"
|
||||
elif is_annotated(typ):
|
||||
return type_name(get_args(typ)[0], short, resolved_type_params)
|
||||
elif not is_type_origin and is_literal(typ):
|
||||
args_str = _get_literal_values_str(typ, short)
|
||||
return f"{_typing_name('Literal', short, typ.__module__)}[{args_str}]"
|
||||
elif not is_type_origin and is_unpack(typ):
|
||||
if (
|
||||
typ in resolved_type_params
|
||||
and resolved_type_params[typ] is not typ
|
||||
):
|
||||
return type_name(
|
||||
resolved_type_params[typ], short, resolved_type_params
|
||||
)
|
||||
else:
|
||||
unpacked_type_arg = get_args(typ)[0]
|
||||
if not is_variable_length_tuple(
|
||||
unpacked_type_arg
|
||||
) and not is_type_var_tuple(unpacked_type_arg):
|
||||
return _get_args_str(
|
||||
unpacked_type_arg, short, resolved_type_params
|
||||
)
|
||||
unpacked_type_name = type_name(
|
||||
unpacked_type_arg, short, resolved_type_params
|
||||
)
|
||||
if PY_311_MIN:
|
||||
return f"*{unpacked_type_name}"
|
||||
else:
|
||||
_unpack = _typing_name("Unpack", short, typ.__module__)
|
||||
return f"{_unpack}[{unpacked_type_name}]"
|
||||
elif not is_type_origin and is_generic(typ):
|
||||
args_str = _get_args_str(typ, short, resolved_type_params)
|
||||
if not args_str:
|
||||
return get_generic_name(typ, short)
|
||||
else:
|
||||
return f"{get_generic_name(typ, short)}[{args_str}]"
|
||||
elif is_builtin_type(typ):
|
||||
return typ.__qualname__
|
||||
elif is_type_var(typ):
|
||||
if (
|
||||
typ in resolved_type_params
|
||||
and resolved_type_params[typ] is not typ
|
||||
):
|
||||
return type_name(
|
||||
resolved_type_params[typ], short, resolved_type_params
|
||||
)
|
||||
elif is_type_var_any(typ):
|
||||
return _typing_name("Any", short)
|
||||
constraints = getattr(typ, "__constraints__")
|
||||
if constraints:
|
||||
args_str = ", ".join(
|
||||
type_name(c, short, resolved_type_params) for c in constraints
|
||||
)
|
||||
return f"{_typing_name('Union', short)}[{args_str}]"
|
||||
else:
|
||||
if type_var_has_default(typ):
|
||||
bound = get_type_var_default(typ)
|
||||
else:
|
||||
bound = getattr(typ, "__bound__")
|
||||
return type_name(bound, short, resolved_type_params)
|
||||
elif is_new_type(typ) and not PY_310_MIN:
|
||||
# because __qualname__ and __module__ are messed up
|
||||
typ = typ.__supertype__
|
||||
try:
|
||||
if short:
|
||||
return typ.__qualname__ # type: ignore
|
||||
else:
|
||||
return f"{typ.__module__}.{typ.__qualname__}" # type: ignore
|
||||
except AttributeError:
|
||||
return str(typ)
|
||||
|
||||
|
||||
def is_special_typing_primitive(typ: Any) -> bool:
|
||||
try:
|
||||
issubclass(typ, object)
|
||||
return False
|
||||
except TypeError:
|
||||
return True
|
||||
|
||||
|
||||
def is_generic(typ: Type) -> bool:
|
||||
with suppress(Exception):
|
||||
if hasattr(typ, "__class_getitem__"):
|
||||
return True
|
||||
# noinspection PyProtectedMember
|
||||
# noinspection PyUnresolvedReferences
|
||||
if (
|
||||
issubclass(typ.__class__, typing._BaseGenericAlias) # type: ignore
|
||||
or type(typ) is types.GenericAlias # type: ignore # noqa: E721
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
# else: # for PEP 585 generics without args
|
||||
# try:
|
||||
# return (
|
||||
# hasattr(typ, "__class_getitem__")
|
||||
# and type(typ[str]) is types.GenericAlias # type: ignore
|
||||
# )
|
||||
# except (TypeError, AttributeError):
|
||||
# return False
|
||||
|
||||
|
||||
def is_typed_dict(typ: Type) -> bool:
|
||||
for module in (typing, typing_extensions):
|
||||
with suppress(AttributeError):
|
||||
if type(typ) is getattr(module, "_TypedDictMeta"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_readonly(typ: Type) -> bool:
|
||||
origin = get_type_origin(typ)
|
||||
for module in (typing, typing_extensions):
|
||||
with suppress(AttributeError):
|
||||
if origin is getattr(module, "ReadOnly"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_named_tuple(typ: Type) -> bool:
|
||||
try:
|
||||
return issubclass(typ, tuple) and hasattr(typ, "_fields")
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def is_new_type(typ: Type) -> bool:
|
||||
return hasattr(typ, "__supertype__")
|
||||
|
||||
|
||||
def is_union(typ: Type) -> bool:
|
||||
try:
|
||||
if PY_310_MIN and isinstance(typ, types.UnionType): # type: ignore
|
||||
return True
|
||||
return typ.__origin__ is Union
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
||||
def is_optional(
|
||||
typ: Type, resolved_type_params: Optional[dict[Type, Type]] = None
|
||||
) -> bool:
|
||||
if resolved_type_params is None:
|
||||
resolved_type_params = {}
|
||||
if not is_union(typ):
|
||||
return False
|
||||
args = get_args(typ)
|
||||
if len(args) != 2:
|
||||
return False
|
||||
for arg in args:
|
||||
if resolved_type_params.get(arg, arg) is NoneType:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_annotated(typ: Type) -> bool:
|
||||
for module in (typing, typing_extensions):
|
||||
with suppress(AttributeError):
|
||||
if type(typ) is getattr(module, "_AnnotatedAlias"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_type_annotations(typ: Type) -> Sequence[Any]:
|
||||
return getattr(typ, "__metadata__", [])
|
||||
|
||||
|
||||
def is_literal(typ: Type) -> bool:
|
||||
if PY_39:
|
||||
with suppress(AttributeError):
|
||||
return is_generic(typ) and get_generic_name(typ, True) == "Literal"
|
||||
elif PY_310_MIN:
|
||||
with suppress(AttributeError):
|
||||
# noinspection PyProtectedMember
|
||||
# noinspection PyUnresolvedReferences
|
||||
return type(typ) is typing._LiteralGenericAlias # type: ignore
|
||||
return False
|
||||
|
||||
|
||||
def is_local_type_name(typ_name: str) -> bool:
|
||||
return "<locals>" in typ_name
|
||||
|
||||
|
||||
def not_none_type_arg(
|
||||
type_args: tuple[Type, ...],
|
||||
resolved_type_params: Optional[dict[Type, Type]] = None,
|
||||
) -> Optional[Type]:
|
||||
if resolved_type_params is None:
|
||||
resolved_type_params = {}
|
||||
for type_arg in type_args:
|
||||
if resolved_type_params.get(type_arg, type_arg) is not NoneType:
|
||||
return type_arg
|
||||
return None
|
||||
|
||||
|
||||
def is_type_var(typ: Type) -> bool:
|
||||
return hasattr(typ, "__constraints__")
|
||||
|
||||
|
||||
def is_type_var_any(typ: Type) -> bool:
|
||||
if not is_type_var(typ):
|
||||
return False
|
||||
elif typ.__constraints__ != ():
|
||||
return False
|
||||
elif typ.__bound__ not in (None, Any):
|
||||
return False
|
||||
elif type_var_has_default(typ):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def is_class_var(typ: Type) -> bool:
|
||||
return get_type_origin(typ) is ClassVar
|
||||
|
||||
|
||||
def is_final(typ: Type) -> bool:
|
||||
return get_type_origin(typ) is typing_extensions.Final
|
||||
|
||||
|
||||
def is_init_var(typ: Type) -> bool:
|
||||
return isinstance(typ, dataclasses.InitVar)
|
||||
|
||||
|
||||
def get_class_that_defines_method(
|
||||
method_name: str, cls: Type
|
||||
) -> Optional[Type]:
|
||||
for cls in cls.__mro__:
|
||||
if method_name in cls.__dict__:
|
||||
return cls
|
||||
return None
|
||||
|
||||
|
||||
def get_class_that_defines_field(field_name: str, cls: Type) -> Optional[Type]:
|
||||
prev_cls = None
|
||||
prev_field = None
|
||||
for base in reversed(cls.__mro__):
|
||||
if dataclasses.is_dataclass(base):
|
||||
field = getattr(base, _FIELDS).get(field_name)
|
||||
if field and field != prev_field:
|
||||
prev_field = field
|
||||
prev_cls = base
|
||||
return prev_cls or cls
|
||||
|
||||
|
||||
def is_dataclass_dict_mixin(typ: Type) -> bool:
|
||||
return type_name(typ) == DataClassDictMixinPath
|
||||
|
||||
|
||||
def is_dataclass_dict_mixin_subclass(typ: Type) -> bool:
|
||||
with suppress(AttributeError):
|
||||
for cls in typ.__mro__:
|
||||
if is_dataclass_dict_mixin(cls):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_orig_bases(typ: Type) -> tuple[Type, ...]:
|
||||
return getattr(typ, "__orig_bases__", ())
|
||||
|
||||
|
||||
def collect_type_params(typ: Type) -> Sequence[Type]:
|
||||
type_params = []
|
||||
for type_arg in get_args(typ):
|
||||
if type_arg in type_params:
|
||||
continue
|
||||
elif is_type_var(type_arg):
|
||||
type_params.append(type_arg)
|
||||
elif is_unpack(type_arg) and is_type_var_tuple(get_args(type_arg)[0]):
|
||||
type_params.append(type_arg)
|
||||
else:
|
||||
for _type_param in collect_type_params(type_arg):
|
||||
if _type_param not in type_params:
|
||||
type_params.append(_type_param)
|
||||
return type_params
|
||||
|
||||
|
||||
def _check_generic(
|
||||
typ: Type, type_params: Sequence[Type], type_args: Sequence[Type]
|
||||
) -> None:
|
||||
# https://github.com/python/cpython/issues/99382
|
||||
unpacks = len(list(filter(is_unpack, type_params)))
|
||||
if unpacks > 1:
|
||||
raise TypeError(
|
||||
"Multiple unpacks are disallowed within a single type parameter "
|
||||
f"list for {type_name(typ)}"
|
||||
)
|
||||
elif unpacks == 1:
|
||||
expected_count = len(type_params) - 1
|
||||
expected_msg = f"at least {len(type_params) - 1}"
|
||||
else:
|
||||
expected_count = len(type_params)
|
||||
expected_msg = f"{expected_count}"
|
||||
args_len = len(type_args)
|
||||
if 0 < args_len < expected_count:
|
||||
raise TypeError(
|
||||
f"Too few arguments for {type_name(typ)}; "
|
||||
f"actual {args_len}, expected {expected_msg}"
|
||||
)
|
||||
|
||||
|
||||
def _flatten_type_args(
|
||||
type_args: Sequence[Type],
|
||||
allow_ellipsis_if_many_args: bool = False,
|
||||
) -> Sequence[Type]:
|
||||
result = []
|
||||
for type_arg in type_args:
|
||||
if is_unpack(type_arg):
|
||||
unpacked_type = get_args(type_arg)[0]
|
||||
if is_type_var_tuple(unpacked_type):
|
||||
result.append(type_arg)
|
||||
elif is_variable_length_tuple(unpacked_type):
|
||||
if len(type_args) == 1:
|
||||
result.extend(_flatten_type_args(get_args(unpacked_type)))
|
||||
elif allow_ellipsis_if_many_args:
|
||||
result.extend(_flatten_type_args(get_args(unpacked_type)))
|
||||
else:
|
||||
result.append(type_arg)
|
||||
elif unpacked_type == Tuple[()]:
|
||||
if len(type_args) == 1:
|
||||
result.append(()) # type: ignore
|
||||
elif unpacked_type == tuple[()]: # type: ignore
|
||||
if len(type_args) == 1:
|
||||
result.append(()) # type: ignore
|
||||
else:
|
||||
result.extend(_flatten_type_args(get_args(unpacked_type)))
|
||||
else:
|
||||
result.append(type_arg)
|
||||
return result
|
||||
|
||||
|
||||
def resolve_type_params(
|
||||
typ: Type,
|
||||
type_args: Sequence[Type] = (),
|
||||
include_bases: bool = True,
|
||||
) -> dict[Type, dict[Type, Type]]:
|
||||
resolved_type_params: dict[Type, Type] = {}
|
||||
result = {typ: resolved_type_params}
|
||||
type_params = []
|
||||
|
||||
for base in get_orig_bases(typ):
|
||||
base_type_params = collect_type_params(base)
|
||||
for type_param in base_type_params:
|
||||
if type_param not in type_params:
|
||||
type_params.append(type_param)
|
||||
|
||||
_check_generic(typ, type_params, type_args)
|
||||
|
||||
type_args = _flatten_type_args(type_args, allow_ellipsis_if_many_args=True)
|
||||
param_idx = 0
|
||||
unpack_param_idx = -1
|
||||
arg_idx = 0
|
||||
while param_idx < len(type_params):
|
||||
type_param = type_params[param_idx]
|
||||
if not is_unpack(type_param):
|
||||
if type_param not in resolved_type_params:
|
||||
try:
|
||||
next_type_arg = type_args[arg_idx]
|
||||
if next_type_arg is Ellipsis:
|
||||
next_type_arg = type_args[arg_idx - 1]
|
||||
else:
|
||||
if unpack_param_idx < 0:
|
||||
arg_idx += 1
|
||||
else:
|
||||
arg_idx -= 1
|
||||
except IndexError:
|
||||
next_type_arg = type_param
|
||||
resolved_type_params[type_param] = next_type_arg
|
||||
if unpack_param_idx < 0:
|
||||
param_idx += 1
|
||||
else:
|
||||
param_idx -= 1
|
||||
elif unpack_param_idx < 0:
|
||||
unpack_param_idx = param_idx
|
||||
param_idx = -1
|
||||
arg_idx = -1
|
||||
unpacked_param = get_args(type_param)[0]
|
||||
for y in reversed(get_args(unpacked_param)): # pragma: no cover
|
||||
# We turn Tuple[x,y] to x, y, but leave this here just in case
|
||||
type_params.insert(param_idx, y)
|
||||
else:
|
||||
if not type_args and is_type_var_tuple(get_args(type_param)[0]):
|
||||
resolved_type_params[type_param] = Unpack[
|
||||
Tuple[Any, ...] # type: ignore
|
||||
]
|
||||
break
|
||||
t_args = type_args[unpack_param_idx : len(type_args) + arg_idx + 1]
|
||||
if len(t_args) == 1 and t_args[0] == ():
|
||||
x: Any = ()
|
||||
elif len(t_args) > 2 and t_args[-1] is Ellipsis:
|
||||
x = (*t_args[:-2], Unpack[Tuple[t_args[-2], ...]])
|
||||
else:
|
||||
x = tuple(t_args)
|
||||
resolved_type_params[type_param] = Unpack[Tuple[x]] # type: ignore
|
||||
break
|
||||
|
||||
if include_bases:
|
||||
orig_bases = {
|
||||
get_type_origin(orig_base): orig_base
|
||||
for orig_base in get_orig_bases(typ)
|
||||
}
|
||||
for base in getattr(typ, "__bases__", ()):
|
||||
orig_base = orig_bases.get(get_type_origin(base))
|
||||
base_type_params = get_args(orig_base)
|
||||
base_type_args = tuple(
|
||||
[resolved_type_params.get(a, a) for a in base_type_params]
|
||||
)
|
||||
result.update(resolve_type_params(base, base_type_args))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def substitute_type_params(typ: Type, substitutions: dict[Type, Type]) -> Type:
|
||||
if is_annotated(typ):
|
||||
origin = get_type_origin(typ)
|
||||
subst = substitutions.get(origin, origin)
|
||||
return typing_extensions.Annotated[
|
||||
(subst, *get_type_annotations(typ)) # type: ignore
|
||||
]
|
||||
else:
|
||||
new_type_args = []
|
||||
for type_param in collect_type_params(typ):
|
||||
new_type_args.append(substitutions.get(type_param, type_param))
|
||||
if new_type_args:
|
||||
with suppress(TypeError, KeyError):
|
||||
return typ[tuple(new_type_args)]
|
||||
if is_hashable(typ):
|
||||
return substitutions.get(typ, typ)
|
||||
else:
|
||||
return typ
|
||||
|
||||
|
||||
def get_name_error_name(e: NameError) -> str:
|
||||
if PY_310_MIN:
|
||||
return e.name # type: ignore
|
||||
else:
|
||||
match = re.search("'(.*)'", e.args[0])
|
||||
return match.group(1) if match else ""
|
||||
|
||||
|
||||
def is_dialect_subclass(typ: Type) -> bool:
|
||||
try:
|
||||
return issubclass(typ, Dialect)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def is_self(typ: Type) -> bool:
|
||||
return typ is typing_extensions.Self
|
||||
|
||||
|
||||
def is_required(typ: Type) -> bool:
|
||||
return get_type_origin(typ) is typing_extensions.Required # noqa
|
||||
|
||||
|
||||
def is_not_required(typ: Type) -> bool:
|
||||
return get_type_origin(typ) is typing_extensions.NotRequired # noqa
|
||||
|
||||
|
||||
def get_function_arg_annotation(
|
||||
function: Callable[..., Any],
|
||||
arg_name: Optional[str] = None,
|
||||
arg_pos: Optional[int] = None,
|
||||
) -> type:
|
||||
parameters = inspect.signature(function).parameters
|
||||
if arg_name is not None:
|
||||
parameter = parameters[arg_name]
|
||||
elif arg_pos is not None:
|
||||
parameter = parameters[list(parameters.keys())[arg_pos]]
|
||||
else:
|
||||
raise ValueError("arg_name or arg_pos must be passed")
|
||||
annotation = parameter.annotation
|
||||
if annotation is inspect.Signature.empty:
|
||||
raise ValueError(f"Argument {arg_name} doesn't have annotation")
|
||||
if isinstance(annotation, str):
|
||||
annotation = str_to_forward_ref(
|
||||
annotation, inspect.getmodule(function)
|
||||
)
|
||||
return annotation
|
||||
|
||||
|
||||
def get_function_return_annotation(function: Callable[[Any], Any]) -> Type:
|
||||
annotation = inspect.signature(function).return_annotation
|
||||
if annotation is inspect.Signature.empty:
|
||||
raise ValueError("Function doesn't have return annotation")
|
||||
if isinstance(annotation, str):
|
||||
annotation = str_to_forward_ref(
|
||||
annotation, inspect.getmodule(function)
|
||||
)
|
||||
return annotation
|
||||
|
||||
|
||||
def is_unpack(typ: Type) -> bool:
|
||||
for module in (typing, typing_extensions):
|
||||
with suppress(AttributeError):
|
||||
if get_type_origin(typ) is getattr(module, "Unpack"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_type_var_tuple(typ: Type) -> bool:
|
||||
for module in (typing, typing_extensions):
|
||||
with suppress(AttributeError):
|
||||
if type(typ) is getattr(module, "TypeVarTuple"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_variable_length_tuple(typ: Type) -> bool:
|
||||
type_args = get_args(typ)
|
||||
return len(type_args) == 2 and type_args[1] is Ellipsis
|
||||
|
||||
|
||||
def hash_type_args(type_args: Iterable[Type]) -> str:
|
||||
return md5(",".join(map(type_name, type_args)).encode()).hexdigest()
|
||||
|
||||
|
||||
def iter_all_subclasses(cls: Type) -> Iterator[Type]:
|
||||
for subclass in cls.__subclasses__():
|
||||
yield subclass
|
||||
yield from iter_all_subclasses(subclass)
|
||||
|
||||
|
||||
def is_hashable(value: Any) -> bool:
|
||||
try:
|
||||
hash(value)
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def is_hashable_type(typ: Any) -> bool:
|
||||
try:
|
||||
return issubclass(typ, Hashable)
|
||||
except TypeError:
|
||||
return True
|
||||
|
||||
|
||||
def str_to_forward_ref(
|
||||
annotation: str, module: Optional[types.ModuleType] = None
|
||||
) -> ForwardRef:
|
||||
return ForwardRef(annotation, module=module)
|
||||
|
||||
|
||||
def evaluate_forward_ref(
|
||||
typ: ForwardRef, globalns: dict[str, Any], localns: dict[str, Any]
|
||||
) -> Optional[Type]:
|
||||
if PY_313_MIN:
|
||||
return typ._evaluate(
|
||||
globalns, localns, type_params=(), recursive_guard=frozenset()
|
||||
) # type: ignore[call-arg]
|
||||
else:
|
||||
return typ._evaluate(
|
||||
globalns, localns, recursive_guard=frozenset()
|
||||
) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def get_forward_ref_referencing_globals(
|
||||
referenced_type: ForwardRef,
|
||||
referencing_object: Optional[Any] = None,
|
||||
fallback: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
if fallback is None:
|
||||
fallback = {}
|
||||
forward_module = getattr(referenced_type, "__forward_module__", None)
|
||||
if not forward_module and referencing_object:
|
||||
# We can't get the module in which ForwardRef's value is defined on
|
||||
# Python < 3.10, ForwardRef evaluation might not work properly
|
||||
# without this information, so we will consider the namespace of
|
||||
# the module in which this ForwardRef is used as globalns.
|
||||
return getattr(
|
||||
sys.modules.get(referencing_object.__module__, None),
|
||||
"__dict__",
|
||||
fallback,
|
||||
)
|
||||
else:
|
||||
return getattr(forward_module, "__dict__", fallback)
|
||||
|
||||
|
||||
def is_type_alias_type(typ: Type) -> bool:
|
||||
if PY_312_MIN:
|
||||
return isinstance(typ, typing.TypeAliasType) # type: ignore
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def type_var_has_default(typ: Any) -> bool:
|
||||
try:
|
||||
return typ.has_default()
|
||||
except AttributeError:
|
||||
return getattr(typ, "__default__", None) is not None
|
||||
|
||||
|
||||
def get_type_var_default(typ: Any) -> Type:
|
||||
return getattr(typ, "__default__")
|
||||
@@ -0,0 +1,52 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from mashumaro.core.meta.code.builder import CodeBuilder
|
||||
from mashumaro.dialect import Dialect
|
||||
from mashumaro.exceptions import UnresolvedTypeReferenceError
|
||||
|
||||
__all__ = [
|
||||
"compile_mixin_packer",
|
||||
"compile_mixin_unpacker",
|
||||
]
|
||||
|
||||
|
||||
def compile_mixin_packer(
|
||||
cls: Type,
|
||||
format_name: str = "dict",
|
||||
dialect: Optional[Type[Dialect]] = None,
|
||||
encoder: Any = None,
|
||||
encoder_kwargs: Optional[dict[str, dict[str, tuple[str, Any]]]] = None,
|
||||
) -> None:
|
||||
builder = CodeBuilder(
|
||||
cls=cls,
|
||||
format_name=format_name,
|
||||
encoder=encoder,
|
||||
encoder_kwargs=encoder_kwargs,
|
||||
default_dialect=dialect,
|
||||
)
|
||||
config = builder.get_config()
|
||||
try:
|
||||
builder.add_pack_method()
|
||||
except UnresolvedTypeReferenceError:
|
||||
if not config.allow_postponed_evaluation:
|
||||
raise
|
||||
|
||||
|
||||
def compile_mixin_unpacker(
|
||||
cls: Type,
|
||||
format_name: str = "dict",
|
||||
dialect: Optional[Type[Dialect]] = None,
|
||||
decoder: Any = None,
|
||||
) -> None:
|
||||
builder = CodeBuilder(
|
||||
cls=cls,
|
||||
format_name=format_name,
|
||||
decoder=decoder,
|
||||
default_dialect=dialect,
|
||||
)
|
||||
config = builder.get_config()
|
||||
try:
|
||||
builder.add_unpack_method()
|
||||
except UnresolvedTypeReferenceError:
|
||||
if not config.allow_postponed_evaluation:
|
||||
raise
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,285 @@
|
||||
import re
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from dataclasses import dataclass, field, replace
|
||||
from functools import cached_property
|
||||
from types import new_class
|
||||
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
|
||||
|
||||
from typing_extensions import ParamSpec, TypeAlias
|
||||
|
||||
from mashumaro.core.meta.code.lines import CodeLines
|
||||
from mashumaro.core.meta.helpers import (
|
||||
get_type_origin,
|
||||
is_annotated,
|
||||
is_generic,
|
||||
is_hashable_type,
|
||||
is_self,
|
||||
type_name,
|
||||
)
|
||||
from mashumaro.exceptions import UnserializableField
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from mashumaro.core.meta.code.builder import CodeBuilder
|
||||
else:
|
||||
CodeBuilder = Any
|
||||
|
||||
|
||||
class TypeMatchEligibleExpression(str):
|
||||
pass
|
||||
|
||||
|
||||
NoneType = type(None)
|
||||
Expression: TypeAlias = Union[str, TypeMatchEligibleExpression]
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
_PY_VALID_ID_RE = re.compile(r"\W|^(?=\d)")
|
||||
|
||||
|
||||
class AttrsHolder:
|
||||
def __new__(
|
||||
cls, name: Optional[str] = None, *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
ah = new_class("AttrsHolder")
|
||||
ah_id = id(ah)
|
||||
if not name:
|
||||
name = f"attrs_{ah_id}"
|
||||
ah.__name__ = ah.__qualname__ = name
|
||||
return ah
|
||||
|
||||
|
||||
class ExpressionWrapper:
|
||||
def __init__(self, expression: str):
|
||||
self.expression = expression
|
||||
|
||||
|
||||
@dataclass
|
||||
class FieldContext:
|
||||
name: str
|
||||
metadata: Mapping
|
||||
packer: Optional[str] = None
|
||||
unpacker: Optional[str] = None
|
||||
|
||||
def copy(self, **changes: Any) -> "FieldContext":
|
||||
return replace(self, **changes)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValueSpec:
|
||||
type: Type
|
||||
origin_type: Type = field(init=False)
|
||||
expression: Expression
|
||||
builder: CodeBuilder
|
||||
field_ctx: FieldContext
|
||||
could_be_none: bool = True
|
||||
annotated_type: Optional[Type] = None
|
||||
owner: Optional[Type] = None
|
||||
no_copy_collections: Sequence = tuple()
|
||||
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
if key == "type":
|
||||
self.origin_type = get_type_origin(value)
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def copy(self, **changes: Any) -> "ValueSpec":
|
||||
return replace(self, **changes)
|
||||
|
||||
@cached_property
|
||||
def annotations(self) -> Sequence[str]:
|
||||
return getattr(self.annotated_type, "__metadata__", [])
|
||||
|
||||
@cached_property
|
||||
def attrs(self) -> Any:
|
||||
if self.builder.is_nailed:
|
||||
return self.builder.attrs
|
||||
if is_self(self.type):
|
||||
typ = self.builder.cls
|
||||
else:
|
||||
typ = self.origin_type
|
||||
attrs = self.attrs_registry.get(typ)
|
||||
if attrs is None:
|
||||
attrs = AttrsHolder()
|
||||
self.attrs_registry[typ] = attrs
|
||||
return attrs
|
||||
|
||||
@cached_property
|
||||
def cls_attrs_name(self) -> str:
|
||||
if self.builder.is_nailed:
|
||||
return "cls"
|
||||
else:
|
||||
self.builder.ensure_object_imported(self.attrs)
|
||||
return self.attrs.__name__
|
||||
|
||||
@cached_property
|
||||
def self_attrs_name(self) -> str:
|
||||
if self.builder.is_nailed:
|
||||
return "self"
|
||||
else:
|
||||
self.builder.ensure_object_imported(self.attrs)
|
||||
return self.attrs.__name__
|
||||
|
||||
@cached_property
|
||||
def attrs_registry(self) -> dict[Any, Any]:
|
||||
return self.builder.attrs_registry
|
||||
|
||||
@cached_property
|
||||
def attrs_registry_name(self) -> str:
|
||||
name = f"attrs_registry_{id(self.attrs_registry)}"
|
||||
self.builder.ensure_object_imported(self.attrs_registry, name)
|
||||
return name
|
||||
|
||||
|
||||
class AbstractMethodBuilder(ABC):
|
||||
@abstractmethod
|
||||
def get_method_prefix(self) -> str: # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
def _generate_method_name(
|
||||
self, spec: ValueSpec
|
||||
) -> str: # pragma: no cover
|
||||
prefix = self.get_method_prefix()
|
||||
if prefix:
|
||||
prefix = f"{prefix}_"
|
||||
if spec.field_ctx.name:
|
||||
suffix = f"_{spec.field_ctx.name}"
|
||||
else:
|
||||
suffix = ""
|
||||
return f"__{prefix}{spec.builder.cls.__name__}{suffix}__{random_hex()}"
|
||||
|
||||
@abstractmethod
|
||||
def _add_definition(self, spec: ValueSpec, lines: CodeLines) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _generate_method_args(self, spec: ValueSpec) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _add_body(
|
||||
self, spec: ValueSpec, lines: CodeLines
|
||||
) -> None: # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
def _add_setattr(
|
||||
self, spec: ValueSpec, method_name: str, lines: CodeLines
|
||||
) -> None:
|
||||
lines.append(
|
||||
f"setattr({spec.cls_attrs_name}, '{method_name}', {method_name})"
|
||||
)
|
||||
|
||||
def _compile(self, spec: ValueSpec, lines: CodeLines) -> None:
|
||||
if spec.builder.get_config().debug:
|
||||
print(f"{type_name(spec.builder.cls)}:")
|
||||
print(lines.as_text())
|
||||
exec(lines.as_text(), spec.builder.globals, spec.builder.__dict__)
|
||||
|
||||
@abstractmethod
|
||||
def _get_call_expr(self, spec: ValueSpec, method_name: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _before_build(self, spec: ValueSpec) -> None:
|
||||
pass
|
||||
|
||||
def _get_existing_method(self, spec: ValueSpec) -> Optional[str]:
|
||||
return None
|
||||
|
||||
def build(self, spec: ValueSpec) -> str:
|
||||
self._before_build(spec)
|
||||
if method := self._get_existing_method(spec):
|
||||
return method
|
||||
lines = CodeLines()
|
||||
method_name = self._add_definition(spec, lines)
|
||||
with lines.indent():
|
||||
self._add_body(spec, lines)
|
||||
self._add_setattr(spec, method_name, lines)
|
||||
self._compile(spec, lines)
|
||||
return self._get_call_expr(spec, method_name)
|
||||
|
||||
|
||||
ValueSpecExprCreator: TypeAlias = Callable[[ValueSpec], Optional[Expression]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Registry:
|
||||
_registry: list[ValueSpecExprCreator] = field(default_factory=list)
|
||||
|
||||
def register(self, function: ValueSpecExprCreator) -> ValueSpecExprCreator:
|
||||
self._registry.append(function)
|
||||
return function
|
||||
|
||||
def get(self, spec: ValueSpec) -> Expression:
|
||||
if is_annotated(spec.type):
|
||||
spec.annotated_type = spec.builder.get_real_type(
|
||||
spec.field_ctx.name, spec.type
|
||||
)
|
||||
spec.type = get_type_origin(spec.type)
|
||||
spec.type = spec.builder.get_real_type(spec.field_ctx.name, spec.type)
|
||||
spec.builder.add_type_modules(spec.type)
|
||||
for packer in self._registry:
|
||||
expr = packer(spec)
|
||||
if expr is not None:
|
||||
return expr
|
||||
raise UnserializableField(
|
||||
spec.field_ctx.name, spec.type, spec.builder.cls
|
||||
)
|
||||
|
||||
|
||||
def ensure_generic_collection(spec: ValueSpec) -> bool:
|
||||
if not is_generic(spec.type):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def ensure_mapping_key_type_hashable(
|
||||
spec: ValueSpec, type_args: Sequence[Type]
|
||||
) -> bool:
|
||||
if type_args:
|
||||
first_type_arg = type_args[0]
|
||||
if not is_hashable_type(first_type_arg):
|
||||
raise UnserializableField(
|
||||
field_name=spec.field_ctx.name,
|
||||
field_type=spec.type,
|
||||
holder_class=spec.builder.cls,
|
||||
msg=(
|
||||
f"{type_name(first_type_arg, short=True)} "
|
||||
"is unhashable and can not be used as a key"
|
||||
),
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def ensure_generic_collection_subclass(
|
||||
spec: ValueSpec, *checked_types: Type
|
||||
) -> bool:
|
||||
return issubclass(
|
||||
spec.origin_type, checked_types
|
||||
) and ensure_generic_collection(spec)
|
||||
|
||||
|
||||
def ensure_generic_mapping(
|
||||
spec: ValueSpec, args: Sequence[Type], checked_type: Type
|
||||
) -> bool:
|
||||
return ensure_generic_collection_subclass(
|
||||
spec, checked_type
|
||||
) and ensure_mapping_key_type_hashable(spec, args)
|
||||
|
||||
|
||||
def expr_or_maybe_none(spec: ValueSpec, new_expr: Expression) -> Expression:
|
||||
if spec.could_be_none:
|
||||
return f"{new_expr} if {spec.expression} is not None else None"
|
||||
else:
|
||||
return new_expr
|
||||
|
||||
|
||||
def random_hex() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
def clean_id(value: str) -> str:
|
||||
if not value:
|
||||
return "_"
|
||||
|
||||
return _PY_VALID_ID_RE.sub("_", value)
|
||||
@@ -0,0 +1,873 @@
|
||||
import datetime
|
||||
import enum
|
||||
import ipaddress
|
||||
import os
|
||||
import re
|
||||
import typing
|
||||
import uuid
|
||||
import zoneinfo
|
||||
from base64 import encodebytes
|
||||
from collections import ChainMap, Counter, OrderedDict, deque
|
||||
from collections.abc import Callable, Collection, Mapping, Sequence, Set
|
||||
from contextlib import suppress
|
||||
from dataclasses import is_dataclass
|
||||
from decimal import Decimal
|
||||
from fractions import Fraction
|
||||
from typing import Any, ForwardRef, Optional, Tuple, Union
|
||||
|
||||
import typing_extensions
|
||||
|
||||
from mashumaro.core.const import PY_311_MIN
|
||||
from mashumaro.core.meta.code.lines import CodeLines
|
||||
from mashumaro.core.meta.helpers import (
|
||||
get_args,
|
||||
get_class_that_defines_method,
|
||||
get_function_return_annotation,
|
||||
get_literal_values,
|
||||
get_type_origin,
|
||||
get_type_var_default,
|
||||
is_final,
|
||||
is_generic,
|
||||
is_literal,
|
||||
is_named_tuple,
|
||||
is_new_type,
|
||||
is_not_required,
|
||||
is_optional,
|
||||
is_readonly,
|
||||
is_required,
|
||||
is_self,
|
||||
is_special_typing_primitive,
|
||||
is_type_alias_type,
|
||||
is_type_var,
|
||||
is_type_var_any,
|
||||
is_type_var_tuple,
|
||||
is_typed_dict,
|
||||
is_union,
|
||||
is_unpack,
|
||||
not_none_type_arg,
|
||||
resolve_type_params,
|
||||
substitute_type_params,
|
||||
type_name,
|
||||
type_var_has_default,
|
||||
)
|
||||
from mashumaro.core.meta.types.common import (
|
||||
Expression,
|
||||
ExpressionWrapper,
|
||||
NoneType,
|
||||
Registry,
|
||||
ValueSpec,
|
||||
clean_id,
|
||||
ensure_generic_collection,
|
||||
ensure_generic_collection_subclass,
|
||||
ensure_generic_mapping,
|
||||
expr_or_maybe_none,
|
||||
random_hex,
|
||||
)
|
||||
from mashumaro.exceptions import (
|
||||
UnserializableDataError,
|
||||
UnserializableField,
|
||||
UnsupportedSerializationEngine,
|
||||
)
|
||||
from mashumaro.helper import pass_through
|
||||
from mashumaro.types import (
|
||||
GenericSerializableType,
|
||||
SerializableType,
|
||||
SerializationStrategy,
|
||||
)
|
||||
|
||||
__all__ = ["PackerRegistry"]
|
||||
|
||||
|
||||
PackerRegistry = Registry()
|
||||
register = PackerRegistry.register
|
||||
|
||||
|
||||
def _pack_with_annotated_serialization_strategy(
|
||||
spec: ValueSpec,
|
||||
strategy: SerializationStrategy,
|
||||
) -> Expression:
|
||||
strategy_type = type(strategy)
|
||||
try:
|
||||
value_type: Union[type, Any] = get_function_return_annotation(
|
||||
strategy.serialize
|
||||
)
|
||||
except (KeyError, ValueError):
|
||||
value_type = Any
|
||||
if isinstance(value_type, ForwardRef):
|
||||
value_type = spec.builder.evaluate_forward_ref(
|
||||
value_type, spec.origin_type
|
||||
)
|
||||
value_type = substitute_type_params(
|
||||
value_type, # type: ignore
|
||||
resolve_type_params(strategy_type, get_args(spec.type))[strategy_type],
|
||||
)
|
||||
overridden_fn = f"__{spec.field_ctx.name}_serialize_{random_hex()}"
|
||||
setattr(spec.attrs, overridden_fn, strategy.serialize)
|
||||
new_spec = spec.copy(
|
||||
type=value_type,
|
||||
expression=(
|
||||
f"{spec.self_attrs_name}.{overridden_fn}({spec.expression})"
|
||||
),
|
||||
)
|
||||
field_metadata = new_spec.field_ctx.metadata
|
||||
if field_metadata.get("serialization_strategy") is strategy:
|
||||
new_spec.field_ctx.metadata = {
|
||||
k: v
|
||||
for k, v in field_metadata.items()
|
||||
if k != "serialization_strategy"
|
||||
}
|
||||
return PackerRegistry.get(
|
||||
spec.copy(
|
||||
type=value_type,
|
||||
expression=(
|
||||
f"{spec.self_attrs_name}.{overridden_fn}({spec.expression})"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_overridden_serialization_method(
|
||||
spec: ValueSpec,
|
||||
) -> Optional[Union[Callable, str, ExpressionWrapper]]:
|
||||
serialize_option = spec.field_ctx.metadata.get("serialize")
|
||||
if serialize_option is not None:
|
||||
return serialize_option
|
||||
checking_types = [spec.type, spec.origin_type]
|
||||
if spec.annotated_type:
|
||||
checking_types.insert(0, spec.annotated_type)
|
||||
for typ in checking_types:
|
||||
for strategy in spec.builder.iter_serialization_strategies(
|
||||
spec.field_ctx.metadata, typ
|
||||
):
|
||||
if strategy is pass_through:
|
||||
return pass_through
|
||||
elif isinstance(strategy, dict):
|
||||
serialize_option = strategy.get("serialize")
|
||||
elif isinstance(strategy, SerializationStrategy):
|
||||
if strategy.__use_annotations__ or is_generic(type(strategy)):
|
||||
return ExpressionWrapper(
|
||||
_pack_with_annotated_serialization_strategy(
|
||||
spec=spec,
|
||||
strategy=strategy,
|
||||
)
|
||||
)
|
||||
else:
|
||||
serialize_option = strategy.serialize
|
||||
if serialize_option is not None:
|
||||
return serialize_option
|
||||
|
||||
|
||||
@register
|
||||
def pack_type_with_overridden_serialization(
|
||||
spec: ValueSpec,
|
||||
) -> Optional[Expression]:
|
||||
serialization_method = get_overridden_serialization_method(spec)
|
||||
if serialization_method is pass_through:
|
||||
return spec.expression
|
||||
elif isinstance(serialization_method, ExpressionWrapper):
|
||||
return serialization_method.expression
|
||||
elif callable(serialization_method):
|
||||
overridden_fn = f"__{spec.field_ctx.name}_serialize_{random_hex()}"
|
||||
setattr(spec.attrs, overridden_fn, staticmethod(serialization_method))
|
||||
return f"{spec.self_attrs_name}.{overridden_fn}({spec.expression})"
|
||||
|
||||
|
||||
def _pack_annotated_serializable_type(
|
||||
spec: ValueSpec,
|
||||
) -> Optional[Expression]:
|
||||
try:
|
||||
# noinspection PyProtectedMember
|
||||
# noinspection PyUnresolvedReferences
|
||||
value_type = get_function_return_annotation(
|
||||
spec.origin_type._serialize
|
||||
)
|
||||
except (KeyError, ValueError):
|
||||
raise UnserializableField(
|
||||
field_name=spec.field_ctx.name,
|
||||
field_type=spec.type,
|
||||
holder_class=spec.builder.cls,
|
||||
msg="Method _serialize must have return annotation",
|
||||
) from None
|
||||
if is_self(value_type):
|
||||
return f"{spec.expression}._serialize()"
|
||||
if isinstance(value_type, ForwardRef):
|
||||
value_type = spec.builder.evaluate_forward_ref(
|
||||
value_type, spec.origin_type
|
||||
)
|
||||
value_type = substitute_type_params(
|
||||
value_type,
|
||||
resolve_type_params(spec.origin_type, get_args(spec.type))[
|
||||
spec.origin_type
|
||||
],
|
||||
)
|
||||
return PackerRegistry.get(
|
||||
spec.copy(
|
||||
type=value_type,
|
||||
expression=f"{spec.expression}._serialize()",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
def pack_serializable_type(spec: ValueSpec) -> Optional[Expression]:
|
||||
try:
|
||||
if not issubclass(spec.origin_type, SerializableType):
|
||||
return None
|
||||
except TypeError:
|
||||
return None
|
||||
if spec.origin_type.__use_annotations__:
|
||||
return _pack_annotated_serializable_type(spec)
|
||||
else:
|
||||
return f"{spec.expression}._serialize()"
|
||||
|
||||
|
||||
@register
|
||||
def pack_generic_serializable_type(spec: ValueSpec) -> Optional[Expression]:
|
||||
with suppress(TypeError):
|
||||
if issubclass(spec.origin_type, GenericSerializableType):
|
||||
type_args = get_args(spec.type)
|
||||
spec.builder.add_type_modules(*type_args)
|
||||
type_arg_names = ", ".join(list(map(type_name, type_args)))
|
||||
return f"{spec.expression}._serialize([{type_arg_names}])"
|
||||
|
||||
|
||||
@register
|
||||
def pack_dataclass(spec: ValueSpec) -> Optional[Expression]:
|
||||
if is_dataclass(spec.origin_type):
|
||||
type_args = get_args(spec.type)
|
||||
method_name = spec.builder.get_pack_method_name(
|
||||
type_args, spec.builder.format_name
|
||||
)
|
||||
method_loc = spec.origin_type if spec.builder.is_nailed else spec.attrs
|
||||
if get_class_that_defines_method(
|
||||
method_name, method_loc
|
||||
) != method_loc and (
|
||||
spec.origin_type is not spec.builder.cls
|
||||
or spec.builder.get_pack_method_name(
|
||||
type_args=type_args,
|
||||
format_name=spec.builder.format_name,
|
||||
encoder=spec.builder.encoder,
|
||||
)
|
||||
!= method_name
|
||||
):
|
||||
builder = spec.builder.__class__(
|
||||
spec.origin_type,
|
||||
type_args,
|
||||
dialect=spec.builder.dialect,
|
||||
format_name=spec.builder.format_name,
|
||||
default_dialect=spec.builder.default_dialect,
|
||||
attrs=method_loc,
|
||||
attrs_registry=(
|
||||
spec.attrs_registry if not spec.builder.is_nailed else None
|
||||
),
|
||||
)
|
||||
builder.add_pack_method()
|
||||
flags = spec.builder.get_pack_method_flags(spec.type)
|
||||
if spec.builder.is_nailed:
|
||||
return f"{spec.expression}.{method_name}({flags})"
|
||||
else:
|
||||
cls_alias = clean_id(type_name(spec.origin_type))
|
||||
method_name_alias = f"{cls_alias}_{method_name}"
|
||||
spec.builder.ensure_object_imported(
|
||||
getattr(spec.attrs, method_name), method_name_alias
|
||||
)
|
||||
method_args = spec.expression
|
||||
return f"{method_name_alias}({method_args})"
|
||||
|
||||
|
||||
@register
|
||||
def pack_final(spec: ValueSpec) -> Optional[Expression]:
|
||||
if is_final(spec.type):
|
||||
return PackerRegistry.get(spec.copy(type=get_args(spec.type)[0]))
|
||||
|
||||
|
||||
@register
|
||||
def pack_any(spec: ValueSpec) -> Optional[Expression]:
|
||||
if spec.type is Any:
|
||||
return spec.expression
|
||||
|
||||
|
||||
def pack_union(
|
||||
spec: ValueSpec, args: tuple[type, ...], prefix: str = "union"
|
||||
) -> Expression:
|
||||
if spec.type is spec.owner and spec.field_ctx.packer:
|
||||
return spec.field_ctx.packer
|
||||
lines = CodeLines()
|
||||
|
||||
method_name = (
|
||||
f"__pack_{prefix}_{spec.builder.cls.__name__}_{spec.field_ctx.name}__"
|
||||
f"{random_hex()}"
|
||||
)
|
||||
|
||||
if not spec.field_ctx.packer:
|
||||
method_args = ", ".join(
|
||||
filter(None, ("value", spec.builder.get_pack_method_flags()))
|
||||
)
|
||||
if spec.builder.is_nailed:
|
||||
union_packer = (
|
||||
f"{spec.self_attrs_name}.{method_name}({method_args})"
|
||||
)
|
||||
else:
|
||||
union_packer = f"{method_name}({method_args})"
|
||||
spec.field_ctx.packer = union_packer
|
||||
|
||||
method_args = "self, value" if spec.builder.is_nailed else "value"
|
||||
default_kwargs = spec.builder.get_pack_method_default_flag_values()
|
||||
if default_kwargs:
|
||||
lines.append(f"def {method_name}({method_args}, {default_kwargs}):")
|
||||
else:
|
||||
lines.append(f"def {method_name}({method_args}):")
|
||||
packers: list[str] = []
|
||||
packer_arg_types: dict[str, list[type]] = {}
|
||||
for type_arg in args:
|
||||
packer = PackerRegistry.get(
|
||||
spec.copy(type=type_arg, expression="value", owner=spec.type)
|
||||
)
|
||||
if packer not in packers:
|
||||
if packer == "value":
|
||||
packers.insert(0, packer)
|
||||
else:
|
||||
packers.append(packer)
|
||||
packer_arg_types.setdefault(packer, []).append(type_arg)
|
||||
|
||||
if len(packers) == 1 and packers[0] == "value":
|
||||
return spec.expression
|
||||
|
||||
with lines.indent():
|
||||
for packer in packers:
|
||||
packer_arg_type_names = []
|
||||
for packer_arg_type in packer_arg_types[packer]:
|
||||
if is_generic(packer_arg_type):
|
||||
packer_arg_type = get_type_origin(packer_arg_type)
|
||||
packer_arg_type_name = clean_id(type_name(packer_arg_type))
|
||||
spec.builder.ensure_object_imported(
|
||||
packer_arg_type, packer_arg_type_name
|
||||
)
|
||||
if packer_arg_type_name not in packer_arg_type_names:
|
||||
packer_arg_type_names.append(packer_arg_type_name)
|
||||
if len(packer_arg_type_names) > 1:
|
||||
packer_arg_type_check = (
|
||||
f"in ({', '.join(packer_arg_type_names)})"
|
||||
)
|
||||
else:
|
||||
packer_arg_type_check = f"is {packer_arg_type_names[0]}"
|
||||
if packer == "value":
|
||||
with lines.indent(
|
||||
f"if value.__class__ {packer_arg_type_check}:"
|
||||
):
|
||||
lines.append(f"return {packer}")
|
||||
else:
|
||||
with lines.indent("try:"):
|
||||
lines.append(f"return {packer}")
|
||||
with lines.indent("except Exception:"):
|
||||
lines.append("pass")
|
||||
field_type = spec.builder.get_type_name_identifier(
|
||||
typ=spec.type,
|
||||
resolved_type_params=spec.builder.get_field_resolved_type_params(
|
||||
spec.field_ctx.name
|
||||
),
|
||||
)
|
||||
if spec.builder.is_nailed:
|
||||
lines.append(
|
||||
"raise InvalidFieldValue("
|
||||
f"'{spec.field_ctx.name}',{field_type},value,type(self))"
|
||||
)
|
||||
else:
|
||||
lines.append("raise ValueError(value)")
|
||||
lines.append(
|
||||
f"setattr({spec.cls_attrs_name}, '{method_name}', {method_name})"
|
||||
)
|
||||
if spec.builder.get_config().debug:
|
||||
print(f"{type_name(spec.builder.cls)}:")
|
||||
print(lines.as_text())
|
||||
|
||||
exec(lines.as_text(), spec.builder.globals, spec.builder.__dict__)
|
||||
|
||||
method_args = ", ".join(
|
||||
filter(None, (spec.expression, spec.builder.get_pack_method_flags()))
|
||||
)
|
||||
if spec.builder.is_nailed:
|
||||
return f"{spec.self_attrs_name}.{method_name}({method_args})"
|
||||
else:
|
||||
spec.builder.ensure_object_imported(
|
||||
getattr(spec.attrs, method_name), method_name
|
||||
)
|
||||
return f"{method_name}({method_args})"
|
||||
|
||||
|
||||
def pack_literal(spec: ValueSpec) -> Expression:
|
||||
spec.builder.add_type_modules(spec.type)
|
||||
lines = CodeLines()
|
||||
method_name = (
|
||||
f"__pack_literal_{spec.builder.cls.__name__}_{spec.field_ctx.name}__"
|
||||
f"{random_hex()}"
|
||||
)
|
||||
method_args = "self, value" if spec.builder.is_nailed else "value"
|
||||
default_kwargs = spec.builder.get_pack_method_default_flag_values()
|
||||
if default_kwargs:
|
||||
lines.append(f"def {method_name}({method_args}, {default_kwargs}):")
|
||||
else:
|
||||
lines.append(f"def {method_name}({method_args}):")
|
||||
resolved_type_params = spec.builder.get_field_resolved_type_params(
|
||||
spec.field_ctx.name
|
||||
)
|
||||
with lines.indent():
|
||||
for literal_value in get_literal_values(spec.type):
|
||||
value_type = type(literal_value)
|
||||
packer = PackerRegistry.get(
|
||||
spec.copy(type=value_type, expression="value")
|
||||
)
|
||||
if isinstance(literal_value, enum.Enum):
|
||||
enum_type_name = spec.builder.get_type_name_identifier(
|
||||
typ=value_type,
|
||||
resolved_type_params=resolved_type_params,
|
||||
)
|
||||
with lines.indent(
|
||||
f"if value == {enum_type_name}.{literal_value.name}:"
|
||||
):
|
||||
lines.append(f"return {packer}")
|
||||
elif isinstance(
|
||||
literal_value,
|
||||
(int, str, bytes, bool, NoneType), # type: ignore
|
||||
):
|
||||
with lines.indent(f"if value == {literal_value!r}:"):
|
||||
lines.append(f"return {packer}")
|
||||
field_type = spec.builder.get_type_name_identifier(
|
||||
typ=spec.type,
|
||||
resolved_type_params=resolved_type_params,
|
||||
)
|
||||
if spec.builder.is_nailed:
|
||||
lines.append(
|
||||
f"raise InvalidFieldValue('{spec.field_ctx.name}',"
|
||||
f"{field_type},value,type(self))"
|
||||
)
|
||||
else:
|
||||
lines.append("raise ValueError(value)")
|
||||
lines.append(
|
||||
f"setattr({spec.cls_attrs_name}, '{method_name}', {method_name})"
|
||||
)
|
||||
if spec.builder.get_config().debug:
|
||||
print(f"{type_name(spec.builder.cls)}:")
|
||||
print(lines.as_text())
|
||||
exec(lines.as_text(), spec.builder.globals, spec.builder.__dict__)
|
||||
method_args = ", ".join(
|
||||
filter(None, (spec.expression, spec.builder.get_pack_method_flags()))
|
||||
)
|
||||
return f"{spec.self_attrs_name}.{method_name}({method_args})"
|
||||
|
||||
|
||||
@register
|
||||
def pack_special_typing_primitive(spec: ValueSpec) -> Optional[Expression]:
|
||||
if is_special_typing_primitive(spec.origin_type):
|
||||
if is_union(spec.type):
|
||||
resolved_type_params = spec.builder.get_field_resolved_type_params(
|
||||
spec.field_ctx.name
|
||||
)
|
||||
if is_optional(spec.type, resolved_type_params):
|
||||
arg = not_none_type_arg(
|
||||
get_args(spec.type), resolved_type_params
|
||||
)
|
||||
pv = PackerRegistry.get(spec.copy(type=arg))
|
||||
return expr_or_maybe_none(spec, pv)
|
||||
else:
|
||||
return pack_union(spec, get_args(spec.type))
|
||||
elif spec.origin_type is typing.AnyStr:
|
||||
raise UnserializableDataError(
|
||||
"AnyStr is not supported by mashumaro"
|
||||
)
|
||||
elif is_type_var_any(spec.type):
|
||||
return spec.expression
|
||||
elif is_type_var(spec.type):
|
||||
constraints = getattr(spec.type, "__constraints__")
|
||||
if constraints:
|
||||
return pack_union(spec, constraints, "type_var")
|
||||
else:
|
||||
if type_var_has_default(spec.type):
|
||||
bound = get_type_var_default(spec.type)
|
||||
else:
|
||||
bound = getattr(spec.type, "__bound__")
|
||||
# act as if it was Optional[bound]
|
||||
pv = PackerRegistry.get(spec.copy(type=bound))
|
||||
return expr_or_maybe_none(spec, pv)
|
||||
elif is_new_type(spec.type):
|
||||
return PackerRegistry.get(spec.copy(type=spec.type.__supertype__))
|
||||
elif is_literal(spec.type):
|
||||
return pack_literal(spec)
|
||||
elif spec.type is typing_extensions.LiteralString:
|
||||
return PackerRegistry.get(spec.copy(type=str))
|
||||
elif is_self(spec.type):
|
||||
method_name = spec.builder.get_pack_method_name(
|
||||
format_name=spec.builder.format_name
|
||||
)
|
||||
method_loc = (
|
||||
spec.builder.cls if spec.builder.is_nailed else spec.attrs
|
||||
)
|
||||
if (
|
||||
get_class_that_defines_method(method_name, method_loc)
|
||||
!= method_loc
|
||||
# not hasattr(self.cls, method_name)
|
||||
and spec.builder.get_pack_method_name(
|
||||
format_name=spec.builder.format_name,
|
||||
encoder=spec.builder.encoder,
|
||||
)
|
||||
!= method_name
|
||||
):
|
||||
builder = spec.builder.__class__(
|
||||
spec.builder.cls,
|
||||
dialect=spec.builder.dialect,
|
||||
format_name=spec.builder.format_name,
|
||||
default_dialect=spec.builder.default_dialect,
|
||||
attrs=method_loc,
|
||||
attrs_registry=(
|
||||
spec.attrs_registry
|
||||
if not spec.builder.is_nailed
|
||||
else None
|
||||
),
|
||||
)
|
||||
builder.add_pack_method()
|
||||
flags = spec.builder.get_pack_method_flags(spec.builder.cls)
|
||||
if spec.builder.is_nailed:
|
||||
return f"{spec.expression}.{method_name}({flags})"
|
||||
else:
|
||||
method_args = spec.expression
|
||||
return f"_cls.{method_name}({method_args})"
|
||||
elif is_required(spec.type) or is_not_required(spec.type):
|
||||
return PackerRegistry.get(spec.copy(type=get_args(spec.type)[0]))
|
||||
elif is_unpack(spec.type):
|
||||
packer = PackerRegistry.get(spec.copy(type=get_args(spec.type)[0]))
|
||||
return f"*{packer}"
|
||||
elif is_type_var_tuple(spec.type):
|
||||
return PackerRegistry.get(spec.copy(type=tuple[Any, ...]))
|
||||
elif isinstance(spec.type, ForwardRef):
|
||||
evaluated = spec.builder.evaluate_forward_ref(
|
||||
spec.type, spec.owner
|
||||
)
|
||||
if evaluated is not None:
|
||||
return PackerRegistry.get(spec.copy(type=evaluated))
|
||||
elif is_type_alias_type(spec.type):
|
||||
return PackerRegistry.get(spec.copy(type=spec.type.__value__))
|
||||
elif is_readonly(spec.type):
|
||||
return PackerRegistry.get(spec.copy(type=get_args(spec.type)[0]))
|
||||
raise UnserializableDataError(
|
||||
f"{spec.type} as a field type is not supported by mashumaro"
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
def pack_number_and_bool_and_none(spec: ValueSpec) -> Optional[Expression]:
|
||||
if spec.origin_type in (int, float, bool, NoneType, None):
|
||||
return spec.expression
|
||||
|
||||
|
||||
@register
|
||||
def pack_date_objects(spec: ValueSpec) -> Optional[Expression]:
|
||||
if spec.origin_type in (datetime.datetime, datetime.date, datetime.time):
|
||||
return f"{spec.expression}.isoformat()"
|
||||
|
||||
|
||||
@register
|
||||
def pack_timedelta(spec: ValueSpec) -> Optional[Expression]:
|
||||
if spec.origin_type is datetime.timedelta:
|
||||
return f"{spec.expression}.total_seconds()"
|
||||
|
||||
|
||||
@register
|
||||
def pack_timezone(spec: ValueSpec) -> Optional[Expression]:
|
||||
if spec.origin_type is datetime.timezone:
|
||||
return f"{spec.expression}.tzname(None)"
|
||||
|
||||
|
||||
@register
|
||||
def pack_zone_info(spec: ValueSpec) -> Optional[Expression]:
|
||||
if spec.origin_type is zoneinfo.ZoneInfo:
|
||||
return f"str({spec.expression})"
|
||||
|
||||
|
||||
@register
|
||||
def pack_uuid(spec: ValueSpec) -> Optional[Expression]:
|
||||
if spec.origin_type is uuid.UUID:
|
||||
return f"str({spec.expression})"
|
||||
|
||||
|
||||
@register
|
||||
def pack_ipaddress(spec: ValueSpec) -> Optional[Expression]:
|
||||
if spec.origin_type in (
|
||||
ipaddress.IPv4Address,
|
||||
ipaddress.IPv6Address,
|
||||
ipaddress.IPv4Network,
|
||||
ipaddress.IPv6Network,
|
||||
ipaddress.IPv4Interface,
|
||||
ipaddress.IPv6Interface,
|
||||
):
|
||||
return f"str({spec.expression})"
|
||||
|
||||
|
||||
@register
|
||||
def pack_decimal(spec: ValueSpec) -> Optional[Expression]:
|
||||
if spec.origin_type is Decimal:
|
||||
return f"str({spec.expression})"
|
||||
|
||||
|
||||
@register
|
||||
def pack_fraction(spec: ValueSpec) -> Optional[Expression]:
|
||||
if spec.origin_type is Fraction:
|
||||
return f"str({spec.expression})"
|
||||
|
||||
|
||||
def pack_tuple(spec: ValueSpec, args: tuple[type, ...]) -> Expression:
|
||||
if not args:
|
||||
if spec.type in (Tuple, tuple):
|
||||
args = [Any, ...] # type: ignore
|
||||
else:
|
||||
return "[]"
|
||||
elif len(args) == 1 and args[0] == ():
|
||||
if not PY_311_MIN:
|
||||
return "[]"
|
||||
if len(args) == 2 and args[1] is Ellipsis:
|
||||
packer = PackerRegistry.get(
|
||||
spec.copy(type=args[0], expression="value", could_be_none=True)
|
||||
)
|
||||
return f"[{packer} for value in {spec.expression}]"
|
||||
else:
|
||||
arg_indexes: list[Union[int, tuple[int, Union[int, None]]]] = []
|
||||
unpack_idx: Optional[int] = None
|
||||
for arg_idx, type_arg in enumerate(args):
|
||||
if is_unpack(type_arg):
|
||||
if unpack_idx is not None:
|
||||
raise TypeError(
|
||||
"Multiple unpacks are disallowed within a single type "
|
||||
f"parameter list for {type_name(spec.type)}"
|
||||
)
|
||||
unpack_idx = arg_idx
|
||||
if len(args) == 1:
|
||||
arg_indexes.append((arg_idx, None))
|
||||
elif arg_idx < len(args) - 1:
|
||||
arg_indexes.append((arg_idx, arg_idx + 1 - len(args)))
|
||||
else:
|
||||
arg_indexes.append((arg_idx, None))
|
||||
else:
|
||||
if unpack_idx is None:
|
||||
arg_indexes.append(arg_idx)
|
||||
else:
|
||||
arg_indexes.append(arg_idx - len(args))
|
||||
packers: list[Expression] = []
|
||||
for _idx, _arg_idx in enumerate(arg_indexes):
|
||||
if isinstance(_arg_idx, tuple):
|
||||
p_expr = f"{spec.expression}[{_arg_idx[0]}:{_arg_idx[1]}]"
|
||||
else:
|
||||
p_expr = f"{spec.expression}[{_arg_idx}]"
|
||||
packer = PackerRegistry.get(
|
||||
spec.copy(
|
||||
type=args[_idx],
|
||||
expression=p_expr,
|
||||
could_be_none=True,
|
||||
)
|
||||
)
|
||||
if packer != "*[]":
|
||||
packers.append(packer)
|
||||
return f"[{', '.join(packers)}]"
|
||||
|
||||
|
||||
def pack_named_tuple(spec: ValueSpec) -> Expression:
|
||||
resolved = resolve_type_params(spec.origin_type, get_args(spec.type))[
|
||||
spec.origin_type
|
||||
]
|
||||
annotations = {
|
||||
k: resolved.get(v, v)
|
||||
for k, v in getattr(spec.origin_type, "__annotations__", {}).items()
|
||||
}
|
||||
fields = getattr(spec.type, "_fields", ())
|
||||
packers = []
|
||||
as_dict = spec.builder.get_dialect_or_config_option(
|
||||
"namedtuple_as_dict", False
|
||||
)
|
||||
serialize_option = get_overridden_serialization_method(spec)
|
||||
if serialize_option is not None:
|
||||
if serialize_option == "as_dict":
|
||||
as_dict = True
|
||||
elif serialize_option == "as_list":
|
||||
as_dict = False
|
||||
else:
|
||||
raise UnsupportedSerializationEngine(
|
||||
field_name=spec.field_ctx.name,
|
||||
field_type=spec.type,
|
||||
holder_class=spec.builder.cls,
|
||||
engine=serialize_option,
|
||||
)
|
||||
for idx, field in enumerate(fields):
|
||||
packer = PackerRegistry.get(
|
||||
spec.copy(
|
||||
type=annotations.get(field, Any),
|
||||
expression=f"{spec.expression}[{idx}]",
|
||||
could_be_none=True,
|
||||
)
|
||||
)
|
||||
packers.append(packer)
|
||||
if as_dict:
|
||||
kv = (f"'{key}': {value}" for key, value in zip(fields, packers))
|
||||
return f"{{{', '.join(kv)}}}"
|
||||
else:
|
||||
return f"[{', '.join(packers)}]"
|
||||
|
||||
|
||||
def pack_typed_dict(spec: ValueSpec) -> Expression:
|
||||
resolved = resolve_type_params(spec.origin_type, get_args(spec.type))[
|
||||
spec.origin_type
|
||||
]
|
||||
annotations = {
|
||||
k: resolved.get(v, v)
|
||||
for k, v in spec.origin_type.__annotations__.items()
|
||||
}
|
||||
all_keys = list(annotations.keys())
|
||||
required_keys = getattr(spec.type, "__required_keys__", all_keys)
|
||||
optional_keys = getattr(spec.type, "__optional_keys__", [])
|
||||
lines = CodeLines()
|
||||
method_name = (
|
||||
f"__pack_typed_dict_{spec.builder.cls.__name__}_"
|
||||
f"{spec.field_ctx.name}__{random_hex()}"
|
||||
)
|
||||
method_args = "self, value" if spec.builder.is_nailed else "value"
|
||||
default_kwargs = spec.builder.get_pack_method_default_flag_values()
|
||||
if default_kwargs:
|
||||
lines.append(f"def {method_name}({method_args}, {default_kwargs}):")
|
||||
else:
|
||||
lines.append(f"def {method_name}({method_args}):")
|
||||
with lines.indent():
|
||||
lines.append("d = {}")
|
||||
for key in sorted(required_keys, key=all_keys.index):
|
||||
packer = PackerRegistry.get(
|
||||
spec.copy(
|
||||
type=annotations[key],
|
||||
expression=f"value['{key}']",
|
||||
could_be_none=True,
|
||||
owner=spec.type,
|
||||
)
|
||||
)
|
||||
lines.append(f"d['{key}'] = {packer}")
|
||||
for key in sorted(optional_keys, key=all_keys.index):
|
||||
lines.append(f"key_value = value.get('{key}', MISSING)")
|
||||
with lines.indent("if key_value is not MISSING:"):
|
||||
packer = PackerRegistry.get(
|
||||
spec.copy(
|
||||
type=annotations[key],
|
||||
expression="key_value",
|
||||
could_be_none=True,
|
||||
owner=spec.type,
|
||||
)
|
||||
)
|
||||
lines.append(f"d['{key}'] = {packer}")
|
||||
lines.append("return d")
|
||||
lines.append(
|
||||
f"setattr({spec.cls_attrs_name}, '{method_name}', {method_name})"
|
||||
)
|
||||
if spec.builder.get_config().debug:
|
||||
print(f"{type_name(spec.builder.cls)}:")
|
||||
print(lines.as_text())
|
||||
exec(lines.as_text(), spec.builder.globals, spec.builder.__dict__)
|
||||
method_args = ", ".join(
|
||||
filter(None, (spec.expression, spec.builder.get_pack_method_flags()))
|
||||
)
|
||||
return f"{spec.self_attrs_name}.{method_name}({method_args})"
|
||||
|
||||
|
||||
@register
|
||||
def pack_collection(spec: ValueSpec) -> Optional[Expression]:
|
||||
if not issubclass(spec.origin_type, Collection):
|
||||
return None
|
||||
elif issubclass(spec.origin_type, enum.Enum):
|
||||
return None
|
||||
|
||||
args = get_args(spec.type)
|
||||
|
||||
def inner_expr(
|
||||
arg_num: int = 0, v_name: str = "value", v_type: Optional[type] = None
|
||||
) -> Expression:
|
||||
if v_type:
|
||||
return PackerRegistry.get(
|
||||
spec.copy(type=v_type, expression=v_name)
|
||||
)
|
||||
else:
|
||||
if args and len(args) > arg_num:
|
||||
type_arg: Any = args[arg_num]
|
||||
else:
|
||||
type_arg = Any
|
||||
return PackerRegistry.get(
|
||||
spec.copy(
|
||||
type=type_arg,
|
||||
expression=v_name,
|
||||
could_be_none=True,
|
||||
field_ctx=spec.field_ctx.copy(metadata={}),
|
||||
)
|
||||
)
|
||||
|
||||
def _make_sequence_expression(ie: Expression) -> Expression:
|
||||
if ie == "value":
|
||||
if spec.origin_type in spec.no_copy_collections:
|
||||
return spec.expression
|
||||
elif spec.origin_type is list:
|
||||
return f"{spec.expression}.copy()"
|
||||
return f"[{ie} for value in {spec.expression}]"
|
||||
|
||||
def _make_mapping_expression(ke: Expression, ve: Expression) -> Expression:
|
||||
if ke == "key" and ve == "value":
|
||||
if spec.origin_type in spec.no_copy_collections:
|
||||
return spec.expression
|
||||
elif spec.origin_type is dict:
|
||||
return f"{spec.expression}.copy()"
|
||||
return f"{{{ke}: {ve} for key, value in {spec.expression}.items()}}"
|
||||
|
||||
if issubclass(spec.origin_type, typing.ByteString): # type: ignore
|
||||
spec.builder.ensure_object_imported(encodebytes)
|
||||
return f"encodebytes({spec.expression}).decode()"
|
||||
elif issubclass(spec.origin_type, str):
|
||||
return spec.expression
|
||||
elif issubclass(spec.origin_type, tuple):
|
||||
if is_named_tuple(spec.origin_type):
|
||||
return pack_named_tuple(spec)
|
||||
elif ensure_generic_collection(spec):
|
||||
return pack_tuple(spec, args)
|
||||
elif ensure_generic_collection_subclass(spec, list, deque, Set):
|
||||
ie = inner_expr()
|
||||
return _make_sequence_expression(ie)
|
||||
elif ensure_generic_mapping(spec, args, ChainMap):
|
||||
ke = inner_expr(0, "key")
|
||||
ve = inner_expr(1)
|
||||
return (
|
||||
f"[{{{ke}: {ve} for key, value in m.items()}} "
|
||||
f"for m in {spec.expression}.maps]"
|
||||
)
|
||||
elif ensure_generic_mapping(spec, args, OrderedDict):
|
||||
ke = inner_expr(0, "key")
|
||||
ve = inner_expr(1)
|
||||
return _make_mapping_expression(ke, ve)
|
||||
elif ensure_generic_mapping(spec, args, Counter):
|
||||
ke = inner_expr(0, "key")
|
||||
ve = inner_expr(1, v_type=int)
|
||||
return _make_mapping_expression(ke, ve)
|
||||
elif is_typed_dict(spec.origin_type):
|
||||
return pack_typed_dict(spec)
|
||||
elif ensure_generic_mapping(spec, args, Mapping):
|
||||
ke = inner_expr(0, "key")
|
||||
ve = inner_expr(1)
|
||||
return _make_mapping_expression(ke, ve)
|
||||
elif ensure_generic_collection_subclass(spec, Sequence):
|
||||
ie = inner_expr()
|
||||
return _make_sequence_expression(ie)
|
||||
|
||||
|
||||
@register
|
||||
def pack_pathlike(spec: ValueSpec) -> Optional[Expression]:
|
||||
if issubclass(spec.origin_type, os.PathLike):
|
||||
return f"{spec.expression}.__fspath__()"
|
||||
|
||||
|
||||
@register
|
||||
def pack_enum(spec: ValueSpec) -> Optional[Expression]:
|
||||
if issubclass(spec.origin_type, enum.Enum):
|
||||
return f"{spec.expression}.value"
|
||||
|
||||
|
||||
@register
|
||||
def pack_pattern(spec: ValueSpec) -> Optional[Expression]:
|
||||
if spec.origin_type in (typing.Pattern, re.Pattern):
|
||||
return f"{spec.expression}.pattern"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,62 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from types import new_class
|
||||
from typing import Any, Type, Union, cast
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from mashumaro.core.const import Sentinel
|
||||
from mashumaro.types import SerializationStrategy
|
||||
|
||||
__all__ = ["Dialect"]
|
||||
|
||||
|
||||
SerializationStrategyValueType = Union[
|
||||
SerializationStrategy, dict[str, Union[str, Callable]]
|
||||
]
|
||||
|
||||
|
||||
class Dialect:
|
||||
serialization_strategy: dict[Any, SerializationStrategyValueType] = {}
|
||||
serialize_by_alias: Union[bool, Literal[Sentinel.MISSING]] = (
|
||||
Sentinel.MISSING
|
||||
)
|
||||
namedtuple_as_dict: Union[bool, Literal[Sentinel.MISSING]] = (
|
||||
Sentinel.MISSING
|
||||
)
|
||||
omit_none: Union[bool, Literal[Sentinel.MISSING]] = Sentinel.MISSING
|
||||
omit_default: Union[bool, Literal[Sentinel.MISSING]] = Sentinel.MISSING
|
||||
no_copy_collections: Union[Sequence[Any], Literal[Sentinel.MISSING]] = (
|
||||
Sentinel.MISSING
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def merge(cls, other: Type["Dialect"]) -> Type["Dialect"]:
|
||||
serialization_strategy: dict[Any, SerializationStrategyValueType] = {}
|
||||
for key, value in cls.serialization_strategy.items():
|
||||
if isinstance(value, SerializationStrategy):
|
||||
serialization_strategy[key] = value
|
||||
else:
|
||||
serialization_strategy[key] = value.copy()
|
||||
for key, value in other.serialization_strategy.items():
|
||||
if isinstance(value, SerializationStrategy):
|
||||
serialization_strategy[key] = value
|
||||
elif isinstance(
|
||||
serialization_strategy.get(key), SerializationStrategy
|
||||
):
|
||||
serialization_strategy[key] = value
|
||||
else:
|
||||
(
|
||||
serialization_strategy.setdefault(
|
||||
key, {}
|
||||
).update( # type: ignore
|
||||
value
|
||||
)
|
||||
)
|
||||
new_dialect = cast(Type[Dialect], new_class("Dialect", (Dialect,)))
|
||||
new_dialect.serialization_strategy = serialization_strategy
|
||||
for key in ("omit_none", "omit_default", "no_copy_collections"):
|
||||
if (others_value := getattr(other, key)) is not Sentinel.MISSING:
|
||||
setattr(new_dialect, key, others_value)
|
||||
else:
|
||||
setattr(new_dialect, key, getattr(cls, key))
|
||||
return new_dialect
|
||||
@@ -0,0 +1,216 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from mashumaro.core.meta.helpers import type_name
|
||||
|
||||
|
||||
class MissingField(LookupError):
|
||||
def __init__(self, field_name: str, field_type: Type, holder_class: Type):
|
||||
self.field_name = field_name
|
||||
self.field_type = field_type
|
||||
self.holder_class = holder_class
|
||||
|
||||
@property
|
||||
def field_type_name(self) -> str:
|
||||
return type_name(self.field_type, short=True)
|
||||
|
||||
@property
|
||||
def holder_class_name(self) -> str:
|
||||
return type_name(self.holder_class, short=True)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'Field "{self.field_name}" of type {self.field_type_name}'
|
||||
f" is missing in {self.holder_class_name} instance"
|
||||
)
|
||||
|
||||
|
||||
class ExtraKeysError(ValueError):
|
||||
def __init__(self, extra_keys: set[str], target_type: Type):
|
||||
self.extra_keys = extra_keys
|
||||
self.target_type = target_type
|
||||
|
||||
@property
|
||||
def target_class_name(self) -> str:
|
||||
return type_name(self.target_type, short=True)
|
||||
|
||||
def __str__(self) -> str:
|
||||
extra_keys_str = ", ".join(k for k in self.extra_keys)
|
||||
return (
|
||||
"Serialized dict has keys that are not defined in "
|
||||
f"{self.target_class_name}: {extra_keys_str}"
|
||||
)
|
||||
|
||||
|
||||
class UnserializableDataError(TypeError):
|
||||
pass
|
||||
|
||||
|
||||
class UnserializableField(UnserializableDataError):
|
||||
def __init__(
|
||||
self,
|
||||
field_name: str,
|
||||
field_type: Type,
|
||||
holder_class: Type,
|
||||
msg: Optional[str] = None,
|
||||
):
|
||||
self.field_name = field_name
|
||||
self.field_type = field_type
|
||||
self.holder_class = holder_class
|
||||
self.msg = msg
|
||||
|
||||
@property
|
||||
def field_type_name(self) -> str:
|
||||
return type_name(self.field_type, short=True)
|
||||
|
||||
@property
|
||||
def holder_class_name(self) -> str:
|
||||
return type_name(self.holder_class, short=True)
|
||||
|
||||
def __str__(self) -> str:
|
||||
s = (
|
||||
f'Field "{self.field_name}" of type {self.field_type_name} '
|
||||
f"in {self.holder_class_name} is not serializable"
|
||||
)
|
||||
if self.msg:
|
||||
s += f": {self.msg}"
|
||||
return s
|
||||
|
||||
|
||||
class UnsupportedSerializationEngine(UnserializableField):
|
||||
def __init__(
|
||||
self,
|
||||
field_name: str,
|
||||
field_type: Type,
|
||||
holder_class: Type,
|
||||
engine: Any,
|
||||
):
|
||||
super(UnsupportedSerializationEngine, self).__init__(
|
||||
field_name,
|
||||
field_type,
|
||||
holder_class,
|
||||
msg=f'Unsupported serialization engine "{engine}"',
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedDeserializationEngine(UnserializableField):
|
||||
def __init__(
|
||||
self,
|
||||
field_name: str,
|
||||
field_type: Type,
|
||||
holder_class: Type,
|
||||
engine: Any,
|
||||
):
|
||||
super(UnsupportedDeserializationEngine, self).__init__(
|
||||
field_name,
|
||||
field_type,
|
||||
holder_class,
|
||||
msg=f'Unsupported deserialization engine "{engine}"',
|
||||
)
|
||||
|
||||
|
||||
class InvalidFieldValue(ValueError):
|
||||
def __init__(
|
||||
self,
|
||||
field_name: str,
|
||||
field_type: Type,
|
||||
field_value: Any,
|
||||
holder_class: Type,
|
||||
msg: Optional[str] = None,
|
||||
):
|
||||
self.field_name = field_name
|
||||
self.field_type = field_type
|
||||
self.field_value = field_value
|
||||
self.holder_class = holder_class
|
||||
self.msg = msg
|
||||
|
||||
@property
|
||||
def field_type_name(self) -> str:
|
||||
return type_name(self.field_type, short=True)
|
||||
|
||||
@property
|
||||
def holder_class_name(self) -> str:
|
||||
return type_name(self.holder_class, short=True)
|
||||
|
||||
def __str__(self) -> str:
|
||||
s = (
|
||||
f'Field "{self.field_name}" of type {self.field_type_name} '
|
||||
f"in {self.holder_class_name} has invalid value "
|
||||
f"{repr(self.field_value)}"
|
||||
)
|
||||
if self.msg:
|
||||
s += f": {self.msg}"
|
||||
return s
|
||||
|
||||
|
||||
class MissingDiscriminatorError(LookupError):
|
||||
def __init__(self, field_name: str):
|
||||
self.field_name = field_name
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Discriminator '{self.field_name}' is missing"
|
||||
|
||||
|
||||
class SuitableVariantNotFoundError(ValueError):
|
||||
def __init__(
|
||||
self,
|
||||
variants_type: Type,
|
||||
discriminator_name: Optional[str] = None,
|
||||
discriminator_value: Any = None,
|
||||
):
|
||||
self.variants_type = variants_type
|
||||
self.discriminator_name = discriminator_name
|
||||
self.discriminator_value = discriminator_value
|
||||
|
||||
def __str__(self) -> str:
|
||||
s = f"{type_name(self.variants_type)} has no "
|
||||
if self.discriminator_value is not None:
|
||||
s += (
|
||||
f"subtype with attribute '{self.discriminator_name}' "
|
||||
f"equal to {self.discriminator_value!r}"
|
||||
)
|
||||
else:
|
||||
s += "suitable subtype"
|
||||
return s
|
||||
|
||||
|
||||
class BadHookSignature(TypeError):
|
||||
pass
|
||||
|
||||
|
||||
class ThirdPartyModuleNotFoundError(ModuleNotFoundError):
|
||||
def __init__(self, module_name: str, field_name: str, holder_class: Type):
|
||||
self.module_name = module_name
|
||||
self.field_name = field_name
|
||||
self.holder_class = holder_class
|
||||
|
||||
@property
|
||||
def holder_class_name(self) -> str:
|
||||
return type_name(self.holder_class, short=True)
|
||||
|
||||
def __str__(self) -> str:
|
||||
s = (
|
||||
f'Install "{self.module_name}" to use it as the serialization '
|
||||
f'method for the field "{self.field_name}" '
|
||||
f"in {self.holder_class_name}"
|
||||
)
|
||||
return s
|
||||
|
||||
|
||||
class UnresolvedTypeReferenceError(NameError):
|
||||
def __init__(self, holder_class: Type, unresolved_type_name: str):
|
||||
self.holder_class = holder_class
|
||||
self.name = unresolved_type_name
|
||||
|
||||
@property
|
||||
def holder_class_name(self) -> str:
|
||||
return type_name(self.holder_class, short=True)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"Class {self.holder_class_name} has unresolved type reference "
|
||||
f"{self.name} in some of its fields"
|
||||
)
|
||||
|
||||
|
||||
class BadDialect(ValueError):
|
||||
pass
|
||||
@@ -0,0 +1,59 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional, TypeVar, Union
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from mashumaro.types import SerializationStrategy
|
||||
|
||||
__all__ = [
|
||||
"field_options",
|
||||
"pass_through",
|
||||
]
|
||||
|
||||
|
||||
NamedTupleDeserializationEngine = Literal["as_dict", "as_list"]
|
||||
DateTimeDeserializationEngine = Literal["ciso8601", "pendulum"]
|
||||
AnyDeserializationEngine = Literal[
|
||||
NamedTupleDeserializationEngine, DateTimeDeserializationEngine
|
||||
]
|
||||
|
||||
NamedTupleSerializationEngine = Literal["as_dict", "as_list"]
|
||||
OmitSerializationEngine = Literal["omit"]
|
||||
AnySerializationEngine = Union[
|
||||
NamedTupleSerializationEngine, OmitSerializationEngine
|
||||
]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def field_options(
|
||||
serialize: Optional[
|
||||
Union[AnySerializationEngine, Callable[[Any], Any]]
|
||||
] = None,
|
||||
deserialize: Optional[
|
||||
Union[AnyDeserializationEngine, Callable[[Any], Any]]
|
||||
] = None,
|
||||
serialization_strategy: Optional[SerializationStrategy] = None,
|
||||
alias: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"serialize": serialize,
|
||||
"deserialize": deserialize,
|
||||
"serialization_strategy": serialization_strategy,
|
||||
"alias": alias,
|
||||
}
|
||||
|
||||
|
||||
class _PassThrough(SerializationStrategy):
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def serialize(self, value: T) -> T:
|
||||
return value
|
||||
|
||||
def deserialize(self, value: T) -> T:
|
||||
return value
|
||||
|
||||
|
||||
pass_through = _PassThrough()
|
||||
@@ -0,0 +1,9 @@
|
||||
from .builder import JSONSchemaBuilder, build_json_schema
|
||||
from .dialects import DRAFT_2020_12, OPEN_API_3_1
|
||||
|
||||
__all__ = [
|
||||
"JSONSchemaBuilder",
|
||||
"build_json_schema",
|
||||
"DRAFT_2020_12",
|
||||
"OPEN_API_3_1",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,134 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mashumaro.jsonschema.models import JSONSchema, Number
|
||||
|
||||
|
||||
class Annotation:
|
||||
pass
|
||||
|
||||
|
||||
class Constraint(Annotation):
|
||||
pass
|
||||
|
||||
|
||||
class NumberConstraint(Constraint):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class Minimum(NumberConstraint):
|
||||
value: Number
|
||||
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class Maximum(NumberConstraint):
|
||||
value: Number
|
||||
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class ExclusiveMinimum(NumberConstraint):
|
||||
value: Number
|
||||
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class ExclusiveMaximum(NumberConstraint):
|
||||
value: Number
|
||||
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class MultipleOf(NumberConstraint):
|
||||
value: Number
|
||||
|
||||
|
||||
class StringConstraint(Constraint):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class MinLength(StringConstraint):
|
||||
value: int
|
||||
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class MaxLength(StringConstraint):
|
||||
value: int
|
||||
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class Pattern(StringConstraint):
|
||||
value: str
|
||||
|
||||
|
||||
class ArrayConstraint(Constraint):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class MinItems(ArrayConstraint):
|
||||
value: int
|
||||
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class MaxItems(ArrayConstraint):
|
||||
value: int
|
||||
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class UniqueItems(ArrayConstraint):
|
||||
value: bool
|
||||
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class Contains(ArrayConstraint):
|
||||
value: JSONSchema
|
||||
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class MinContains(ArrayConstraint):
|
||||
value: int
|
||||
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class MaxContains(ArrayConstraint):
|
||||
value: int
|
||||
|
||||
|
||||
class ObjectConstraint(Constraint):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class MaxProperties(ObjectConstraint):
|
||||
value: int
|
||||
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class MinProperties(ObjectConstraint):
|
||||
value: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class DependentRequired(ObjectConstraint):
|
||||
value: dict[str, set[str]]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Annotation",
|
||||
"MultipleOf",
|
||||
"Maximum",
|
||||
"ExclusiveMaximum",
|
||||
"Minimum",
|
||||
"ExclusiveMinimum",
|
||||
"MaxLength",
|
||||
"MinLength",
|
||||
"Pattern",
|
||||
"MaxItems",
|
||||
"MinItems",
|
||||
"UniqueItems",
|
||||
"Contains",
|
||||
"MaxContains",
|
||||
"MinContains",
|
||||
"MaxProperties",
|
||||
"MinProperties",
|
||||
"DependentRequired",
|
||||
]
|
||||
@@ -0,0 +1,97 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from mashumaro.jsonschema.dialects import DRAFT_2020_12, JSONSchemaDialect
|
||||
from mashumaro.jsonschema.models import Context, JSONSchema
|
||||
from mashumaro.jsonschema.plugins import BasePlugin
|
||||
from mashumaro.jsonschema.schema import Instance, get_schema
|
||||
|
||||
try:
|
||||
from mashumaro.mixins.orjson import (
|
||||
DataClassORJSONMixin as DataClassJSONMixin,
|
||||
)
|
||||
except ImportError: # pragma: no cover
|
||||
from mashumaro.mixins.json import DataClassJSONMixin # type: ignore
|
||||
|
||||
|
||||
def build_json_schema(
|
||||
instance_type: Type,
|
||||
context: Optional[Context] = None,
|
||||
with_definitions: bool = True,
|
||||
all_refs: Optional[bool] = None,
|
||||
with_dialect_uri: bool = False,
|
||||
dialect: Optional[JSONSchemaDialect] = None,
|
||||
ref_prefix: Optional[str] = None,
|
||||
plugins: Sequence[BasePlugin] = (),
|
||||
) -> JSONSchema:
|
||||
if context is None:
|
||||
context = Context()
|
||||
else:
|
||||
context = Context(
|
||||
dialect=context.dialect,
|
||||
definitions=context.definitions,
|
||||
all_refs=context.all_refs,
|
||||
ref_prefix=context.ref_prefix,
|
||||
plugins=context.plugins,
|
||||
)
|
||||
if dialect is not None:
|
||||
context.dialect = dialect
|
||||
if all_refs is not None:
|
||||
context.all_refs = all_refs
|
||||
elif context.all_refs is None:
|
||||
context.all_refs = context.dialect.all_refs
|
||||
if ref_prefix is not None:
|
||||
context.ref_prefix = ref_prefix.rstrip("/")
|
||||
elif context.ref_prefix is None:
|
||||
context.ref_prefix = context.dialect.definitions_root_pointer
|
||||
if plugins:
|
||||
context.plugins = plugins
|
||||
instance = Instance(instance_type)
|
||||
schema = get_schema(instance, context, with_dialect_uri=with_dialect_uri)
|
||||
if with_definitions and context.definitions:
|
||||
schema.definitions = context.definitions
|
||||
return schema
|
||||
|
||||
|
||||
@dataclass
|
||||
class JSONSchemaDefinitions(DataClassJSONMixin):
|
||||
definitions: dict[str, JSONSchema]
|
||||
|
||||
def __post_serialize__( # type: ignore
|
||||
self, d: dict[Any, Any]
|
||||
) -> list[dict[str, Any]]:
|
||||
return d["definitions"]
|
||||
|
||||
|
||||
class JSONSchemaBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
dialect: JSONSchemaDialect = DRAFT_2020_12,
|
||||
all_refs: Optional[bool] = None,
|
||||
ref_prefix: Optional[str] = None,
|
||||
plugins: Sequence[BasePlugin] = (),
|
||||
):
|
||||
if all_refs is None:
|
||||
all_refs = dialect.all_refs
|
||||
if ref_prefix is None:
|
||||
ref_prefix = dialect.definitions_root_pointer
|
||||
self.context = Context(
|
||||
dialect=dialect,
|
||||
all_refs=all_refs,
|
||||
ref_prefix=ref_prefix.rstrip("/"),
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
def build(self, instance_type: Type) -> JSONSchema:
|
||||
return build_json_schema(
|
||||
instance_type=instance_type,
|
||||
context=self.context,
|
||||
with_definitions=False,
|
||||
)
|
||||
|
||||
def get_definitions(self) -> JSONSchemaDefinitions:
|
||||
return JSONSchemaDefinitions(self.context.definitions)
|
||||
|
||||
|
||||
__all__ = ["JSONSchemaBuilder", "build_json_schema"]
|
||||
@@ -0,0 +1,29 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JSONSchemaDialect:
|
||||
uri: str
|
||||
definitions_root_pointer: str
|
||||
all_refs: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JSONSchemaDraft202012Dialect(JSONSchemaDialect):
|
||||
uri: str = "https://json-schema.org/draft/2020-12/schema"
|
||||
definitions_root_pointer: str = "#/$defs"
|
||||
all_refs: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OpenAPISchema31Dialect(JSONSchemaDialect):
|
||||
uri: str = "https://spec.openapis.org/oas/3.1/dialect/base"
|
||||
definitions_root_pointer: str = "#/components/schemas"
|
||||
all_refs: bool = True
|
||||
|
||||
|
||||
DRAFT_2020_12 = JSONSchemaDraft202012Dialect()
|
||||
OPEN_API_3_1 = OpenAPISchema31Dialect()
|
||||
|
||||
|
||||
__all__ = ["JSONSchemaDialect", "DRAFT_2020_12", "OPEN_API_3_1"]
|
||||
@@ -0,0 +1,212 @@
|
||||
import datetime
|
||||
import ipaddress
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import MISSING, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from typing_extensions import TYPE_CHECKING, Self, TypeAlias
|
||||
|
||||
from mashumaro.config import BaseConfig
|
||||
from mashumaro.core.meta.helpers import iter_all_subclasses
|
||||
from mashumaro.helper import pass_through
|
||||
from mashumaro.jsonschema.dialects import DRAFT_2020_12, JSONSchemaDialect
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from mashumaro.jsonschema.plugins import BasePlugin
|
||||
else:
|
||||
BasePlugin = Any
|
||||
|
||||
try:
|
||||
from mashumaro.mixins.orjson import (
|
||||
DataClassORJSONMixin as DataClassJSONMixin,
|
||||
)
|
||||
except ImportError: # pragma: no cover
|
||||
from mashumaro.mixins.json import DataClassJSONMixin # type: ignore
|
||||
|
||||
|
||||
# https://github.com/python/mypy/issues/3186
|
||||
Number: TypeAlias = Union[int, float]
|
||||
|
||||
Null = object()
|
||||
|
||||
|
||||
class JSONSchemaInstanceType(Enum):
|
||||
NULL = "null"
|
||||
BOOLEAN = "boolean"
|
||||
OBJECT = "object"
|
||||
ARRAY = "array"
|
||||
NUMBER = "number"
|
||||
STRING = "string"
|
||||
INTEGER = "integer"
|
||||
|
||||
|
||||
class JSONSchemaInstanceFormat(Enum):
|
||||
pass
|
||||
|
||||
|
||||
class JSONSchemaStringFormat(JSONSchemaInstanceFormat):
|
||||
DATETIME = "date-time"
|
||||
DATE = "date"
|
||||
TIME = "time"
|
||||
DURATION = "duration"
|
||||
EMAIL = "email"
|
||||
IDN_EMAIL = "idn-email"
|
||||
HOSTNAME = "hostname"
|
||||
IDN_HOSTNAME = "idn-hostname"
|
||||
IPV4ADDRESS = "ipv4"
|
||||
IPV6ADDRESS = "ipv6"
|
||||
URI = "uri"
|
||||
URI_REFERENCE = "uri-reference"
|
||||
IRI = "iri"
|
||||
IRI_REFERENCE = "iri-reference"
|
||||
UUID = "uuid"
|
||||
URI_TEMPLATE = "uri-template"
|
||||
JSON_POINTER = "json-pointer"
|
||||
RELATIVE_JSON_POINTER = "relative-json-pointer"
|
||||
REGEX = "regex"
|
||||
|
||||
|
||||
class JSONSchemaInstanceFormatExtension(JSONSchemaInstanceFormat):
|
||||
TIMEDELTA = "time-delta"
|
||||
TIME_ZONE = "time-zone"
|
||||
IPV4NETWORK = "ipv4network"
|
||||
IPV6NETWORK = "ipv6network"
|
||||
IPV4INTERFACE = "ipv4interface"
|
||||
IPV6INTERFACE = "ipv6interface"
|
||||
DECIMAL = "decimal"
|
||||
FRACTION = "fraction"
|
||||
BASE64 = "base64"
|
||||
PATH = "path"
|
||||
|
||||
|
||||
DATETIME_FORMATS = {
|
||||
datetime.datetime: JSONSchemaStringFormat.DATETIME,
|
||||
datetime.date: JSONSchemaStringFormat.DATE,
|
||||
datetime.time: JSONSchemaStringFormat.TIME,
|
||||
}
|
||||
|
||||
|
||||
IPADDRESS_FORMATS = {
|
||||
ipaddress.IPv4Address: JSONSchemaStringFormat.IPV4ADDRESS,
|
||||
ipaddress.IPv6Address: JSONSchemaStringFormat.IPV6ADDRESS,
|
||||
ipaddress.IPv4Network: JSONSchemaInstanceFormatExtension.IPV4NETWORK,
|
||||
ipaddress.IPv6Network: JSONSchemaInstanceFormatExtension.IPV6NETWORK,
|
||||
ipaddress.IPv4Interface: JSONSchemaInstanceFormatExtension.IPV4INTERFACE,
|
||||
ipaddress.IPv6Interface: JSONSchemaInstanceFormatExtension.IPV6INTERFACE,
|
||||
}
|
||||
|
||||
|
||||
def _deserialize_json_schema_instance_format(
|
||||
value: Any,
|
||||
) -> JSONSchemaInstanceFormat:
|
||||
for cls in iter_all_subclasses(JSONSchemaInstanceFormat):
|
||||
try:
|
||||
return cls(value)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
raise ValueError(value)
|
||||
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class JSONSchema(DataClassJSONMixin):
|
||||
# Common keywords
|
||||
schema: Optional[str] = None
|
||||
type: Optional[JSONSchemaInstanceType] = None
|
||||
enum: Optional[list[Any]] = None
|
||||
const: Optional[Any] = field(default_factory=lambda: MISSING)
|
||||
format: Optional[JSONSchemaInstanceFormat] = None
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
anyOf: Optional[List["JSONSchema"]] = None
|
||||
reference: Optional[str] = None
|
||||
definitions: Optional[Dict[str, "JSONSchema"]] = None
|
||||
default: Optional[Any] = field(default_factory=lambda: MISSING)
|
||||
deprecated: Optional[bool] = None
|
||||
examples: Optional[list[Any]] = None
|
||||
# Keywords for Objects
|
||||
properties: Optional[Dict[str, "JSONSchema"]] = None
|
||||
patternProperties: Optional[Dict[str, "JSONSchema"]] = None
|
||||
additionalProperties: Union["JSONSchema", bool, None] = None
|
||||
propertyNames: Optional["JSONSchema"] = None
|
||||
# Keywords for Arrays
|
||||
prefixItems: Optional[List["JSONSchema"]] = None
|
||||
items: Optional["JSONSchema"] = None
|
||||
contains: Optional["JSONSchema"] = None
|
||||
# Validation keywords for numeric instances
|
||||
multipleOf: Optional[Number] = None
|
||||
maximum: Optional[Number] = None
|
||||
exclusiveMaximum: Optional[Number] = None
|
||||
minimum: Optional[Number] = None
|
||||
exclusiveMinimum: Optional[Number] = None
|
||||
# Validation keywords for Strings
|
||||
maxLength: Optional[int] = None
|
||||
minLength: Optional[int] = None
|
||||
pattern: Optional[str] = None
|
||||
# Validation keywords for Arrays
|
||||
maxItems: Optional[int] = None
|
||||
minItems: Optional[int] = None
|
||||
uniqueItems: Optional[bool] = None
|
||||
maxContains: Optional[int] = None
|
||||
minContains: Optional[int] = None
|
||||
# Validation keywords for Objects
|
||||
maxProperties: Optional[int] = None
|
||||
minProperties: Optional[int] = None
|
||||
required: Optional[list[str]] = None
|
||||
dependentRequired: Optional[dict[str, set[str]]] = None
|
||||
|
||||
class Config(BaseConfig):
|
||||
omit_none = True
|
||||
serialize_by_alias = True
|
||||
aliases = {
|
||||
"schema": "$schema",
|
||||
"reference": "$ref",
|
||||
"definitions": "$defs",
|
||||
}
|
||||
serialization_strategy = {
|
||||
int: pass_through,
|
||||
float: pass_through,
|
||||
Null: pass_through,
|
||||
JSONSchemaInstanceFormat: {
|
||||
"deserialize": _deserialize_json_schema_instance_format,
|
||||
},
|
||||
}
|
||||
|
||||
def __pre_serialize__(self) -> Self:
|
||||
if self.const is None:
|
||||
self.const = Null
|
||||
if self.default is None:
|
||||
self.default = Null
|
||||
return self
|
||||
|
||||
def __post_serialize__(self, d: dict[Any, Any]) -> dict[Any, Any]:
|
||||
const = d.get("const")
|
||||
if const is MISSING:
|
||||
d.pop("const")
|
||||
elif const is Null:
|
||||
d["const"] = None
|
||||
default = d.get("default")
|
||||
if default is MISSING:
|
||||
d.pop("default")
|
||||
elif default is Null:
|
||||
d["default"] = None
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class JSONObjectSchema(JSONSchema):
|
||||
type: Optional[JSONSchemaInstanceType] = JSONSchemaInstanceType.OBJECT
|
||||
|
||||
|
||||
@dataclass
|
||||
class JSONArraySchema(JSONSchema):
|
||||
type: Optional[JSONSchemaInstanceType] = JSONSchemaInstanceType.ARRAY
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
dialect: JSONSchemaDialect = DRAFT_2020_12
|
||||
definitions: dict[str, JSONSchema] = field(default_factory=dict)
|
||||
all_refs: Optional[bool] = None
|
||||
ref_prefix: Optional[str] = None
|
||||
plugins: Sequence[BasePlugin] = ()
|
||||
@@ -0,0 +1,28 @@
|
||||
from dataclasses import is_dataclass
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
|
||||
from mashumaro.jsonschema.models import Context, JSONSchema
|
||||
from mashumaro.jsonschema.schema import Instance
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
def get_schema(
|
||||
self,
|
||||
instance: Instance,
|
||||
ctx: Context,
|
||||
schema: Optional[JSONSchema] = None,
|
||||
) -> Optional[JSONSchema]:
|
||||
pass
|
||||
|
||||
|
||||
class DocstringDescriptionPlugin(BasePlugin):
|
||||
def get_schema(
|
||||
self,
|
||||
instance: Instance,
|
||||
ctx: Context,
|
||||
schema: Optional[JSONSchema] = None,
|
||||
) -> Optional[JSONSchema]:
|
||||
if schema and is_dataclass(instance.type) and instance.type.__doc__:
|
||||
schema.description = cleandoc(instance.type.__doc__)
|
||||
return None
|
||||
@@ -0,0 +1,886 @@
|
||||
import datetime
|
||||
import ipaddress
|
||||
import os
|
||||
import warnings
|
||||
from base64 import encodebytes
|
||||
from collections import ChainMap, Counter, deque
|
||||
from collections.abc import (
|
||||
ByteString,
|
||||
Callable,
|
||||
Collection,
|
||||
Iterable,
|
||||
Mapping,
|
||||
Sequence,
|
||||
Set,
|
||||
)
|
||||
from dataclasses import MISSING, dataclass, field, is_dataclass, replace
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from functools import cached_property
|
||||
from typing import Any, ForwardRef, Optional, Tuple, Type, Union
|
||||
from uuid import UUID
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from mashumaro.config import BaseConfig
|
||||
from mashumaro.core.const import PY_311_MIN
|
||||
from mashumaro.core.meta.code.builder import CodeBuilder
|
||||
from mashumaro.core.meta.helpers import (
|
||||
evaluate_forward_ref,
|
||||
get_args,
|
||||
get_forward_ref_referencing_globals,
|
||||
get_function_return_annotation,
|
||||
get_literal_values,
|
||||
get_type_origin,
|
||||
is_annotated,
|
||||
is_generic,
|
||||
is_literal,
|
||||
is_named_tuple,
|
||||
is_new_type,
|
||||
is_not_required,
|
||||
is_readonly,
|
||||
is_required,
|
||||
is_special_typing_primitive,
|
||||
is_type_var,
|
||||
is_type_var_any,
|
||||
is_type_var_tuple,
|
||||
is_typed_dict,
|
||||
is_union,
|
||||
is_unpack,
|
||||
resolve_type_params,
|
||||
type_name,
|
||||
)
|
||||
from mashumaro.core.meta.types.common import NoneType
|
||||
from mashumaro.helper import pass_through
|
||||
from mashumaro.jsonschema.annotations import (
|
||||
Annotation,
|
||||
Contains,
|
||||
DependentRequired,
|
||||
ExclusiveMaximum,
|
||||
ExclusiveMinimum,
|
||||
MaxContains,
|
||||
Maximum,
|
||||
MaxItems,
|
||||
MaxLength,
|
||||
MaxProperties,
|
||||
MinContains,
|
||||
Minimum,
|
||||
MinItems,
|
||||
MinLength,
|
||||
MinProperties,
|
||||
MultipleOf,
|
||||
Pattern,
|
||||
UniqueItems,
|
||||
)
|
||||
from mashumaro.jsonschema.models import (
|
||||
DATETIME_FORMATS,
|
||||
IPADDRESS_FORMATS,
|
||||
Context,
|
||||
JSONArraySchema,
|
||||
JSONObjectSchema,
|
||||
JSONSchema,
|
||||
JSONSchemaInstanceFormatExtension,
|
||||
JSONSchemaInstanceType,
|
||||
JSONSchemaStringFormat,
|
||||
)
|
||||
from mashumaro.types import SerializationStrategy
|
||||
|
||||
try:
|
||||
from mashumaro.mixins.orjson import (
|
||||
DataClassORJSONMixin as DataClassJSONMixin,
|
||||
)
|
||||
except ImportError: # pragma: no cover
|
||||
from mashumaro.mixins.json import DataClassJSONMixin # type: ignore
|
||||
|
||||
|
||||
UTC_OFFSET_PATTERN = r"^UTC([+-][0-2][0-9]:[0-5][0-9])?$"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Instance:
|
||||
type: Type
|
||||
name: Optional[str] = None
|
||||
|
||||
__owner_builder: Optional[CodeBuilder] = None
|
||||
__self_builder: Optional[CodeBuilder] = None
|
||||
|
||||
# Original type despite custom serialization. To be revised.
|
||||
_original_type: Type = field(init=False)
|
||||
|
||||
origin_type: Type = field(init=False)
|
||||
annotations: list[Annotation] = field(init=False, default_factory=list)
|
||||
|
||||
@cached_property
|
||||
def metadata(self) -> dict[str, Any]:
|
||||
if self.name and self.__owner_builder:
|
||||
return dict(**self.__owner_builder.metadatas.get(self.name, {}))
|
||||
else:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def _self_builder(self) -> CodeBuilder:
|
||||
assert self.__self_builder
|
||||
return self.__self_builder
|
||||
|
||||
@property
|
||||
def alias(self) -> Optional[str]:
|
||||
alias = self.metadata.get("alias")
|
||||
if alias is None:
|
||||
aliases_config = self.get_owner_config().aliases
|
||||
alias = aliases_config.get(self.name) # type: ignore
|
||||
if alias is None:
|
||||
alias = self.name
|
||||
return alias
|
||||
|
||||
@property
|
||||
def owner_class(self) -> Optional[Type]:
|
||||
if self.__owner_builder:
|
||||
return self.__owner_builder.cls
|
||||
return None
|
||||
|
||||
def derive(self, **changes: Any) -> "Instance":
|
||||
new_type = changes.get("type")
|
||||
if isinstance(new_type, ForwardRef):
|
||||
changes["type"] = evaluate_forward_ref(
|
||||
new_type,
|
||||
get_forward_ref_referencing_globals(new_type, self.type),
|
||||
self.__dict__,
|
||||
)
|
||||
new_instance = replace(self, **changes)
|
||||
if is_dataclass(self.origin_type):
|
||||
new_instance.__owner_builder = self.__self_builder
|
||||
return new_instance
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._original_type = self.type
|
||||
self.update_type(self.type)
|
||||
if is_annotated(self.type):
|
||||
self.annotations = getattr(self.type, "__metadata__", [])
|
||||
self.type = get_args(self.type)[0]
|
||||
self.origin_type = get_type_origin(self.type)
|
||||
|
||||
def update_type(self, new_type: Type) -> None:
|
||||
if self.__owner_builder:
|
||||
self.type = self.__owner_builder.get_real_type(
|
||||
field_name=self.name, # type: ignore
|
||||
field_type=new_type,
|
||||
)
|
||||
self.origin_type = get_type_origin(self.type)
|
||||
if is_dataclass(self.origin_type):
|
||||
type_args = get_args(self.type)
|
||||
self.__self_builder = CodeBuilder(self.origin_type, type_args)
|
||||
self.__self_builder.reset()
|
||||
else:
|
||||
self.__self_builder = None
|
||||
|
||||
def fields(self) -> Iterable[tuple[str, Type, bool, Any]]:
|
||||
for f_name, f_type in self._self_builder.get_field_types(
|
||||
include_extras=True
|
||||
).items():
|
||||
f = self._self_builder.dataclass_fields.get(f_name)
|
||||
if not f or f and not f.init:
|
||||
continue
|
||||
f_default = f.default
|
||||
if f_default is MISSING:
|
||||
f_default = self._self_builder.namespace.get(f_name, MISSING)
|
||||
if f_default is not MISSING:
|
||||
f_default = _default(f_type, f_default, self.get_self_config())
|
||||
|
||||
has_default = (
|
||||
f.default is not MISSING or f.default_factory is not MISSING
|
||||
)
|
||||
|
||||
yield f_name, f_type, has_default, f_default
|
||||
|
||||
def get_overridden_serialization_method(
|
||||
self,
|
||||
) -> Optional[Union[Callable, str]]:
|
||||
if not self.__owner_builder:
|
||||
return None
|
||||
serialize_option = self.metadata.get("serialize")
|
||||
if serialize_option is not None:
|
||||
if callable(serialize_option):
|
||||
self.metadata.pop("serialize", None) # prevent recursion
|
||||
return serialize_option
|
||||
for strategy in self.__owner_builder.iter_serialization_strategies(
|
||||
self.metadata, self.type
|
||||
):
|
||||
if strategy is pass_through:
|
||||
return pass_through
|
||||
elif isinstance(strategy, dict):
|
||||
serialize_option = strategy.get("serialize")
|
||||
elif isinstance(strategy, SerializationStrategy):
|
||||
serialize_option = strategy.serialize
|
||||
if serialize_option is not None:
|
||||
return serialize_option
|
||||
return None
|
||||
|
||||
def get_owner_config(self) -> Type[BaseConfig]:
|
||||
if self.__owner_builder:
|
||||
return self.__owner_builder.get_config()
|
||||
else:
|
||||
return BaseConfig
|
||||
|
||||
def get_owner_dialect_or_config_option(
|
||||
self, option: str, default: Any
|
||||
) -> Any:
|
||||
if self.__owner_builder:
|
||||
return self.__owner_builder.get_dialect_or_config_option(
|
||||
option, default
|
||||
)
|
||||
else:
|
||||
return default
|
||||
|
||||
def get_self_config(self) -> Type[BaseConfig]:
|
||||
if self.__self_builder:
|
||||
return self.__self_builder.get_config()
|
||||
else:
|
||||
return BaseConfig
|
||||
|
||||
|
||||
InstanceSchemaCreator: TypeAlias = Callable[
|
||||
[Instance, Context], Optional[JSONSchema]
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstanceSchemaCreatorRegistry:
|
||||
_registry: list[InstanceSchemaCreator] = field(default_factory=list)
|
||||
|
||||
def register(self, func: InstanceSchemaCreator) -> InstanceSchemaCreator:
|
||||
self._registry.append(func)
|
||||
return func
|
||||
|
||||
def iter(self) -> Iterable[InstanceSchemaCreator]:
|
||||
yield from self._registry
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmptyJSONSchema(JSONSchema):
|
||||
pass
|
||||
|
||||
|
||||
def get_schema(
|
||||
instance: Instance, ctx: Context, with_dialect_uri: bool = False
|
||||
) -> JSONSchema:
|
||||
schema = None
|
||||
for schema_creator in Registry.iter():
|
||||
schema = schema_creator(instance, ctx)
|
||||
if schema is not None:
|
||||
if with_dialect_uri:
|
||||
schema.schema = ctx.dialect.uri
|
||||
break
|
||||
for plugin in ctx.plugins:
|
||||
try:
|
||||
new_schema = plugin.get_schema(instance, ctx, schema)
|
||||
if new_schema:
|
||||
schema = new_schema
|
||||
except NotImplementedError:
|
||||
continue
|
||||
if schema:
|
||||
return schema
|
||||
raise NotImplementedError(
|
||||
f'Type {type_name(instance.type)} of field "{instance.name}" '
|
||||
f"in {type_name(instance.owner_class)} isn't supported"
|
||||
)
|
||||
|
||||
|
||||
def _get_schema_or_none(
|
||||
instance: Instance, ctx: Context
|
||||
) -> Optional[JSONSchema]:
|
||||
schema = get_schema(instance, ctx)
|
||||
if isinstance(schema, EmptyJSONSchema):
|
||||
return None
|
||||
return schema
|
||||
|
||||
|
||||
def _default(f_type: Type, f_value: Any, config_cls: Type[BaseConfig]) -> Any:
|
||||
@dataclass
|
||||
class CC(DataClassJSONMixin):
|
||||
x: f_type = f_value # type: ignore
|
||||
|
||||
class Config(config_cls): # type: ignore
|
||||
pass
|
||||
|
||||
return CC(f_value).to_dict()["x"]
|
||||
|
||||
|
||||
Registry = InstanceSchemaCreatorRegistry()
|
||||
register = Registry.register
|
||||
|
||||
|
||||
BASIC_TYPES = {str, int, float, bool}
|
||||
|
||||
|
||||
@register
|
||||
def on_type_with_overridden_serialization(
|
||||
instance: Instance, ctx: Context
|
||||
) -> Optional[JSONSchema]:
|
||||
def override_with_any(reason: Any) -> None:
|
||||
if instance.owner_class is not None:
|
||||
name = f"{type_name(instance.owner_class)}.{instance.name}"
|
||||
else: # pragma: no cover
|
||||
# we will have an owner class, but leave this here just in case
|
||||
name = type_name(instance.type)
|
||||
warnings.warn(
|
||||
f"Type Any will be used for {name} with "
|
||||
f"overridden serialization method: {reason}"
|
||||
)
|
||||
instance.update_type(Any) # type: ignore[arg-type]
|
||||
|
||||
overridden_method = instance.get_overridden_serialization_method()
|
||||
if overridden_method is pass_through:
|
||||
return None
|
||||
elif overridden_method in BASIC_TYPES:
|
||||
instance.update_type(overridden_method) # type: ignore
|
||||
elif callable(overridden_method):
|
||||
try:
|
||||
new_type = get_function_return_annotation(overridden_method)
|
||||
if new_type is instance.type:
|
||||
return None
|
||||
else:
|
||||
instance.update_type(new_type)
|
||||
except Exception as e:
|
||||
override_with_any(e)
|
||||
return get_schema(instance, ctx)
|
||||
|
||||
|
||||
@register
|
||||
def on_dataclass(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
# TODO: Self references might not work
|
||||
if is_dataclass(instance.origin_type):
|
||||
jsonschema_config = instance.get_self_config().json_schema
|
||||
schema = JSONObjectSchema(
|
||||
title=instance.origin_type.__name__,
|
||||
additionalProperties=jsonschema_config.get(
|
||||
"additionalProperties", False
|
||||
),
|
||||
)
|
||||
properties: dict[str, JSONSchema] = {}
|
||||
required = []
|
||||
field_schema_overrides = jsonschema_config.get("properties", {})
|
||||
for f_name, f_type, has_default, f_default in instance.fields():
|
||||
override = field_schema_overrides.get(f_name)
|
||||
f_instance = instance.derive(type=f_type, name=f_name)
|
||||
if override:
|
||||
f_schema = JSONSchema.from_dict(override)
|
||||
else:
|
||||
f_schema = get_schema(f_instance, ctx)
|
||||
if f_instance.alias:
|
||||
f_name = f_instance.alias
|
||||
if f_default is not MISSING:
|
||||
f_schema.default = f_default
|
||||
description = f_instance.metadata.get("description")
|
||||
if description:
|
||||
f_schema.description = description
|
||||
|
||||
if not has_default:
|
||||
required.append(f_name)
|
||||
|
||||
properties[f_name] = f_schema
|
||||
if properties:
|
||||
schema.properties = properties
|
||||
if required:
|
||||
schema.required = required
|
||||
if ctx.all_refs:
|
||||
ctx.definitions[instance.origin_type.__name__] = schema
|
||||
ref_prefix = ctx.ref_prefix or ctx.dialect.definitions_root_pointer
|
||||
return JSONSchema(
|
||||
reference=f"{ref_prefix}/{instance.origin_type.__name__}"
|
||||
)
|
||||
else:
|
||||
return schema
|
||||
|
||||
|
||||
@register
|
||||
def on_any(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.type is Any:
|
||||
return EmptyJSONSchema()
|
||||
|
||||
|
||||
def on_literal(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
enum_values = []
|
||||
for value in get_literal_values(instance.type):
|
||||
if isinstance(value, Enum):
|
||||
enum_values.append(value.value)
|
||||
elif isinstance(value, (int, str, bool, NoneType)): # type: ignore
|
||||
enum_values.append(value)
|
||||
elif isinstance(value, bytes):
|
||||
enum_values.append(encodebytes(value).decode())
|
||||
if len(enum_values) == 1:
|
||||
return JSONSchema(const=enum_values[0])
|
||||
else:
|
||||
return JSONSchema(enum=enum_values)
|
||||
|
||||
|
||||
@register
|
||||
def on_special_typing_primitive(
|
||||
instance: Instance, ctx: Context
|
||||
) -> Optional[JSONSchema]:
|
||||
if not is_special_typing_primitive(instance.origin_type):
|
||||
return None
|
||||
|
||||
args = get_args(instance.type)
|
||||
|
||||
if is_union(instance.type):
|
||||
return JSONSchema(
|
||||
anyOf=[get_schema(instance.derive(type=arg), ctx) for arg in args]
|
||||
)
|
||||
elif is_type_var_any(instance.type):
|
||||
return EmptyJSONSchema()
|
||||
elif is_type_var(instance.type):
|
||||
constraints = getattr(instance.type, "__constraints__")
|
||||
if constraints:
|
||||
return JSONSchema(
|
||||
anyOf=[
|
||||
get_schema(instance.derive(type=arg), ctx)
|
||||
for arg in constraints
|
||||
]
|
||||
)
|
||||
else:
|
||||
bound = getattr(instance.type, "__bound__")
|
||||
return get_schema(instance.derive(type=bound), ctx)
|
||||
elif is_new_type(instance.type):
|
||||
return get_schema(
|
||||
instance.derive(type=instance.type.__supertype__), ctx
|
||||
)
|
||||
elif is_literal(instance.type):
|
||||
return on_literal(instance, ctx)
|
||||
# elif is_self(instance.type):
|
||||
# raise NotImplementedError
|
||||
elif is_required(instance.type) or is_not_required(instance.type):
|
||||
return get_schema(instance.derive(type=args[0]), ctx)
|
||||
elif is_unpack(instance.type):
|
||||
return get_schema(
|
||||
instance.derive(type=get_args(instance.type)[0]), ctx
|
||||
)
|
||||
elif is_type_var_tuple(instance.type):
|
||||
return get_schema(instance.derive(type=tuple[Any, ...]), ctx)
|
||||
elif is_readonly(instance.type):
|
||||
return get_schema(instance.derive(type=args[0]), ctx)
|
||||
elif isinstance(instance.type, ForwardRef):
|
||||
evaluated = evaluate_forward_ref(
|
||||
instance.type,
|
||||
get_forward_ref_referencing_globals(instance.type),
|
||||
None,
|
||||
)
|
||||
if evaluated is not None:
|
||||
return get_schema(instance.derive(type=evaluated), ctx)
|
||||
|
||||
|
||||
@register
|
||||
def on_number(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type is int:
|
||||
schema = JSONSchema(type=JSONSchemaInstanceType.INTEGER)
|
||||
elif instance.origin_type is float:
|
||||
schema = JSONSchema(type=JSONSchemaInstanceType.NUMBER)
|
||||
else:
|
||||
return None
|
||||
for annotation in instance.annotations:
|
||||
if isinstance(annotation, Maximum):
|
||||
schema.maximum = annotation.value
|
||||
elif isinstance(annotation, Minimum):
|
||||
schema.minimum = annotation.value
|
||||
elif isinstance(annotation, ExclusiveMaximum):
|
||||
schema.exclusiveMaximum = annotation.value
|
||||
elif isinstance(annotation, ExclusiveMinimum):
|
||||
schema.exclusiveMinimum = annotation.value
|
||||
elif isinstance(annotation, MultipleOf):
|
||||
schema.multipleOf = annotation.value
|
||||
return schema
|
||||
|
||||
|
||||
@register
|
||||
def on_bool(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type is bool:
|
||||
return JSONSchema(type=JSONSchemaInstanceType.BOOLEAN)
|
||||
|
||||
|
||||
@register
|
||||
def on_none(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type in (NoneType, None):
|
||||
return JSONSchema(type=JSONSchemaInstanceType.NULL)
|
||||
|
||||
|
||||
@register
|
||||
def on_date_objects(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type in (
|
||||
datetime.datetime,
|
||||
datetime.date,
|
||||
datetime.time,
|
||||
):
|
||||
return JSONSchema(
|
||||
type=JSONSchemaInstanceType.STRING,
|
||||
format=DATETIME_FORMATS[instance.origin_type],
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
def on_timedelta(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type is datetime.timedelta:
|
||||
return JSONSchema(
|
||||
type=JSONSchemaInstanceType.NUMBER,
|
||||
format=JSONSchemaInstanceFormatExtension.TIMEDELTA,
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
def on_timezone(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type is datetime.timezone:
|
||||
return JSONSchema(
|
||||
type=JSONSchemaInstanceType.STRING, pattern=UTC_OFFSET_PATTERN
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
def on_zone_info(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type is ZoneInfo:
|
||||
return JSONSchema(
|
||||
type=JSONSchemaInstanceType.STRING,
|
||||
format=JSONSchemaInstanceFormatExtension.TIME_ZONE,
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
def on_uuid(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type is UUID:
|
||||
return JSONSchema(
|
||||
type=JSONSchemaInstanceType.STRING,
|
||||
format=JSONSchemaStringFormat.UUID,
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
def on_ipaddress(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type in (
|
||||
ipaddress.IPv4Address,
|
||||
ipaddress.IPv6Address,
|
||||
ipaddress.IPv4Network,
|
||||
ipaddress.IPv6Network,
|
||||
ipaddress.IPv4Interface,
|
||||
ipaddress.IPv6Interface,
|
||||
):
|
||||
return JSONSchema(
|
||||
type=JSONSchemaInstanceType.STRING,
|
||||
format=IPADDRESS_FORMATS[instance.origin_type], # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
def on_decimal(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type is Decimal:
|
||||
return JSONSchema(
|
||||
type=JSONSchemaInstanceType.STRING,
|
||||
format=JSONSchemaInstanceFormatExtension.DECIMAL,
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
def on_fraction(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type is Fraction:
|
||||
return JSONSchema(
|
||||
type=JSONSchemaInstanceType.STRING,
|
||||
format=JSONSchemaInstanceFormatExtension.FRACTION,
|
||||
)
|
||||
|
||||
|
||||
def on_tuple(instance: Instance, ctx: Context) -> JSONArraySchema:
|
||||
args = get_args(instance.type)
|
||||
if not args:
|
||||
if instance.type in (Tuple, tuple):
|
||||
args = [Any, ...] # type: ignore
|
||||
else:
|
||||
return JSONArraySchema(maxItems=0)
|
||||
elif len(args) == 1 and args[0] == ():
|
||||
if not PY_311_MIN:
|
||||
return JSONArraySchema(maxItems=0)
|
||||
if len(args) == 2 and args[1] is Ellipsis:
|
||||
items_schema = _get_schema_or_none(instance.derive(type=args[0]), ctx)
|
||||
return JSONArraySchema(items=items_schema)
|
||||
else:
|
||||
min_items = 0
|
||||
max_items = 0
|
||||
prefix_items = []
|
||||
items: Optional[JSONSchema] = None
|
||||
unpack_schema: Optional[JSONSchema] = None
|
||||
unpack_idx = 0
|
||||
for arg_idx, arg in enumerate(args, start=1):
|
||||
if not is_unpack(arg):
|
||||
min_items += 1
|
||||
if not unpack_schema:
|
||||
prefix_items.append(
|
||||
get_schema(instance.derive(type=arg), ctx)
|
||||
)
|
||||
else:
|
||||
unpack_schema = get_schema(instance.derive(type=arg), ctx)
|
||||
unpack_idx = arg_idx
|
||||
if unpack_schema:
|
||||
prefix_items.extend(unpack_schema.prefixItems or [])
|
||||
min_items += unpack_schema.minItems or 0
|
||||
max_items += unpack_schema.maxItems or 0
|
||||
if unpack_idx == len(args):
|
||||
items = unpack_schema.items
|
||||
else:
|
||||
min_items = len(args)
|
||||
max_items = len(args)
|
||||
return JSONArraySchema(
|
||||
prefixItems=prefix_items or None,
|
||||
items=items,
|
||||
minItems=min_items or None,
|
||||
maxItems=max_items or None,
|
||||
)
|
||||
|
||||
|
||||
def on_named_tuple(instance: Instance, ctx: Context) -> JSONSchema:
|
||||
resolved = resolve_type_params(
|
||||
instance.origin_type, get_args(instance.type)
|
||||
)[instance.origin_type]
|
||||
annotations = {
|
||||
k: resolved.get(v, v)
|
||||
for k, v in getattr(
|
||||
instance.origin_type, "__annotations__", {}
|
||||
).items()
|
||||
}
|
||||
fields = getattr(instance.type, "_fields", ())
|
||||
defaults = getattr(instance.type, "_field_defaults", {})
|
||||
as_dict = instance.get_owner_dialect_or_config_option(
|
||||
"namedtuple_as_dict", False
|
||||
)
|
||||
serialize_option = instance.get_overridden_serialization_method()
|
||||
if serialize_option == "as_dict":
|
||||
as_dict = True
|
||||
elif serialize_option == "as_list":
|
||||
as_dict = False
|
||||
properties = {}
|
||||
for f_name in fields:
|
||||
f_type = annotations.get(f_name, Any)
|
||||
f_schema = get_schema(instance.derive(type=f_type), ctx)
|
||||
f_default = defaults.get(f_name, MISSING)
|
||||
if f_default is not MISSING:
|
||||
if isinstance(f_schema, EmptyJSONSchema):
|
||||
f_schema = JSONSchema()
|
||||
f_schema.default = _default(
|
||||
f_type, f_default, instance.get_self_config()
|
||||
)
|
||||
properties[f_name] = f_schema
|
||||
if as_dict:
|
||||
return JSONObjectSchema(
|
||||
properties=properties or None,
|
||||
required=list(fields),
|
||||
additionalProperties=False,
|
||||
)
|
||||
else:
|
||||
return JSONArraySchema(
|
||||
prefixItems=list(properties.values()) or None,
|
||||
maxItems=len(properties) or None,
|
||||
minItems=len(properties) or None,
|
||||
)
|
||||
|
||||
|
||||
def on_typed_dict(instance: Instance, ctx: Context) -> JSONObjectSchema:
|
||||
resolved = resolve_type_params(
|
||||
instance.origin_type, get_args(instance.type)
|
||||
)[instance.origin_type]
|
||||
annotations = {
|
||||
k: resolved.get(v, v)
|
||||
for k, v in instance.origin_type.__annotations__.items()
|
||||
}
|
||||
all_keys = list(annotations.keys())
|
||||
required_keys = getattr(instance.type, "__required_keys__", all_keys)
|
||||
return JSONObjectSchema(
|
||||
properties={
|
||||
key: get_schema(instance.derive(type=annotations[key]), ctx)
|
||||
for key in all_keys
|
||||
}
|
||||
or None,
|
||||
required=sorted(required_keys) or None,
|
||||
additionalProperties=False,
|
||||
)
|
||||
|
||||
|
||||
def apply_array_constraints(
|
||||
instance: Instance,
|
||||
schema: JSONSchema,
|
||||
) -> JSONSchema:
|
||||
has_contains = False
|
||||
min_contains: Optional[int] = None
|
||||
max_contains: Optional[int] = None
|
||||
for annotation in instance.annotations:
|
||||
if isinstance(annotation, MinItems):
|
||||
schema.minItems = annotation.value
|
||||
elif isinstance(annotation, MaxItems):
|
||||
schema.maxItems = annotation.value
|
||||
elif isinstance(annotation, UniqueItems):
|
||||
schema.uniqueItems = annotation.value
|
||||
elif isinstance(annotation, Contains):
|
||||
schema.contains = annotation.value
|
||||
has_contains = True
|
||||
elif isinstance(annotation, MinContains):
|
||||
min_contains = annotation.value
|
||||
elif isinstance(annotation, MaxContains):
|
||||
max_contains = annotation.value
|
||||
if has_contains:
|
||||
if min_contains is not None:
|
||||
schema.minContains = min_contains
|
||||
if max_contains is not None:
|
||||
schema.maxContains = max_contains
|
||||
return schema
|
||||
|
||||
|
||||
def apply_object_constraints(
|
||||
instance: Instance, schema: JSONSchema
|
||||
) -> JSONSchema:
|
||||
for annotation in instance.annotations:
|
||||
if isinstance(annotation, MaxProperties):
|
||||
schema.maxProperties = annotation.value
|
||||
elif isinstance(annotation, MinProperties):
|
||||
schema.minProperties = annotation.value
|
||||
elif isinstance(annotation, DependentRequired):
|
||||
schema.dependentRequired = annotation.value
|
||||
return schema
|
||||
|
||||
|
||||
@register
|
||||
def on_collection(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if not issubclass(instance.origin_type, Collection):
|
||||
return None
|
||||
elif issubclass(instance.origin_type, Enum):
|
||||
return None
|
||||
|
||||
args = get_args(instance.type)
|
||||
|
||||
if issubclass(instance.origin_type, ByteString): # type: ignore[arg-type]
|
||||
return JSONSchema(
|
||||
type=JSONSchemaInstanceType.STRING,
|
||||
format=JSONSchemaInstanceFormatExtension.BASE64,
|
||||
)
|
||||
elif issubclass(instance.origin_type, str):
|
||||
schema = JSONSchema(type=JSONSchemaInstanceType.STRING)
|
||||
for annotation in instance.annotations:
|
||||
if isinstance(annotation, MinLength):
|
||||
schema.minLength = annotation.value
|
||||
elif isinstance(annotation, MaxLength):
|
||||
schema.maxLength = annotation.value
|
||||
elif isinstance(annotation, Pattern):
|
||||
schema.pattern = annotation.value
|
||||
return schema
|
||||
elif is_generic(instance.type) and issubclass(
|
||||
instance.origin_type, (list, deque)
|
||||
):
|
||||
return apply_array_constraints(
|
||||
instance,
|
||||
JSONArraySchema(
|
||||
items=(
|
||||
_get_schema_or_none(instance.derive(type=args[0]), ctx)
|
||||
if args
|
||||
else None
|
||||
)
|
||||
),
|
||||
)
|
||||
elif issubclass(instance.origin_type, tuple):
|
||||
if is_named_tuple(instance.origin_type):
|
||||
return apply_array_constraints(
|
||||
instance, on_named_tuple(instance, ctx)
|
||||
)
|
||||
elif is_generic(instance.type):
|
||||
return apply_array_constraints(instance, on_tuple(instance, ctx))
|
||||
elif is_generic(instance.type) and issubclass(
|
||||
instance.origin_type, (frozenset, Set)
|
||||
):
|
||||
return apply_array_constraints(
|
||||
instance,
|
||||
JSONArraySchema(
|
||||
items=(
|
||||
_get_schema_or_none(instance.derive(type=args[0]), ctx)
|
||||
if args
|
||||
else None
|
||||
),
|
||||
uniqueItems=True,
|
||||
),
|
||||
)
|
||||
elif is_generic(instance.type) and issubclass(
|
||||
instance.origin_type, ChainMap
|
||||
):
|
||||
return apply_array_constraints(
|
||||
instance,
|
||||
JSONArraySchema(
|
||||
items=get_schema(
|
||||
instance=instance.derive(
|
||||
type=(
|
||||
dict[args[0], args[1]] # type: ignore
|
||||
if args
|
||||
else dict
|
||||
)
|
||||
),
|
||||
ctx=ctx,
|
||||
)
|
||||
),
|
||||
)
|
||||
elif is_generic(instance.type) and issubclass(
|
||||
instance.origin_type, Counter
|
||||
):
|
||||
schema = JSONObjectSchema(
|
||||
additionalProperties=get_schema(instance.derive(type=int), ctx),
|
||||
)
|
||||
if args:
|
||||
schema.propertyNames = _get_schema_or_none(
|
||||
instance.derive(type=args[0]), ctx
|
||||
)
|
||||
return apply_object_constraints(instance, schema)
|
||||
elif is_typed_dict(instance.origin_type):
|
||||
return on_typed_dict(instance, ctx)
|
||||
elif is_generic(instance.type) and issubclass(
|
||||
instance.origin_type, Mapping
|
||||
):
|
||||
schema = JSONObjectSchema(
|
||||
additionalProperties=(
|
||||
_get_schema_or_none(instance.derive(type=args[1]), ctx)
|
||||
if args
|
||||
else None
|
||||
),
|
||||
propertyNames=(
|
||||
_get_schema_or_none(instance.derive(type=args[0]), ctx)
|
||||
if args
|
||||
else None
|
||||
),
|
||||
)
|
||||
return apply_object_constraints(instance, schema)
|
||||
elif is_generic(instance.type) and issubclass(
|
||||
instance.origin_type, Sequence
|
||||
):
|
||||
return apply_array_constraints(
|
||||
instance,
|
||||
JSONArraySchema(
|
||||
items=(
|
||||
_get_schema_or_none(instance.derive(type=args[0]), ctx)
|
||||
if args
|
||||
else None
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
def on_pathlike(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if issubclass(instance.origin_type, os.PathLike):
|
||||
schema = JSONSchema(
|
||||
type=JSONSchemaInstanceType.STRING,
|
||||
format=JSONSchemaInstanceFormatExtension.PATH,
|
||||
)
|
||||
for annotation in instance.annotations:
|
||||
if isinstance(annotation, MaxLength):
|
||||
schema.maxLength = annotation.value
|
||||
elif isinstance(annotation, MinLength):
|
||||
schema.minLength = annotation.value
|
||||
return schema
|
||||
|
||||
|
||||
@register
|
||||
def on_enum(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if issubclass(instance.origin_type, Enum):
|
||||
return JSONSchema(enum=[m.value for m in instance.origin_type])
|
||||
|
||||
|
||||
__all__ = ["Instance", "get_schema"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,68 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Type, TypeVar, final
|
||||
|
||||
from mashumaro.core.meta.mixin import (
|
||||
compile_mixin_packer,
|
||||
compile_mixin_unpacker,
|
||||
)
|
||||
|
||||
__all__ = ["DataClassDictMixin"]
|
||||
|
||||
|
||||
T = TypeVar("T", bound="DataClassDictMixin")
|
||||
|
||||
|
||||
class DataClassDictMixin:
|
||||
__slots__ = ()
|
||||
|
||||
__mashumaro_builder_params = {"packer": {}, "unpacker": {}} # type: ignore
|
||||
|
||||
def __init_subclass__(cls: Type[T], **kwargs: Any):
|
||||
super().__init_subclass__(**kwargs)
|
||||
for ancestor in cls.__mro__[-1:0:-1]:
|
||||
builder_params_ = f"_{ancestor.__name__}__mashumaro_builder_params"
|
||||
builder_params = getattr(ancestor, builder_params_, None)
|
||||
if builder_params:
|
||||
compile_mixin_unpacker(cls, **builder_params["unpacker"])
|
||||
compile_mixin_packer(cls, **builder_params["packer"])
|
||||
|
||||
@final
|
||||
def to_dict(
|
||||
self: T,
|
||||
# *
|
||||
# keyword-only arguments that exist with the code generation options:
|
||||
# omit_none: bool = False
|
||||
# by_alias: bool = False
|
||||
# dialect: Type[Dialect] = None
|
||||
**kwargs: Any,
|
||||
) -> dict[Any, Any]: ...
|
||||
|
||||
@classmethod
|
||||
@final
|
||||
def from_dict(
|
||||
cls: Type[T],
|
||||
d: Mapping,
|
||||
# *
|
||||
# keyword-only arguments that exist with the code generation options:
|
||||
# dialect: Type[Dialect] = None
|
||||
**kwargs: Any,
|
||||
) -> T: ...
|
||||
|
||||
@classmethod
|
||||
def __pre_deserialize__(
|
||||
cls: Type[T], d: dict[Any, Any]
|
||||
) -> dict[Any, Any]: ...
|
||||
|
||||
@classmethod
|
||||
def __post_deserialize__(cls: Type[T], obj: T) -> T: ...
|
||||
|
||||
def __pre_serialize__(
|
||||
self: T,
|
||||
# context: Any = None, # added with ADD_SERIALIZATION_CONTEXT option
|
||||
) -> T: ...
|
||||
|
||||
def __post_serialize__(
|
||||
self: T,
|
||||
d: dict[Any, Any],
|
||||
# context: Any = None, # added with ADD_SERIALIZATION_CONTEXT option
|
||||
) -> dict[Any, Any]: ...
|
||||
@@ -0,0 +1,32 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Type, TypeVar, Union
|
||||
|
||||
from mashumaro.mixins.dict import DataClassDictMixin
|
||||
|
||||
T = TypeVar("T", bound="DataClassJSONMixin")
|
||||
|
||||
|
||||
EncodedData = Union[str, bytes, bytearray]
|
||||
Encoder = Callable[[Any], EncodedData]
|
||||
Decoder = Callable[[EncodedData], dict[Any, Any]]
|
||||
|
||||
|
||||
class DataClassJSONMixin(DataClassDictMixin):
|
||||
__slots__ = ()
|
||||
|
||||
def to_json(
|
||||
self: T,
|
||||
encoder: Encoder = json.dumps,
|
||||
**to_dict_kwargs: Any,
|
||||
) -> EncodedData:
|
||||
return encoder(self.to_dict(**to_dict_kwargs))
|
||||
|
||||
@classmethod
|
||||
def from_json(
|
||||
cls: Type[T],
|
||||
data: EncodedData,
|
||||
decoder: Decoder = json.loads,
|
||||
**from_dict_kwargs: Any,
|
||||
) -> T:
|
||||
return cls.from_dict(decoder(data), **from_dict_kwargs)
|
||||
@@ -0,0 +1,67 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Type, TypeVar, final
|
||||
|
||||
import msgpack
|
||||
|
||||
from mashumaro.dialect import Dialect
|
||||
from mashumaro.helper import pass_through
|
||||
from mashumaro.mixins.dict import DataClassDictMixin
|
||||
|
||||
T = TypeVar("T", bound="DataClassMessagePackMixin")
|
||||
|
||||
|
||||
EncodedData = bytes
|
||||
Encoder = Callable[[Any], EncodedData]
|
||||
Decoder = Callable[[EncodedData], dict[Any, Any]]
|
||||
|
||||
|
||||
class MessagePackDialect(Dialect):
|
||||
no_copy_collections = (list, dict)
|
||||
serialization_strategy = {
|
||||
bytes: pass_through,
|
||||
bytearray: {
|
||||
"deserialize": bytearray,
|
||||
"serialize": pass_through,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def default_encoder(data: Any) -> EncodedData:
|
||||
return msgpack.packb(data, use_bin_type=True)
|
||||
|
||||
|
||||
def default_decoder(data: EncodedData) -> dict[Any, Any]:
|
||||
return msgpack.unpackb(data, raw=False)
|
||||
|
||||
|
||||
class DataClassMessagePackMixin(DataClassDictMixin):
|
||||
__slots__ = ()
|
||||
|
||||
__mashumaro_builder_params = {
|
||||
"packer": {
|
||||
"format_name": "msgpack",
|
||||
"dialect": MessagePackDialect,
|
||||
"encoder": default_encoder,
|
||||
},
|
||||
"unpacker": {
|
||||
"format_name": "msgpack",
|
||||
"dialect": MessagePackDialect,
|
||||
"decoder": default_decoder,
|
||||
},
|
||||
}
|
||||
|
||||
@final
|
||||
def to_msgpack(
|
||||
self: T,
|
||||
encoder: Encoder = default_encoder,
|
||||
**to_dict_kwargs: Any,
|
||||
) -> EncodedData: ...
|
||||
|
||||
@classmethod
|
||||
@final
|
||||
def from_msgpack(
|
||||
cls: Type[T],
|
||||
data: EncodedData,
|
||||
decoder: Decoder = default_decoder,
|
||||
**from_dict_kwargs: Any,
|
||||
) -> T: ...
|
||||
@@ -0,0 +1,69 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import date, datetime, time
|
||||
from typing import Any, Type, TypeVar, Union, final
|
||||
from uuid import UUID
|
||||
|
||||
import orjson
|
||||
|
||||
from mashumaro.core.helpers import ConfigValue
|
||||
from mashumaro.dialect import Dialect
|
||||
from mashumaro.helper import pass_through
|
||||
from mashumaro.mixins.dict import DataClassDictMixin
|
||||
|
||||
T = TypeVar("T", bound="DataClassORJSONMixin")
|
||||
|
||||
|
||||
EncodedData = Union[str, bytes, bytearray]
|
||||
Encoder = Callable[[Any], EncodedData]
|
||||
Decoder = Callable[[EncodedData], dict[Any, Any]]
|
||||
|
||||
|
||||
class OrjsonDialect(Dialect):
|
||||
no_copy_collections = (list, dict)
|
||||
serialization_strategy = {
|
||||
datetime: {"serialize": pass_through},
|
||||
date: {"serialize": pass_through},
|
||||
time: {"serialize": pass_through},
|
||||
UUID: {"serialize": pass_through},
|
||||
}
|
||||
|
||||
|
||||
class DataClassORJSONMixin(DataClassDictMixin):
|
||||
__slots__ = ()
|
||||
|
||||
__mashumaro_builder_params = {
|
||||
"packer": {
|
||||
"format_name": "jsonb",
|
||||
"dialect": OrjsonDialect,
|
||||
"encoder": orjson.dumps,
|
||||
"encoder_kwargs": {
|
||||
"option": ("orjson_options", ConfigValue("orjson_options")),
|
||||
},
|
||||
},
|
||||
"unpacker": {
|
||||
"format_name": "json",
|
||||
"dialect": OrjsonDialect,
|
||||
"decoder": orjson.loads,
|
||||
},
|
||||
}
|
||||
|
||||
@final
|
||||
def to_jsonb(
|
||||
self: T,
|
||||
encoder: Encoder = orjson.dumps,
|
||||
*,
|
||||
orjson_options: int = ...,
|
||||
**to_dict_kwargs: Any,
|
||||
) -> bytes: ...
|
||||
|
||||
def to_json(self: T, **kwargs: Any) -> str:
|
||||
return self.to_jsonb(**kwargs).decode()
|
||||
|
||||
@classmethod
|
||||
@final
|
||||
def from_json(
|
||||
cls: Type[T],
|
||||
data: EncodedData,
|
||||
decoder: Decoder = orjson.loads,
|
||||
**from_dict_kwargs: Any,
|
||||
) -> T: ...
|
||||
@@ -0,0 +1,42 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Type, TypeVar, Union, final
|
||||
|
||||
import orjson
|
||||
|
||||
from mashumaro.dialect import Dialect
|
||||
from mashumaro.mixins.dict import DataClassDictMixin
|
||||
|
||||
T = TypeVar("T", bound="DataClassORJSONMixin")
|
||||
|
||||
EncodedData = Union[str, bytes, bytearray]
|
||||
Encoder = Callable[[Any], EncodedData]
|
||||
Decoder = Callable[[EncodedData], dict[Any, Any]]
|
||||
|
||||
class OrjsonDialect(Dialect):
|
||||
serialization_strategy: Any
|
||||
|
||||
class DataClassORJSONMixin(DataClassDictMixin):
|
||||
__slots__ = ()
|
||||
@final
|
||||
def to_jsonb(
|
||||
self: T,
|
||||
encoder: Encoder = orjson.dumps,
|
||||
*,
|
||||
orjson_options: int = ...,
|
||||
**to_dict_kwargs: Any,
|
||||
) -> bytes: ...
|
||||
def to_json(
|
||||
self: T,
|
||||
encoder: Encoder = orjson.dumps,
|
||||
*,
|
||||
orjson_options: int = ...,
|
||||
**to_dict_kwargs: Any,
|
||||
) -> str: ...
|
||||
@classmethod
|
||||
@final
|
||||
def from_json(
|
||||
cls: Type[T],
|
||||
data: EncodedData,
|
||||
decoder: Decoder = orjson.loads,
|
||||
**from_dict_kwargs: Any,
|
||||
) -> T: ...
|
||||
@@ -0,0 +1,64 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import date, datetime, time
|
||||
from typing import Any, Type, TypeVar, final
|
||||
|
||||
import tomli_w
|
||||
|
||||
from mashumaro.dialect import Dialect
|
||||
from mashumaro.helper import pass_through
|
||||
from mashumaro.mixins.dict import DataClassDictMixin
|
||||
|
||||
try:
|
||||
import tomllib
|
||||
except ModuleNotFoundError:
|
||||
import tomli as tomllib # type: ignore
|
||||
|
||||
T = TypeVar("T", bound="DataClassTOMLMixin")
|
||||
|
||||
|
||||
EncodedData = str
|
||||
Encoder = Callable[[Any], EncodedData]
|
||||
Decoder = Callable[[EncodedData], dict[Any, Any]]
|
||||
|
||||
|
||||
class TOMLDialect(Dialect):
|
||||
no_copy_collections = (list, dict)
|
||||
omit_none = True
|
||||
serialization_strategy = {
|
||||
datetime: pass_through,
|
||||
date: pass_through,
|
||||
time: pass_through,
|
||||
}
|
||||
|
||||
|
||||
class DataClassTOMLMixin(DataClassDictMixin):
|
||||
__slots__ = ()
|
||||
|
||||
__mashumaro_builder_params = {
|
||||
"packer": {
|
||||
"format_name": "toml",
|
||||
"dialect": TOMLDialect,
|
||||
"encoder": tomli_w.dumps,
|
||||
},
|
||||
"unpacker": {
|
||||
"format_name": "toml",
|
||||
"dialect": TOMLDialect,
|
||||
"decoder": tomllib.loads,
|
||||
},
|
||||
}
|
||||
|
||||
@final
|
||||
def to_toml(
|
||||
self: T,
|
||||
encoder: Encoder = tomli_w.dumps,
|
||||
**to_dict_kwargs: Any,
|
||||
) -> EncodedData: ...
|
||||
|
||||
@classmethod
|
||||
@final
|
||||
def from_toml(
|
||||
cls: Type[T],
|
||||
data: EncodedData,
|
||||
decoder: Decoder = tomllib.loads,
|
||||
**from_dict_kwargs: Any,
|
||||
) -> T: ...
|
||||
@@ -0,0 +1,45 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Type, TypeVar, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from mashumaro.mixins.dict import DataClassDictMixin
|
||||
|
||||
T = TypeVar("T", bound="DataClassYAMLMixin")
|
||||
|
||||
|
||||
EncodedData = Union[str, bytes]
|
||||
Encoder = Callable[[Any], EncodedData]
|
||||
Decoder = Callable[[EncodedData], dict[Any, Any]]
|
||||
|
||||
|
||||
DefaultLoader = getattr(yaml, "CSafeLoader", yaml.SafeLoader)
|
||||
DefaultDumper = getattr(yaml, "CDumper", yaml.Dumper)
|
||||
|
||||
|
||||
def default_encoder(data: Any) -> EncodedData:
|
||||
return yaml.dump(data, Dumper=DefaultDumper)
|
||||
|
||||
|
||||
def default_decoder(data: EncodedData) -> dict[Any, Any]:
|
||||
return yaml.load(data, DefaultLoader)
|
||||
|
||||
|
||||
class DataClassYAMLMixin(DataClassDictMixin):
|
||||
__slots__ = ()
|
||||
|
||||
def to_yaml(
|
||||
self: T,
|
||||
encoder: Encoder = default_encoder,
|
||||
**to_dict_kwargs: Any,
|
||||
) -> EncodedData:
|
||||
return encoder(self.to_dict(**to_dict_kwargs))
|
||||
|
||||
@classmethod
|
||||
def from_yaml(
|
||||
cls: Type[T],
|
||||
data: EncodedData,
|
||||
decoder: Decoder = default_decoder,
|
||||
**from_dict_kwargs: Any,
|
||||
) -> T:
|
||||
return cls.from_dict(decoder(data), **from_dict_kwargs)
|
||||
@@ -0,0 +1,127 @@
|
||||
import decimal
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Type, Union
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from mashumaro.core.const import Sentinel
|
||||
|
||||
__all__ = [
|
||||
"SerializableType",
|
||||
"GenericSerializableType",
|
||||
"SerializationStrategy",
|
||||
"RoundedDecimal",
|
||||
"Discriminator",
|
||||
"Alias",
|
||||
]
|
||||
|
||||
|
||||
class SerializableType:
|
||||
__slots__ = ()
|
||||
|
||||
__use_annotations__ = False
|
||||
|
||||
def __init_subclass__(
|
||||
cls,
|
||||
use_annotations: Union[
|
||||
bool, Literal[Sentinel.MISSING]
|
||||
] = Sentinel.MISSING,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if use_annotations is not Sentinel.MISSING:
|
||||
cls.__use_annotations__ = use_annotations
|
||||
|
||||
def _serialize(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _deserialize(cls, value: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class GenericSerializableType:
|
||||
__slots__ = ()
|
||||
|
||||
def _serialize(self, types: list[Type]) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _deserialize(cls, value: Any, types: list[Type]) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SerializationStrategy:
|
||||
__use_annotations__ = False
|
||||
|
||||
def __init_subclass__(
|
||||
cls,
|
||||
use_annotations: Union[
|
||||
bool, Literal[Sentinel.MISSING]
|
||||
] = Sentinel.MISSING,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if use_annotations is not Sentinel.MISSING:
|
||||
cls.__use_annotations__ = use_annotations
|
||||
|
||||
def serialize(self, value: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def deserialize(self, value: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RoundedDecimal(SerializationStrategy):
|
||||
def __init__(
|
||||
self, places: Optional[int] = None, rounding: Optional[str] = None
|
||||
):
|
||||
if places is not None:
|
||||
self.exp = decimal.Decimal((0, (1,), -places))
|
||||
else:
|
||||
self.exp = None # type: ignore
|
||||
self.rounding = rounding
|
||||
|
||||
def serialize(self, value: decimal.Decimal) -> str:
|
||||
if self.exp:
|
||||
if self.rounding:
|
||||
return str(value.quantize(self.exp, rounding=self.rounding))
|
||||
else:
|
||||
return str(value.quantize(self.exp))
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
def deserialize(self, value: str) -> decimal.Decimal:
|
||||
return decimal.Decimal(str(value))
|
||||
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class Discriminator:
|
||||
field: Optional[str] = None
|
||||
include_supertypes: bool = False
|
||||
include_subtypes: bool = False
|
||||
variant_tagger_fn: Optional[Callable[[Any], Any]] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.include_supertypes and not self.include_subtypes:
|
||||
raise ValueError(
|
||||
"Either 'include_supertypes' or 'include_subtypes' "
|
||||
"must be enabled"
|
||||
)
|
||||
|
||||
|
||||
class Alias:
|
||||
def __init__(self, name: str, /):
|
||||
self.name = name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Alias(name='{self.name}')"
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, Alias):
|
||||
return False
|
||||
return self.name == other.name
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.name)
|
||||
Reference in New Issue
Block a user