98 lines
3.0 KiB
Python
98 lines
3.0 KiB
Python
from collections.abc import Sequence
|
|
from dataclasses import dataclass
|
|
from typing import Any, Optional, Type
|
|
|
|
from mashumaro.jsonschema.dialects import DRAFT_2020_12, JSONSchemaDialect
|
|
from mashumaro.jsonschema.models import Context, JSONSchema
|
|
from mashumaro.jsonschema.plugins import BasePlugin
|
|
from mashumaro.jsonschema.schema import Instance, get_schema
|
|
|
|
try:
|
|
from mashumaro.mixins.orjson import (
|
|
DataClassORJSONMixin as DataClassJSONMixin,
|
|
)
|
|
except ImportError: # pragma: no cover
|
|
from mashumaro.mixins.json import DataClassJSONMixin # type: ignore
|
|
|
|
|
|
def build_json_schema(
|
|
instance_type: Type,
|
|
context: Optional[Context] = None,
|
|
with_definitions: bool = True,
|
|
all_refs: Optional[bool] = None,
|
|
with_dialect_uri: bool = False,
|
|
dialect: Optional[JSONSchemaDialect] = None,
|
|
ref_prefix: Optional[str] = None,
|
|
plugins: Sequence[BasePlugin] = (),
|
|
) -> JSONSchema:
|
|
if context is None:
|
|
context = Context()
|
|
else:
|
|
context = Context(
|
|
dialect=context.dialect,
|
|
definitions=context.definitions,
|
|
all_refs=context.all_refs,
|
|
ref_prefix=context.ref_prefix,
|
|
plugins=context.plugins,
|
|
)
|
|
if dialect is not None:
|
|
context.dialect = dialect
|
|
if all_refs is not None:
|
|
context.all_refs = all_refs
|
|
elif context.all_refs is None:
|
|
context.all_refs = context.dialect.all_refs
|
|
if ref_prefix is not None:
|
|
context.ref_prefix = ref_prefix.rstrip("/")
|
|
elif context.ref_prefix is None:
|
|
context.ref_prefix = context.dialect.definitions_root_pointer
|
|
if plugins:
|
|
context.plugins = plugins
|
|
instance = Instance(instance_type)
|
|
schema = get_schema(instance, context, with_dialect_uri=with_dialect_uri)
|
|
if with_definitions and context.definitions:
|
|
schema.definitions = context.definitions
|
|
return schema
|
|
|
|
|
|
@dataclass
|
|
class JSONSchemaDefinitions(DataClassJSONMixin):
|
|
definitions: dict[str, JSONSchema]
|
|
|
|
def __post_serialize__( # type: ignore
|
|
self, d: dict[Any, Any]
|
|
) -> list[dict[str, Any]]:
|
|
return d["definitions"]
|
|
|
|
|
|
class JSONSchemaBuilder:
|
|
def __init__(
|
|
self,
|
|
dialect: JSONSchemaDialect = DRAFT_2020_12,
|
|
all_refs: Optional[bool] = None,
|
|
ref_prefix: Optional[str] = None,
|
|
plugins: Sequence[BasePlugin] = (),
|
|
):
|
|
if all_refs is None:
|
|
all_refs = dialect.all_refs
|
|
if ref_prefix is None:
|
|
ref_prefix = dialect.definitions_root_pointer
|
|
self.context = Context(
|
|
dialect=dialect,
|
|
all_refs=all_refs,
|
|
ref_prefix=ref_prefix.rstrip("/"),
|
|
plugins=plugins,
|
|
)
|
|
|
|
def build(self, instance_type: Type) -> JSONSchema:
|
|
return build_json_schema(
|
|
instance_type=instance_type,
|
|
context=self.context,
|
|
with_definitions=False,
|
|
)
|
|
|
|
def get_definitions(self) -> JSONSchemaDefinitions:
|
|
return JSONSchemaDefinitions(self.context.definitions)
|
|
|
|
|
|
__all__ = ["JSONSchemaBuilder", "build_json_schema"]
|