some random stuff. caelestia incoming
This commit is contained in:
@@ -0,0 +1,886 @@
|
||||
import datetime
|
||||
import ipaddress
|
||||
import os
|
||||
import warnings
|
||||
from base64 import encodebytes
|
||||
from collections import ChainMap, Counter, deque
|
||||
from collections.abc import (
|
||||
ByteString,
|
||||
Callable,
|
||||
Collection,
|
||||
Iterable,
|
||||
Mapping,
|
||||
Sequence,
|
||||
Set,
|
||||
)
|
||||
from dataclasses import MISSING, dataclass, field, is_dataclass, replace
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from functools import cached_property
|
||||
from typing import Any, ForwardRef, Optional, Tuple, Type, Union
|
||||
from uuid import UUID
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from mashumaro.config import BaseConfig
|
||||
from mashumaro.core.const import PY_311_MIN
|
||||
from mashumaro.core.meta.code.builder import CodeBuilder
|
||||
from mashumaro.core.meta.helpers import (
|
||||
evaluate_forward_ref,
|
||||
get_args,
|
||||
get_forward_ref_referencing_globals,
|
||||
get_function_return_annotation,
|
||||
get_literal_values,
|
||||
get_type_origin,
|
||||
is_annotated,
|
||||
is_generic,
|
||||
is_literal,
|
||||
is_named_tuple,
|
||||
is_new_type,
|
||||
is_not_required,
|
||||
is_readonly,
|
||||
is_required,
|
||||
is_special_typing_primitive,
|
||||
is_type_var,
|
||||
is_type_var_any,
|
||||
is_type_var_tuple,
|
||||
is_typed_dict,
|
||||
is_union,
|
||||
is_unpack,
|
||||
resolve_type_params,
|
||||
type_name,
|
||||
)
|
||||
from mashumaro.core.meta.types.common import NoneType
|
||||
from mashumaro.helper import pass_through
|
||||
from mashumaro.jsonschema.annotations import (
|
||||
Annotation,
|
||||
Contains,
|
||||
DependentRequired,
|
||||
ExclusiveMaximum,
|
||||
ExclusiveMinimum,
|
||||
MaxContains,
|
||||
Maximum,
|
||||
MaxItems,
|
||||
MaxLength,
|
||||
MaxProperties,
|
||||
MinContains,
|
||||
Minimum,
|
||||
MinItems,
|
||||
MinLength,
|
||||
MinProperties,
|
||||
MultipleOf,
|
||||
Pattern,
|
||||
UniqueItems,
|
||||
)
|
||||
from mashumaro.jsonschema.models import (
|
||||
DATETIME_FORMATS,
|
||||
IPADDRESS_FORMATS,
|
||||
Context,
|
||||
JSONArraySchema,
|
||||
JSONObjectSchema,
|
||||
JSONSchema,
|
||||
JSONSchemaInstanceFormatExtension,
|
||||
JSONSchemaInstanceType,
|
||||
JSONSchemaStringFormat,
|
||||
)
|
||||
from mashumaro.types import SerializationStrategy
|
||||
|
||||
try:
|
||||
from mashumaro.mixins.orjson import (
|
||||
DataClassORJSONMixin as DataClassJSONMixin,
|
||||
)
|
||||
except ImportError: # pragma: no cover
|
||||
from mashumaro.mixins.json import DataClassJSONMixin # type: ignore
|
||||
|
||||
|
||||
UTC_OFFSET_PATTERN = r"^UTC([+-][0-2][0-9]:[0-5][0-9])?$"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Instance:
|
||||
type: Type
|
||||
name: Optional[str] = None
|
||||
|
||||
__owner_builder: Optional[CodeBuilder] = None
|
||||
__self_builder: Optional[CodeBuilder] = None
|
||||
|
||||
# Original type despite custom serialization. To be revised.
|
||||
_original_type: Type = field(init=False)
|
||||
|
||||
origin_type: Type = field(init=False)
|
||||
annotations: list[Annotation] = field(init=False, default_factory=list)
|
||||
|
||||
@cached_property
|
||||
def metadata(self) -> dict[str, Any]:
|
||||
if self.name and self.__owner_builder:
|
||||
return dict(**self.__owner_builder.metadatas.get(self.name, {}))
|
||||
else:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def _self_builder(self) -> CodeBuilder:
|
||||
assert self.__self_builder
|
||||
return self.__self_builder
|
||||
|
||||
@property
|
||||
def alias(self) -> Optional[str]:
|
||||
alias = self.metadata.get("alias")
|
||||
if alias is None:
|
||||
aliases_config = self.get_owner_config().aliases
|
||||
alias = aliases_config.get(self.name) # type: ignore
|
||||
if alias is None:
|
||||
alias = self.name
|
||||
return alias
|
||||
|
||||
@property
|
||||
def owner_class(self) -> Optional[Type]:
|
||||
if self.__owner_builder:
|
||||
return self.__owner_builder.cls
|
||||
return None
|
||||
|
||||
def derive(self, **changes: Any) -> "Instance":
|
||||
new_type = changes.get("type")
|
||||
if isinstance(new_type, ForwardRef):
|
||||
changes["type"] = evaluate_forward_ref(
|
||||
new_type,
|
||||
get_forward_ref_referencing_globals(new_type, self.type),
|
||||
self.__dict__,
|
||||
)
|
||||
new_instance = replace(self, **changes)
|
||||
if is_dataclass(self.origin_type):
|
||||
new_instance.__owner_builder = self.__self_builder
|
||||
return new_instance
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._original_type = self.type
|
||||
self.update_type(self.type)
|
||||
if is_annotated(self.type):
|
||||
self.annotations = getattr(self.type, "__metadata__", [])
|
||||
self.type = get_args(self.type)[0]
|
||||
self.origin_type = get_type_origin(self.type)
|
||||
|
||||
def update_type(self, new_type: Type) -> None:
|
||||
if self.__owner_builder:
|
||||
self.type = self.__owner_builder.get_real_type(
|
||||
field_name=self.name, # type: ignore
|
||||
field_type=new_type,
|
||||
)
|
||||
self.origin_type = get_type_origin(self.type)
|
||||
if is_dataclass(self.origin_type):
|
||||
type_args = get_args(self.type)
|
||||
self.__self_builder = CodeBuilder(self.origin_type, type_args)
|
||||
self.__self_builder.reset()
|
||||
else:
|
||||
self.__self_builder = None
|
||||
|
||||
def fields(self) -> Iterable[tuple[str, Type, bool, Any]]:
|
||||
for f_name, f_type in self._self_builder.get_field_types(
|
||||
include_extras=True
|
||||
).items():
|
||||
f = self._self_builder.dataclass_fields.get(f_name)
|
||||
if not f or f and not f.init:
|
||||
continue
|
||||
f_default = f.default
|
||||
if f_default is MISSING:
|
||||
f_default = self._self_builder.namespace.get(f_name, MISSING)
|
||||
if f_default is not MISSING:
|
||||
f_default = _default(f_type, f_default, self.get_self_config())
|
||||
|
||||
has_default = (
|
||||
f.default is not MISSING or f.default_factory is not MISSING
|
||||
)
|
||||
|
||||
yield f_name, f_type, has_default, f_default
|
||||
|
||||
def get_overridden_serialization_method(
|
||||
self,
|
||||
) -> Optional[Union[Callable, str]]:
|
||||
if not self.__owner_builder:
|
||||
return None
|
||||
serialize_option = self.metadata.get("serialize")
|
||||
if serialize_option is not None:
|
||||
if callable(serialize_option):
|
||||
self.metadata.pop("serialize", None) # prevent recursion
|
||||
return serialize_option
|
||||
for strategy in self.__owner_builder.iter_serialization_strategies(
|
||||
self.metadata, self.type
|
||||
):
|
||||
if strategy is pass_through:
|
||||
return pass_through
|
||||
elif isinstance(strategy, dict):
|
||||
serialize_option = strategy.get("serialize")
|
||||
elif isinstance(strategy, SerializationStrategy):
|
||||
serialize_option = strategy.serialize
|
||||
if serialize_option is not None:
|
||||
return serialize_option
|
||||
return None
|
||||
|
||||
def get_owner_config(self) -> Type[BaseConfig]:
|
||||
if self.__owner_builder:
|
||||
return self.__owner_builder.get_config()
|
||||
else:
|
||||
return BaseConfig
|
||||
|
||||
def get_owner_dialect_or_config_option(
|
||||
self, option: str, default: Any
|
||||
) -> Any:
|
||||
if self.__owner_builder:
|
||||
return self.__owner_builder.get_dialect_or_config_option(
|
||||
option, default
|
||||
)
|
||||
else:
|
||||
return default
|
||||
|
||||
def get_self_config(self) -> Type[BaseConfig]:
|
||||
if self.__self_builder:
|
||||
return self.__self_builder.get_config()
|
||||
else:
|
||||
return BaseConfig
|
||||
|
||||
|
||||
InstanceSchemaCreator: TypeAlias = Callable[
|
||||
[Instance, Context], Optional[JSONSchema]
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstanceSchemaCreatorRegistry:
|
||||
_registry: list[InstanceSchemaCreator] = field(default_factory=list)
|
||||
|
||||
def register(self, func: InstanceSchemaCreator) -> InstanceSchemaCreator:
|
||||
self._registry.append(func)
|
||||
return func
|
||||
|
||||
def iter(self) -> Iterable[InstanceSchemaCreator]:
|
||||
yield from self._registry
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmptyJSONSchema(JSONSchema):
|
||||
pass
|
||||
|
||||
|
||||
def get_schema(
|
||||
instance: Instance, ctx: Context, with_dialect_uri: bool = False
|
||||
) -> JSONSchema:
|
||||
schema = None
|
||||
for schema_creator in Registry.iter():
|
||||
schema = schema_creator(instance, ctx)
|
||||
if schema is not None:
|
||||
if with_dialect_uri:
|
||||
schema.schema = ctx.dialect.uri
|
||||
break
|
||||
for plugin in ctx.plugins:
|
||||
try:
|
||||
new_schema = plugin.get_schema(instance, ctx, schema)
|
||||
if new_schema:
|
||||
schema = new_schema
|
||||
except NotImplementedError:
|
||||
continue
|
||||
if schema:
|
||||
return schema
|
||||
raise NotImplementedError(
|
||||
f'Type {type_name(instance.type)} of field "{instance.name}" '
|
||||
f"in {type_name(instance.owner_class)} isn't supported"
|
||||
)
|
||||
|
||||
|
||||
def _get_schema_or_none(
|
||||
instance: Instance, ctx: Context
|
||||
) -> Optional[JSONSchema]:
|
||||
schema = get_schema(instance, ctx)
|
||||
if isinstance(schema, EmptyJSONSchema):
|
||||
return None
|
||||
return schema
|
||||
|
||||
|
||||
def _default(f_type: Type, f_value: Any, config_cls: Type[BaseConfig]) -> Any:
|
||||
@dataclass
|
||||
class CC(DataClassJSONMixin):
|
||||
x: f_type = f_value # type: ignore
|
||||
|
||||
class Config(config_cls): # type: ignore
|
||||
pass
|
||||
|
||||
return CC(f_value).to_dict()["x"]
|
||||
|
||||
|
||||
Registry = InstanceSchemaCreatorRegistry()
|
||||
register = Registry.register
|
||||
|
||||
|
||||
BASIC_TYPES = {str, int, float, bool}
|
||||
|
||||
|
||||
@register
|
||||
def on_type_with_overridden_serialization(
|
||||
instance: Instance, ctx: Context
|
||||
) -> Optional[JSONSchema]:
|
||||
def override_with_any(reason: Any) -> None:
|
||||
if instance.owner_class is not None:
|
||||
name = f"{type_name(instance.owner_class)}.{instance.name}"
|
||||
else: # pragma: no cover
|
||||
# we will have an owner class, but leave this here just in case
|
||||
name = type_name(instance.type)
|
||||
warnings.warn(
|
||||
f"Type Any will be used for {name} with "
|
||||
f"overridden serialization method: {reason}"
|
||||
)
|
||||
instance.update_type(Any) # type: ignore[arg-type]
|
||||
|
||||
overridden_method = instance.get_overridden_serialization_method()
|
||||
if overridden_method is pass_through:
|
||||
return None
|
||||
elif overridden_method in BASIC_TYPES:
|
||||
instance.update_type(overridden_method) # type: ignore
|
||||
elif callable(overridden_method):
|
||||
try:
|
||||
new_type = get_function_return_annotation(overridden_method)
|
||||
if new_type is instance.type:
|
||||
return None
|
||||
else:
|
||||
instance.update_type(new_type)
|
||||
except Exception as e:
|
||||
override_with_any(e)
|
||||
return get_schema(instance, ctx)
|
||||
|
||||
|
||||
@register
|
||||
def on_dataclass(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
# TODO: Self references might not work
|
||||
if is_dataclass(instance.origin_type):
|
||||
jsonschema_config = instance.get_self_config().json_schema
|
||||
schema = JSONObjectSchema(
|
||||
title=instance.origin_type.__name__,
|
||||
additionalProperties=jsonschema_config.get(
|
||||
"additionalProperties", False
|
||||
),
|
||||
)
|
||||
properties: dict[str, JSONSchema] = {}
|
||||
required = []
|
||||
field_schema_overrides = jsonschema_config.get("properties", {})
|
||||
for f_name, f_type, has_default, f_default in instance.fields():
|
||||
override = field_schema_overrides.get(f_name)
|
||||
f_instance = instance.derive(type=f_type, name=f_name)
|
||||
if override:
|
||||
f_schema = JSONSchema.from_dict(override)
|
||||
else:
|
||||
f_schema = get_schema(f_instance, ctx)
|
||||
if f_instance.alias:
|
||||
f_name = f_instance.alias
|
||||
if f_default is not MISSING:
|
||||
f_schema.default = f_default
|
||||
description = f_instance.metadata.get("description")
|
||||
if description:
|
||||
f_schema.description = description
|
||||
|
||||
if not has_default:
|
||||
required.append(f_name)
|
||||
|
||||
properties[f_name] = f_schema
|
||||
if properties:
|
||||
schema.properties = properties
|
||||
if required:
|
||||
schema.required = required
|
||||
if ctx.all_refs:
|
||||
ctx.definitions[instance.origin_type.__name__] = schema
|
||||
ref_prefix = ctx.ref_prefix or ctx.dialect.definitions_root_pointer
|
||||
return JSONSchema(
|
||||
reference=f"{ref_prefix}/{instance.origin_type.__name__}"
|
||||
)
|
||||
else:
|
||||
return schema
|
||||
|
||||
|
||||
@register
|
||||
def on_any(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.type is Any:
|
||||
return EmptyJSONSchema()
|
||||
|
||||
|
||||
def on_literal(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
enum_values = []
|
||||
for value in get_literal_values(instance.type):
|
||||
if isinstance(value, Enum):
|
||||
enum_values.append(value.value)
|
||||
elif isinstance(value, (int, str, bool, NoneType)): # type: ignore
|
||||
enum_values.append(value)
|
||||
elif isinstance(value, bytes):
|
||||
enum_values.append(encodebytes(value).decode())
|
||||
if len(enum_values) == 1:
|
||||
return JSONSchema(const=enum_values[0])
|
||||
else:
|
||||
return JSONSchema(enum=enum_values)
|
||||
|
||||
|
||||
@register
|
||||
def on_special_typing_primitive(
|
||||
instance: Instance, ctx: Context
|
||||
) -> Optional[JSONSchema]:
|
||||
if not is_special_typing_primitive(instance.origin_type):
|
||||
return None
|
||||
|
||||
args = get_args(instance.type)
|
||||
|
||||
if is_union(instance.type):
|
||||
return JSONSchema(
|
||||
anyOf=[get_schema(instance.derive(type=arg), ctx) for arg in args]
|
||||
)
|
||||
elif is_type_var_any(instance.type):
|
||||
return EmptyJSONSchema()
|
||||
elif is_type_var(instance.type):
|
||||
constraints = getattr(instance.type, "__constraints__")
|
||||
if constraints:
|
||||
return JSONSchema(
|
||||
anyOf=[
|
||||
get_schema(instance.derive(type=arg), ctx)
|
||||
for arg in constraints
|
||||
]
|
||||
)
|
||||
else:
|
||||
bound = getattr(instance.type, "__bound__")
|
||||
return get_schema(instance.derive(type=bound), ctx)
|
||||
elif is_new_type(instance.type):
|
||||
return get_schema(
|
||||
instance.derive(type=instance.type.__supertype__), ctx
|
||||
)
|
||||
elif is_literal(instance.type):
|
||||
return on_literal(instance, ctx)
|
||||
# elif is_self(instance.type):
|
||||
# raise NotImplementedError
|
||||
elif is_required(instance.type) or is_not_required(instance.type):
|
||||
return get_schema(instance.derive(type=args[0]), ctx)
|
||||
elif is_unpack(instance.type):
|
||||
return get_schema(
|
||||
instance.derive(type=get_args(instance.type)[0]), ctx
|
||||
)
|
||||
elif is_type_var_tuple(instance.type):
|
||||
return get_schema(instance.derive(type=tuple[Any, ...]), ctx)
|
||||
elif is_readonly(instance.type):
|
||||
return get_schema(instance.derive(type=args[0]), ctx)
|
||||
elif isinstance(instance.type, ForwardRef):
|
||||
evaluated = evaluate_forward_ref(
|
||||
instance.type,
|
||||
get_forward_ref_referencing_globals(instance.type),
|
||||
None,
|
||||
)
|
||||
if evaluated is not None:
|
||||
return get_schema(instance.derive(type=evaluated), ctx)
|
||||
|
||||
|
||||
@register
|
||||
def on_number(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type is int:
|
||||
schema = JSONSchema(type=JSONSchemaInstanceType.INTEGER)
|
||||
elif instance.origin_type is float:
|
||||
schema = JSONSchema(type=JSONSchemaInstanceType.NUMBER)
|
||||
else:
|
||||
return None
|
||||
for annotation in instance.annotations:
|
||||
if isinstance(annotation, Maximum):
|
||||
schema.maximum = annotation.value
|
||||
elif isinstance(annotation, Minimum):
|
||||
schema.minimum = annotation.value
|
||||
elif isinstance(annotation, ExclusiveMaximum):
|
||||
schema.exclusiveMaximum = annotation.value
|
||||
elif isinstance(annotation, ExclusiveMinimum):
|
||||
schema.exclusiveMinimum = annotation.value
|
||||
elif isinstance(annotation, MultipleOf):
|
||||
schema.multipleOf = annotation.value
|
||||
return schema
|
||||
|
||||
|
||||
@register
|
||||
def on_bool(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type is bool:
|
||||
return JSONSchema(type=JSONSchemaInstanceType.BOOLEAN)
|
||||
|
||||
|
||||
@register
|
||||
def on_none(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type in (NoneType, None):
|
||||
return JSONSchema(type=JSONSchemaInstanceType.NULL)
|
||||
|
||||
|
||||
@register
|
||||
def on_date_objects(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type in (
|
||||
datetime.datetime,
|
||||
datetime.date,
|
||||
datetime.time,
|
||||
):
|
||||
return JSONSchema(
|
||||
type=JSONSchemaInstanceType.STRING,
|
||||
format=DATETIME_FORMATS[instance.origin_type],
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
def on_timedelta(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type is datetime.timedelta:
|
||||
return JSONSchema(
|
||||
type=JSONSchemaInstanceType.NUMBER,
|
||||
format=JSONSchemaInstanceFormatExtension.TIMEDELTA,
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
def on_timezone(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type is datetime.timezone:
|
||||
return JSONSchema(
|
||||
type=JSONSchemaInstanceType.STRING, pattern=UTC_OFFSET_PATTERN
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
def on_zone_info(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type is ZoneInfo:
|
||||
return JSONSchema(
|
||||
type=JSONSchemaInstanceType.STRING,
|
||||
format=JSONSchemaInstanceFormatExtension.TIME_ZONE,
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
def on_uuid(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type is UUID:
|
||||
return JSONSchema(
|
||||
type=JSONSchemaInstanceType.STRING,
|
||||
format=JSONSchemaStringFormat.UUID,
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
def on_ipaddress(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type in (
|
||||
ipaddress.IPv4Address,
|
||||
ipaddress.IPv6Address,
|
||||
ipaddress.IPv4Network,
|
||||
ipaddress.IPv6Network,
|
||||
ipaddress.IPv4Interface,
|
||||
ipaddress.IPv6Interface,
|
||||
):
|
||||
return JSONSchema(
|
||||
type=JSONSchemaInstanceType.STRING,
|
||||
format=IPADDRESS_FORMATS[instance.origin_type], # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
def on_decimal(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type is Decimal:
|
||||
return JSONSchema(
|
||||
type=JSONSchemaInstanceType.STRING,
|
||||
format=JSONSchemaInstanceFormatExtension.DECIMAL,
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
def on_fraction(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if instance.origin_type is Fraction:
|
||||
return JSONSchema(
|
||||
type=JSONSchemaInstanceType.STRING,
|
||||
format=JSONSchemaInstanceFormatExtension.FRACTION,
|
||||
)
|
||||
|
||||
|
||||
def on_tuple(instance: Instance, ctx: Context) -> JSONArraySchema:
|
||||
args = get_args(instance.type)
|
||||
if not args:
|
||||
if instance.type in (Tuple, tuple):
|
||||
args = [Any, ...] # type: ignore
|
||||
else:
|
||||
return JSONArraySchema(maxItems=0)
|
||||
elif len(args) == 1 and args[0] == ():
|
||||
if not PY_311_MIN:
|
||||
return JSONArraySchema(maxItems=0)
|
||||
if len(args) == 2 and args[1] is Ellipsis:
|
||||
items_schema = _get_schema_or_none(instance.derive(type=args[0]), ctx)
|
||||
return JSONArraySchema(items=items_schema)
|
||||
else:
|
||||
min_items = 0
|
||||
max_items = 0
|
||||
prefix_items = []
|
||||
items: Optional[JSONSchema] = None
|
||||
unpack_schema: Optional[JSONSchema] = None
|
||||
unpack_idx = 0
|
||||
for arg_idx, arg in enumerate(args, start=1):
|
||||
if not is_unpack(arg):
|
||||
min_items += 1
|
||||
if not unpack_schema:
|
||||
prefix_items.append(
|
||||
get_schema(instance.derive(type=arg), ctx)
|
||||
)
|
||||
else:
|
||||
unpack_schema = get_schema(instance.derive(type=arg), ctx)
|
||||
unpack_idx = arg_idx
|
||||
if unpack_schema:
|
||||
prefix_items.extend(unpack_schema.prefixItems or [])
|
||||
min_items += unpack_schema.minItems or 0
|
||||
max_items += unpack_schema.maxItems or 0
|
||||
if unpack_idx == len(args):
|
||||
items = unpack_schema.items
|
||||
else:
|
||||
min_items = len(args)
|
||||
max_items = len(args)
|
||||
return JSONArraySchema(
|
||||
prefixItems=prefix_items or None,
|
||||
items=items,
|
||||
minItems=min_items or None,
|
||||
maxItems=max_items or None,
|
||||
)
|
||||
|
||||
|
||||
def on_named_tuple(instance: Instance, ctx: Context) -> JSONSchema:
|
||||
resolved = resolve_type_params(
|
||||
instance.origin_type, get_args(instance.type)
|
||||
)[instance.origin_type]
|
||||
annotations = {
|
||||
k: resolved.get(v, v)
|
||||
for k, v in getattr(
|
||||
instance.origin_type, "__annotations__", {}
|
||||
).items()
|
||||
}
|
||||
fields = getattr(instance.type, "_fields", ())
|
||||
defaults = getattr(instance.type, "_field_defaults", {})
|
||||
as_dict = instance.get_owner_dialect_or_config_option(
|
||||
"namedtuple_as_dict", False
|
||||
)
|
||||
serialize_option = instance.get_overridden_serialization_method()
|
||||
if serialize_option == "as_dict":
|
||||
as_dict = True
|
||||
elif serialize_option == "as_list":
|
||||
as_dict = False
|
||||
properties = {}
|
||||
for f_name in fields:
|
||||
f_type = annotations.get(f_name, Any)
|
||||
f_schema = get_schema(instance.derive(type=f_type), ctx)
|
||||
f_default = defaults.get(f_name, MISSING)
|
||||
if f_default is not MISSING:
|
||||
if isinstance(f_schema, EmptyJSONSchema):
|
||||
f_schema = JSONSchema()
|
||||
f_schema.default = _default(
|
||||
f_type, f_default, instance.get_self_config()
|
||||
)
|
||||
properties[f_name] = f_schema
|
||||
if as_dict:
|
||||
return JSONObjectSchema(
|
||||
properties=properties or None,
|
||||
required=list(fields),
|
||||
additionalProperties=False,
|
||||
)
|
||||
else:
|
||||
return JSONArraySchema(
|
||||
prefixItems=list(properties.values()) or None,
|
||||
maxItems=len(properties) or None,
|
||||
minItems=len(properties) or None,
|
||||
)
|
||||
|
||||
|
||||
def on_typed_dict(instance: Instance, ctx: Context) -> JSONObjectSchema:
|
||||
resolved = resolve_type_params(
|
||||
instance.origin_type, get_args(instance.type)
|
||||
)[instance.origin_type]
|
||||
annotations = {
|
||||
k: resolved.get(v, v)
|
||||
for k, v in instance.origin_type.__annotations__.items()
|
||||
}
|
||||
all_keys = list(annotations.keys())
|
||||
required_keys = getattr(instance.type, "__required_keys__", all_keys)
|
||||
return JSONObjectSchema(
|
||||
properties={
|
||||
key: get_schema(instance.derive(type=annotations[key]), ctx)
|
||||
for key in all_keys
|
||||
}
|
||||
or None,
|
||||
required=sorted(required_keys) or None,
|
||||
additionalProperties=False,
|
||||
)
|
||||
|
||||
|
||||
def apply_array_constraints(
|
||||
instance: Instance,
|
||||
schema: JSONSchema,
|
||||
) -> JSONSchema:
|
||||
has_contains = False
|
||||
min_contains: Optional[int] = None
|
||||
max_contains: Optional[int] = None
|
||||
for annotation in instance.annotations:
|
||||
if isinstance(annotation, MinItems):
|
||||
schema.minItems = annotation.value
|
||||
elif isinstance(annotation, MaxItems):
|
||||
schema.maxItems = annotation.value
|
||||
elif isinstance(annotation, UniqueItems):
|
||||
schema.uniqueItems = annotation.value
|
||||
elif isinstance(annotation, Contains):
|
||||
schema.contains = annotation.value
|
||||
has_contains = True
|
||||
elif isinstance(annotation, MinContains):
|
||||
min_contains = annotation.value
|
||||
elif isinstance(annotation, MaxContains):
|
||||
max_contains = annotation.value
|
||||
if has_contains:
|
||||
if min_contains is not None:
|
||||
schema.minContains = min_contains
|
||||
if max_contains is not None:
|
||||
schema.maxContains = max_contains
|
||||
return schema
|
||||
|
||||
|
||||
def apply_object_constraints(
|
||||
instance: Instance, schema: JSONSchema
|
||||
) -> JSONSchema:
|
||||
for annotation in instance.annotations:
|
||||
if isinstance(annotation, MaxProperties):
|
||||
schema.maxProperties = annotation.value
|
||||
elif isinstance(annotation, MinProperties):
|
||||
schema.minProperties = annotation.value
|
||||
elif isinstance(annotation, DependentRequired):
|
||||
schema.dependentRequired = annotation.value
|
||||
return schema
|
||||
|
||||
|
||||
@register
|
||||
def on_collection(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if not issubclass(instance.origin_type, Collection):
|
||||
return None
|
||||
elif issubclass(instance.origin_type, Enum):
|
||||
return None
|
||||
|
||||
args = get_args(instance.type)
|
||||
|
||||
if issubclass(instance.origin_type, ByteString): # type: ignore[arg-type]
|
||||
return JSONSchema(
|
||||
type=JSONSchemaInstanceType.STRING,
|
||||
format=JSONSchemaInstanceFormatExtension.BASE64,
|
||||
)
|
||||
elif issubclass(instance.origin_type, str):
|
||||
schema = JSONSchema(type=JSONSchemaInstanceType.STRING)
|
||||
for annotation in instance.annotations:
|
||||
if isinstance(annotation, MinLength):
|
||||
schema.minLength = annotation.value
|
||||
elif isinstance(annotation, MaxLength):
|
||||
schema.maxLength = annotation.value
|
||||
elif isinstance(annotation, Pattern):
|
||||
schema.pattern = annotation.value
|
||||
return schema
|
||||
elif is_generic(instance.type) and issubclass(
|
||||
instance.origin_type, (list, deque)
|
||||
):
|
||||
return apply_array_constraints(
|
||||
instance,
|
||||
JSONArraySchema(
|
||||
items=(
|
||||
_get_schema_or_none(instance.derive(type=args[0]), ctx)
|
||||
if args
|
||||
else None
|
||||
)
|
||||
),
|
||||
)
|
||||
elif issubclass(instance.origin_type, tuple):
|
||||
if is_named_tuple(instance.origin_type):
|
||||
return apply_array_constraints(
|
||||
instance, on_named_tuple(instance, ctx)
|
||||
)
|
||||
elif is_generic(instance.type):
|
||||
return apply_array_constraints(instance, on_tuple(instance, ctx))
|
||||
elif is_generic(instance.type) and issubclass(
|
||||
instance.origin_type, (frozenset, Set)
|
||||
):
|
||||
return apply_array_constraints(
|
||||
instance,
|
||||
JSONArraySchema(
|
||||
items=(
|
||||
_get_schema_or_none(instance.derive(type=args[0]), ctx)
|
||||
if args
|
||||
else None
|
||||
),
|
||||
uniqueItems=True,
|
||||
),
|
||||
)
|
||||
elif is_generic(instance.type) and issubclass(
|
||||
instance.origin_type, ChainMap
|
||||
):
|
||||
return apply_array_constraints(
|
||||
instance,
|
||||
JSONArraySchema(
|
||||
items=get_schema(
|
||||
instance=instance.derive(
|
||||
type=(
|
||||
dict[args[0], args[1]] # type: ignore
|
||||
if args
|
||||
else dict
|
||||
)
|
||||
),
|
||||
ctx=ctx,
|
||||
)
|
||||
),
|
||||
)
|
||||
elif is_generic(instance.type) and issubclass(
|
||||
instance.origin_type, Counter
|
||||
):
|
||||
schema = JSONObjectSchema(
|
||||
additionalProperties=get_schema(instance.derive(type=int), ctx),
|
||||
)
|
||||
if args:
|
||||
schema.propertyNames = _get_schema_or_none(
|
||||
instance.derive(type=args[0]), ctx
|
||||
)
|
||||
return apply_object_constraints(instance, schema)
|
||||
elif is_typed_dict(instance.origin_type):
|
||||
return on_typed_dict(instance, ctx)
|
||||
elif is_generic(instance.type) and issubclass(
|
||||
instance.origin_type, Mapping
|
||||
):
|
||||
schema = JSONObjectSchema(
|
||||
additionalProperties=(
|
||||
_get_schema_or_none(instance.derive(type=args[1]), ctx)
|
||||
if args
|
||||
else None
|
||||
),
|
||||
propertyNames=(
|
||||
_get_schema_or_none(instance.derive(type=args[0]), ctx)
|
||||
if args
|
||||
else None
|
||||
),
|
||||
)
|
||||
return apply_object_constraints(instance, schema)
|
||||
elif is_generic(instance.type) and issubclass(
|
||||
instance.origin_type, Sequence
|
||||
):
|
||||
return apply_array_constraints(
|
||||
instance,
|
||||
JSONArraySchema(
|
||||
items=(
|
||||
_get_schema_or_none(instance.derive(type=args[0]), ctx)
|
||||
if args
|
||||
else None
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
def on_pathlike(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if issubclass(instance.origin_type, os.PathLike):
|
||||
schema = JSONSchema(
|
||||
type=JSONSchemaInstanceType.STRING,
|
||||
format=JSONSchemaInstanceFormatExtension.PATH,
|
||||
)
|
||||
for annotation in instance.annotations:
|
||||
if isinstance(annotation, MaxLength):
|
||||
schema.maxLength = annotation.value
|
||||
elif isinstance(annotation, MinLength):
|
||||
schema.minLength = annotation.value
|
||||
return schema
|
||||
|
||||
|
||||
@register
|
||||
def on_enum(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
|
||||
if issubclass(instance.origin_type, Enum):
|
||||
return JSONSchema(enum=[m.value for m in instance.origin_type])
|
||||
|
||||
|
||||
__all__ = ["Instance", "get_schema"]
|
||||
Reference in New Issue
Block a user