v0.1.0: CRM/ERP 系统内测版本 - 安全加固完成
- Docker bridge 网络隔离(8000 端口封死) - Gunicorn 4 Worker 多进程 - Alembic 数据库迁移基线 - 日志轮转 20m×3 - JWT 密钥 + DB 密码 + CORS 收紧 - 3-2-1 备份链路(NAS + R740-B 冷备) - 连接池 pool_pre_ping + pool_recycle=3600
This commit is contained in:
@@ -0,0 +1,431 @@
|
||||
import typing
|
||||
from importlib import import_module
|
||||
from warnings import warn
|
||||
|
||||
from ._migration import getattr_migration
|
||||
from .version import VERSION
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
# import of virtually everything is supported via `__getattr__` below,
|
||||
# but we need them here for type checking and IDE support
|
||||
import pydantic_core
|
||||
from pydantic_core.core_schema import (
|
||||
FieldSerializationInfo,
|
||||
SerializationInfo,
|
||||
SerializerFunctionWrapHandler,
|
||||
ValidationInfo,
|
||||
ValidatorFunctionWrapHandler,
|
||||
)
|
||||
|
||||
from . import dataclasses
|
||||
from .aliases import AliasChoices, AliasGenerator, AliasPath
|
||||
from .annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
|
||||
from .config import ConfigDict, with_config
|
||||
from .errors import *
|
||||
from .fields import Field, PrivateAttr, computed_field
|
||||
from .functional_serializers import (
|
||||
PlainSerializer,
|
||||
SerializeAsAny,
|
||||
WrapSerializer,
|
||||
field_serializer,
|
||||
model_serializer,
|
||||
)
|
||||
from .functional_validators import (
|
||||
AfterValidator,
|
||||
BeforeValidator,
|
||||
InstanceOf,
|
||||
ModelWrapValidatorHandler,
|
||||
PlainValidator,
|
||||
SkipValidation,
|
||||
WrapValidator,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from .json_schema import WithJsonSchema
|
||||
from .main import *
|
||||
from .networks import *
|
||||
from .type_adapter import TypeAdapter
|
||||
from .types import *
|
||||
from .validate_call_decorator import validate_call
|
||||
from .warnings import (
|
||||
PydanticDeprecatedSince20,
|
||||
PydanticDeprecatedSince26,
|
||||
PydanticDeprecatedSince29,
|
||||
PydanticDeprecationWarning,
|
||||
PydanticExperimentalWarning,
|
||||
)
|
||||
|
||||
# this encourages pycharm to import `ValidationError` from here, not pydantic_core
|
||||
ValidationError = pydantic_core.ValidationError
|
||||
from .deprecated.class_validators import root_validator, validator
|
||||
from .deprecated.config import BaseConfig, Extra
|
||||
from .deprecated.tools import *
|
||||
from .root_model import RootModel
|
||||
|
||||
__version__ = VERSION
|
||||
__all__ = (
|
||||
# dataclasses
|
||||
'dataclasses',
|
||||
# functional validators
|
||||
'field_validator',
|
||||
'model_validator',
|
||||
'AfterValidator',
|
||||
'BeforeValidator',
|
||||
'PlainValidator',
|
||||
'WrapValidator',
|
||||
'SkipValidation',
|
||||
'InstanceOf',
|
||||
'ModelWrapValidatorHandler',
|
||||
# JSON Schema
|
||||
'WithJsonSchema',
|
||||
# deprecated V1 functional validators, these are imported via `__getattr__` below
|
||||
'root_validator',
|
||||
'validator',
|
||||
# functional serializers
|
||||
'field_serializer',
|
||||
'model_serializer',
|
||||
'PlainSerializer',
|
||||
'SerializeAsAny',
|
||||
'WrapSerializer',
|
||||
# config
|
||||
'ConfigDict',
|
||||
'with_config',
|
||||
# deprecated V1 config, these are imported via `__getattr__` below
|
||||
'BaseConfig',
|
||||
'Extra',
|
||||
# validate_call
|
||||
'validate_call',
|
||||
# errors
|
||||
'PydanticErrorCodes',
|
||||
'PydanticUserError',
|
||||
'PydanticSchemaGenerationError',
|
||||
'PydanticImportError',
|
||||
'PydanticUndefinedAnnotation',
|
||||
'PydanticInvalidForJsonSchema',
|
||||
# fields
|
||||
'Field',
|
||||
'computed_field',
|
||||
'PrivateAttr',
|
||||
# alias
|
||||
'AliasChoices',
|
||||
'AliasGenerator',
|
||||
'AliasPath',
|
||||
# main
|
||||
'BaseModel',
|
||||
'create_model',
|
||||
# network
|
||||
'AnyUrl',
|
||||
'AnyHttpUrl',
|
||||
'FileUrl',
|
||||
'HttpUrl',
|
||||
'FtpUrl',
|
||||
'WebsocketUrl',
|
||||
'AnyWebsocketUrl',
|
||||
'UrlConstraints',
|
||||
'EmailStr',
|
||||
'NameEmail',
|
||||
'IPvAnyAddress',
|
||||
'IPvAnyInterface',
|
||||
'IPvAnyNetwork',
|
||||
'PostgresDsn',
|
||||
'CockroachDsn',
|
||||
'AmqpDsn',
|
||||
'RedisDsn',
|
||||
'MongoDsn',
|
||||
'KafkaDsn',
|
||||
'NatsDsn',
|
||||
'MySQLDsn',
|
||||
'MariaDBDsn',
|
||||
'ClickHouseDsn',
|
||||
'SnowflakeDsn',
|
||||
'validate_email',
|
||||
# root_model
|
||||
'RootModel',
|
||||
# deprecated tools, these are imported via `__getattr__` below
|
||||
'parse_obj_as',
|
||||
'schema_of',
|
||||
'schema_json_of',
|
||||
# types
|
||||
'Strict',
|
||||
'StrictStr',
|
||||
'conbytes',
|
||||
'conlist',
|
||||
'conset',
|
||||
'confrozenset',
|
||||
'constr',
|
||||
'StringConstraints',
|
||||
'ImportString',
|
||||
'conint',
|
||||
'PositiveInt',
|
||||
'NegativeInt',
|
||||
'NonNegativeInt',
|
||||
'NonPositiveInt',
|
||||
'confloat',
|
||||
'PositiveFloat',
|
||||
'NegativeFloat',
|
||||
'NonNegativeFloat',
|
||||
'NonPositiveFloat',
|
||||
'FiniteFloat',
|
||||
'condecimal',
|
||||
'condate',
|
||||
'UUID1',
|
||||
'UUID3',
|
||||
'UUID4',
|
||||
'UUID5',
|
||||
'FilePath',
|
||||
'DirectoryPath',
|
||||
'NewPath',
|
||||
'Json',
|
||||
'Secret',
|
||||
'SecretStr',
|
||||
'SecretBytes',
|
||||
'SocketPath',
|
||||
'StrictBool',
|
||||
'StrictBytes',
|
||||
'StrictInt',
|
||||
'StrictFloat',
|
||||
'PaymentCardNumber',
|
||||
'ByteSize',
|
||||
'PastDate',
|
||||
'FutureDate',
|
||||
'PastDatetime',
|
||||
'FutureDatetime',
|
||||
'AwareDatetime',
|
||||
'NaiveDatetime',
|
||||
'AllowInfNan',
|
||||
'EncoderProtocol',
|
||||
'EncodedBytes',
|
||||
'EncodedStr',
|
||||
'Base64Encoder',
|
||||
'Base64Bytes',
|
||||
'Base64Str',
|
||||
'Base64UrlBytes',
|
||||
'Base64UrlStr',
|
||||
'GetPydanticSchema',
|
||||
'Tag',
|
||||
'Discriminator',
|
||||
'JsonValue',
|
||||
'FailFast',
|
||||
# type_adapter
|
||||
'TypeAdapter',
|
||||
# version
|
||||
'__version__',
|
||||
'VERSION',
|
||||
# warnings
|
||||
'PydanticDeprecatedSince20',
|
||||
'PydanticDeprecatedSince26',
|
||||
'PydanticDeprecatedSince29',
|
||||
'PydanticDeprecationWarning',
|
||||
'PydanticExperimentalWarning',
|
||||
# annotated handlers
|
||||
'GetCoreSchemaHandler',
|
||||
'GetJsonSchemaHandler',
|
||||
# pydantic_core
|
||||
'ValidationError',
|
||||
'ValidationInfo',
|
||||
'SerializationInfo',
|
||||
'ValidatorFunctionWrapHandler',
|
||||
'FieldSerializationInfo',
|
||||
'SerializerFunctionWrapHandler',
|
||||
'OnErrorOmit',
|
||||
)
|
||||
|
||||
# A mapping of {<member name>: (package, <module name>)} defining dynamic imports
|
||||
_dynamic_imports: 'dict[str, tuple[str, str]]' = {
|
||||
'dataclasses': (__spec__.parent, '__module__'),
|
||||
# functional validators
|
||||
'field_validator': (__spec__.parent, '.functional_validators'),
|
||||
'model_validator': (__spec__.parent, '.functional_validators'),
|
||||
'AfterValidator': (__spec__.parent, '.functional_validators'),
|
||||
'BeforeValidator': (__spec__.parent, '.functional_validators'),
|
||||
'PlainValidator': (__spec__.parent, '.functional_validators'),
|
||||
'WrapValidator': (__spec__.parent, '.functional_validators'),
|
||||
'SkipValidation': (__spec__.parent, '.functional_validators'),
|
||||
'InstanceOf': (__spec__.parent, '.functional_validators'),
|
||||
'ModelWrapValidatorHandler': (__spec__.parent, '.functional_validators'),
|
||||
# JSON Schema
|
||||
'WithJsonSchema': (__spec__.parent, '.json_schema'),
|
||||
# functional serializers
|
||||
'field_serializer': (__spec__.parent, '.functional_serializers'),
|
||||
'model_serializer': (__spec__.parent, '.functional_serializers'),
|
||||
'PlainSerializer': (__spec__.parent, '.functional_serializers'),
|
||||
'SerializeAsAny': (__spec__.parent, '.functional_serializers'),
|
||||
'WrapSerializer': (__spec__.parent, '.functional_serializers'),
|
||||
# config
|
||||
'ConfigDict': (__spec__.parent, '.config'),
|
||||
'with_config': (__spec__.parent, '.config'),
|
||||
# validate call
|
||||
'validate_call': (__spec__.parent, '.validate_call_decorator'),
|
||||
# errors
|
||||
'PydanticErrorCodes': (__spec__.parent, '.errors'),
|
||||
'PydanticUserError': (__spec__.parent, '.errors'),
|
||||
'PydanticSchemaGenerationError': (__spec__.parent, '.errors'),
|
||||
'PydanticImportError': (__spec__.parent, '.errors'),
|
||||
'PydanticUndefinedAnnotation': (__spec__.parent, '.errors'),
|
||||
'PydanticInvalidForJsonSchema': (__spec__.parent, '.errors'),
|
||||
# fields
|
||||
'Field': (__spec__.parent, '.fields'),
|
||||
'computed_field': (__spec__.parent, '.fields'),
|
||||
'PrivateAttr': (__spec__.parent, '.fields'),
|
||||
# alias
|
||||
'AliasChoices': (__spec__.parent, '.aliases'),
|
||||
'AliasGenerator': (__spec__.parent, '.aliases'),
|
||||
'AliasPath': (__spec__.parent, '.aliases'),
|
||||
# main
|
||||
'BaseModel': (__spec__.parent, '.main'),
|
||||
'create_model': (__spec__.parent, '.main'),
|
||||
# network
|
||||
'AnyUrl': (__spec__.parent, '.networks'),
|
||||
'AnyHttpUrl': (__spec__.parent, '.networks'),
|
||||
'FileUrl': (__spec__.parent, '.networks'),
|
||||
'HttpUrl': (__spec__.parent, '.networks'),
|
||||
'FtpUrl': (__spec__.parent, '.networks'),
|
||||
'WebsocketUrl': (__spec__.parent, '.networks'),
|
||||
'AnyWebsocketUrl': (__spec__.parent, '.networks'),
|
||||
'UrlConstraints': (__spec__.parent, '.networks'),
|
||||
'EmailStr': (__spec__.parent, '.networks'),
|
||||
'NameEmail': (__spec__.parent, '.networks'),
|
||||
'IPvAnyAddress': (__spec__.parent, '.networks'),
|
||||
'IPvAnyInterface': (__spec__.parent, '.networks'),
|
||||
'IPvAnyNetwork': (__spec__.parent, '.networks'),
|
||||
'PostgresDsn': (__spec__.parent, '.networks'),
|
||||
'CockroachDsn': (__spec__.parent, '.networks'),
|
||||
'AmqpDsn': (__spec__.parent, '.networks'),
|
||||
'RedisDsn': (__spec__.parent, '.networks'),
|
||||
'MongoDsn': (__spec__.parent, '.networks'),
|
||||
'KafkaDsn': (__spec__.parent, '.networks'),
|
||||
'NatsDsn': (__spec__.parent, '.networks'),
|
||||
'MySQLDsn': (__spec__.parent, '.networks'),
|
||||
'MariaDBDsn': (__spec__.parent, '.networks'),
|
||||
'ClickHouseDsn': (__spec__.parent, '.networks'),
|
||||
'SnowflakeDsn': (__spec__.parent, '.networks'),
|
||||
'validate_email': (__spec__.parent, '.networks'),
|
||||
# root_model
|
||||
'RootModel': (__spec__.parent, '.root_model'),
|
||||
# types
|
||||
'Strict': (__spec__.parent, '.types'),
|
||||
'StrictStr': (__spec__.parent, '.types'),
|
||||
'conbytes': (__spec__.parent, '.types'),
|
||||
'conlist': (__spec__.parent, '.types'),
|
||||
'conset': (__spec__.parent, '.types'),
|
||||
'confrozenset': (__spec__.parent, '.types'),
|
||||
'constr': (__spec__.parent, '.types'),
|
||||
'StringConstraints': (__spec__.parent, '.types'),
|
||||
'ImportString': (__spec__.parent, '.types'),
|
||||
'conint': (__spec__.parent, '.types'),
|
||||
'PositiveInt': (__spec__.parent, '.types'),
|
||||
'NegativeInt': (__spec__.parent, '.types'),
|
||||
'NonNegativeInt': (__spec__.parent, '.types'),
|
||||
'NonPositiveInt': (__spec__.parent, '.types'),
|
||||
'confloat': (__spec__.parent, '.types'),
|
||||
'PositiveFloat': (__spec__.parent, '.types'),
|
||||
'NegativeFloat': (__spec__.parent, '.types'),
|
||||
'NonNegativeFloat': (__spec__.parent, '.types'),
|
||||
'NonPositiveFloat': (__spec__.parent, '.types'),
|
||||
'FiniteFloat': (__spec__.parent, '.types'),
|
||||
'condecimal': (__spec__.parent, '.types'),
|
||||
'condate': (__spec__.parent, '.types'),
|
||||
'UUID1': (__spec__.parent, '.types'),
|
||||
'UUID3': (__spec__.parent, '.types'),
|
||||
'UUID4': (__spec__.parent, '.types'),
|
||||
'UUID5': (__spec__.parent, '.types'),
|
||||
'FilePath': (__spec__.parent, '.types'),
|
||||
'DirectoryPath': (__spec__.parent, '.types'),
|
||||
'NewPath': (__spec__.parent, '.types'),
|
||||
'Json': (__spec__.parent, '.types'),
|
||||
'Secret': (__spec__.parent, '.types'),
|
||||
'SecretStr': (__spec__.parent, '.types'),
|
||||
'SecretBytes': (__spec__.parent, '.types'),
|
||||
'StrictBool': (__spec__.parent, '.types'),
|
||||
'StrictBytes': (__spec__.parent, '.types'),
|
||||
'StrictInt': (__spec__.parent, '.types'),
|
||||
'StrictFloat': (__spec__.parent, '.types'),
|
||||
'PaymentCardNumber': (__spec__.parent, '.types'),
|
||||
'ByteSize': (__spec__.parent, '.types'),
|
||||
'PastDate': (__spec__.parent, '.types'),
|
||||
'SocketPath': (__spec__.parent, '.types'),
|
||||
'FutureDate': (__spec__.parent, '.types'),
|
||||
'PastDatetime': (__spec__.parent, '.types'),
|
||||
'FutureDatetime': (__spec__.parent, '.types'),
|
||||
'AwareDatetime': (__spec__.parent, '.types'),
|
||||
'NaiveDatetime': (__spec__.parent, '.types'),
|
||||
'AllowInfNan': (__spec__.parent, '.types'),
|
||||
'EncoderProtocol': (__spec__.parent, '.types'),
|
||||
'EncodedBytes': (__spec__.parent, '.types'),
|
||||
'EncodedStr': (__spec__.parent, '.types'),
|
||||
'Base64Encoder': (__spec__.parent, '.types'),
|
||||
'Base64Bytes': (__spec__.parent, '.types'),
|
||||
'Base64Str': (__spec__.parent, '.types'),
|
||||
'Base64UrlBytes': (__spec__.parent, '.types'),
|
||||
'Base64UrlStr': (__spec__.parent, '.types'),
|
||||
'GetPydanticSchema': (__spec__.parent, '.types'),
|
||||
'Tag': (__spec__.parent, '.types'),
|
||||
'Discriminator': (__spec__.parent, '.types'),
|
||||
'JsonValue': (__spec__.parent, '.types'),
|
||||
'OnErrorOmit': (__spec__.parent, '.types'),
|
||||
'FailFast': (__spec__.parent, '.types'),
|
||||
# type_adapter
|
||||
'TypeAdapter': (__spec__.parent, '.type_adapter'),
|
||||
# warnings
|
||||
'PydanticDeprecatedSince20': (__spec__.parent, '.warnings'),
|
||||
'PydanticDeprecatedSince26': (__spec__.parent, '.warnings'),
|
||||
'PydanticDeprecatedSince29': (__spec__.parent, '.warnings'),
|
||||
'PydanticDeprecationWarning': (__spec__.parent, '.warnings'),
|
||||
'PydanticExperimentalWarning': (__spec__.parent, '.warnings'),
|
||||
# annotated handlers
|
||||
'GetCoreSchemaHandler': (__spec__.parent, '.annotated_handlers'),
|
||||
'GetJsonSchemaHandler': (__spec__.parent, '.annotated_handlers'),
|
||||
# pydantic_core stuff
|
||||
'ValidationError': ('pydantic_core', '.'),
|
||||
'ValidationInfo': ('pydantic_core', '.core_schema'),
|
||||
'SerializationInfo': ('pydantic_core', '.core_schema'),
|
||||
'ValidatorFunctionWrapHandler': ('pydantic_core', '.core_schema'),
|
||||
'FieldSerializationInfo': ('pydantic_core', '.core_schema'),
|
||||
'SerializerFunctionWrapHandler': ('pydantic_core', '.core_schema'),
|
||||
# deprecated, mostly not included in __all__
|
||||
'root_validator': (__spec__.parent, '.deprecated.class_validators'),
|
||||
'validator': (__spec__.parent, '.deprecated.class_validators'),
|
||||
'BaseConfig': (__spec__.parent, '.deprecated.config'),
|
||||
'Extra': (__spec__.parent, '.deprecated.config'),
|
||||
'parse_obj_as': (__spec__.parent, '.deprecated.tools'),
|
||||
'schema_of': (__spec__.parent, '.deprecated.tools'),
|
||||
'schema_json_of': (__spec__.parent, '.deprecated.tools'),
|
||||
# deprecated dynamic imports
|
||||
'FieldValidationInfo': ('pydantic_core', '.core_schema'),
|
||||
'GenerateSchema': (__spec__.parent, '._internal._generate_schema'),
|
||||
}
|
||||
_deprecated_dynamic_imports = {'FieldValidationInfo', 'GenerateSchema'}
|
||||
|
||||
_getattr_migration = getattr_migration(__name__)
|
||||
|
||||
|
||||
def __getattr__(attr_name: str) -> object:
|
||||
if attr_name in _deprecated_dynamic_imports:
|
||||
warn(
|
||||
f'Importing {attr_name} from `pydantic` is deprecated. This feature is either no longer supported, or is not public.',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
dynamic_attr = _dynamic_imports.get(attr_name)
|
||||
if dynamic_attr is None:
|
||||
return _getattr_migration(attr_name)
|
||||
|
||||
package, module_name = dynamic_attr
|
||||
|
||||
if module_name == '__module__':
|
||||
result = import_module(f'.{attr_name}', package=package)
|
||||
globals()[attr_name] = result
|
||||
return result
|
||||
else:
|
||||
module = import_module(module_name, package=package)
|
||||
result = getattr(module, attr_name)
|
||||
g = globals()
|
||||
for k, (_, v_module_name) in _dynamic_imports.items():
|
||||
if v_module_name == module_name and k not in _deprecated_dynamic_imports:
|
||||
g[k] = getattr(module, k)
|
||||
return result
|
||||
|
||||
|
||||
def __dir__() -> 'list[str]':
|
||||
return list(__all__)
|
||||
@@ -0,0 +1,345 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from re import Pattern
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic_core import core_schema
|
||||
from typing_extensions import (
|
||||
Literal,
|
||||
Self,
|
||||
)
|
||||
|
||||
from ..aliases import AliasGenerator
|
||||
from ..config import ConfigDict, ExtraValues, JsonDict, JsonEncoder, JsonSchemaExtraCallable
|
||||
from ..errors import PydanticUserError
|
||||
from ..warnings import PydanticDeprecatedSince20, PydanticDeprecatedSince210
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._internal._schema_generation_shared import GenerateSchema
|
||||
from ..fields import ComputedFieldInfo, FieldInfo
|
||||
|
||||
DEPRECATION_MESSAGE = 'Support for class-based `config` is deprecated, use ConfigDict instead.'
|
||||
|
||||
|
||||
class ConfigWrapper:
|
||||
"""Internal wrapper for Config which exposes ConfigDict items as attributes."""
|
||||
|
||||
__slots__ = ('config_dict',)
|
||||
|
||||
config_dict: ConfigDict
|
||||
|
||||
# all annotations are copied directly from ConfigDict, and should be kept up to date, a test will fail if they
|
||||
# stop matching
|
||||
title: str | None
|
||||
str_to_lower: bool
|
||||
str_to_upper: bool
|
||||
str_strip_whitespace: bool
|
||||
str_min_length: int
|
||||
str_max_length: int | None
|
||||
extra: ExtraValues | None
|
||||
frozen: bool
|
||||
populate_by_name: bool
|
||||
use_enum_values: bool
|
||||
validate_assignment: bool
|
||||
arbitrary_types_allowed: bool
|
||||
from_attributes: bool
|
||||
# whether to use the actual key provided in the data (e.g. alias or first alias for "field required" errors) instead of field_names
|
||||
# to construct error `loc`s, default `True`
|
||||
loc_by_alias: bool
|
||||
alias_generator: Callable[[str], str] | AliasGenerator | None
|
||||
model_title_generator: Callable[[type], str] | None
|
||||
field_title_generator: Callable[[str, FieldInfo | ComputedFieldInfo], str] | None
|
||||
ignored_types: tuple[type, ...]
|
||||
allow_inf_nan: bool
|
||||
json_schema_extra: JsonDict | JsonSchemaExtraCallable | None
|
||||
json_encoders: dict[type[object], JsonEncoder] | None
|
||||
|
||||
# new in V2
|
||||
strict: bool
|
||||
# whether instances of models and dataclasses (including subclass instances) should re-validate, default 'never'
|
||||
revalidate_instances: Literal['always', 'never', 'subclass-instances']
|
||||
ser_json_timedelta: Literal['iso8601', 'float']
|
||||
ser_json_bytes: Literal['utf8', 'base64', 'hex']
|
||||
val_json_bytes: Literal['utf8', 'base64', 'hex']
|
||||
ser_json_inf_nan: Literal['null', 'constants', 'strings']
|
||||
# whether to validate default values during validation, default False
|
||||
validate_default: bool
|
||||
validate_return: bool
|
||||
protected_namespaces: tuple[str | Pattern[str], ...]
|
||||
hide_input_in_errors: bool
|
||||
defer_build: bool
|
||||
plugin_settings: dict[str, object] | None
|
||||
schema_generator: type[GenerateSchema] | None
|
||||
json_schema_serialization_defaults_required: bool
|
||||
json_schema_mode_override: Literal['validation', 'serialization', None]
|
||||
coerce_numbers_to_str: bool
|
||||
regex_engine: Literal['rust-regex', 'python-re']
|
||||
validation_error_cause: bool
|
||||
use_attribute_docstrings: bool
|
||||
cache_strings: bool | Literal['all', 'keys', 'none']
|
||||
|
||||
def __init__(self, config: ConfigDict | dict[str, Any] | type[Any] | None, *, check: bool = True):
|
||||
if check:
|
||||
self.config_dict = prepare_config(config)
|
||||
else:
|
||||
self.config_dict = cast(ConfigDict, config)
|
||||
|
||||
@classmethod
|
||||
def for_model(cls, bases: tuple[type[Any], ...], namespace: dict[str, Any], kwargs: dict[str, Any]) -> Self:
|
||||
"""Build a new `ConfigWrapper` instance for a `BaseModel`.
|
||||
|
||||
The config wrapper built based on (in descending order of priority):
|
||||
- options from `kwargs`
|
||||
- options from the `namespace`
|
||||
- options from the base classes (`bases`)
|
||||
|
||||
Args:
|
||||
bases: A tuple of base classes.
|
||||
namespace: The namespace of the class being created.
|
||||
kwargs: The kwargs passed to the class being created.
|
||||
|
||||
Returns:
|
||||
A `ConfigWrapper` instance for `BaseModel`.
|
||||
"""
|
||||
config_new = ConfigDict()
|
||||
for base in bases:
|
||||
config = getattr(base, 'model_config', None)
|
||||
if config:
|
||||
config_new.update(config.copy())
|
||||
|
||||
config_class_from_namespace = namespace.get('Config')
|
||||
config_dict_from_namespace = namespace.get('model_config')
|
||||
|
||||
raw_annotations = namespace.get('__annotations__', {})
|
||||
if raw_annotations.get('model_config') and config_dict_from_namespace is None:
|
||||
raise PydanticUserError(
|
||||
'`model_config` cannot be used as a model field name. Use `model_config` for model configuration.',
|
||||
code='model-config-invalid-field-name',
|
||||
)
|
||||
|
||||
if config_class_from_namespace and config_dict_from_namespace:
|
||||
raise PydanticUserError('"Config" and "model_config" cannot be used together', code='config-both')
|
||||
|
||||
config_from_namespace = config_dict_from_namespace or prepare_config(config_class_from_namespace)
|
||||
|
||||
config_new.update(config_from_namespace)
|
||||
|
||||
for k in list(kwargs.keys()):
|
||||
if k in config_keys:
|
||||
config_new[k] = kwargs.pop(k)
|
||||
|
||||
return cls(config_new)
|
||||
|
||||
# we don't show `__getattr__` to type checkers so missing attributes cause errors
|
||||
if not TYPE_CHECKING: # pragma: no branch
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self.config_dict[name]
|
||||
except KeyError:
|
||||
try:
|
||||
return config_defaults[name]
|
||||
except KeyError:
|
||||
raise AttributeError(f'Config has no attribute {name!r}') from None
|
||||
|
||||
def core_config(self, title: str | None) -> core_schema.CoreConfig:
|
||||
"""Create a pydantic-core config.
|
||||
|
||||
We don't use getattr here since we don't want to populate with defaults.
|
||||
|
||||
Args:
|
||||
title: The title to use if not set in config.
|
||||
|
||||
Returns:
|
||||
A `CoreConfig` object created from config.
|
||||
"""
|
||||
config = self.config_dict
|
||||
|
||||
if config.get('schema_generator') is not None:
|
||||
warnings.warn(
|
||||
'The `schema_generator` setting has been deprecated since v2.10. This setting no longer has any effect.',
|
||||
PydanticDeprecatedSince210,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
core_config_values = {
|
||||
'title': config.get('title') or title or None,
|
||||
'extra_fields_behavior': config.get('extra'),
|
||||
'allow_inf_nan': config.get('allow_inf_nan'),
|
||||
'populate_by_name': config.get('populate_by_name'),
|
||||
'str_strip_whitespace': config.get('str_strip_whitespace'),
|
||||
'str_to_lower': config.get('str_to_lower'),
|
||||
'str_to_upper': config.get('str_to_upper'),
|
||||
'strict': config.get('strict'),
|
||||
'ser_json_timedelta': config.get('ser_json_timedelta'),
|
||||
'ser_json_bytes': config.get('ser_json_bytes'),
|
||||
'val_json_bytes': config.get('val_json_bytes'),
|
||||
'ser_json_inf_nan': config.get('ser_json_inf_nan'),
|
||||
'from_attributes': config.get('from_attributes'),
|
||||
'loc_by_alias': config.get('loc_by_alias'),
|
||||
'revalidate_instances': config.get('revalidate_instances'),
|
||||
'validate_default': config.get('validate_default'),
|
||||
'str_max_length': config.get('str_max_length'),
|
||||
'str_min_length': config.get('str_min_length'),
|
||||
'hide_input_in_errors': config.get('hide_input_in_errors'),
|
||||
'coerce_numbers_to_str': config.get('coerce_numbers_to_str'),
|
||||
'regex_engine': config.get('regex_engine'),
|
||||
'validation_error_cause': config.get('validation_error_cause'),
|
||||
'cache_strings': config.get('cache_strings'),
|
||||
}
|
||||
|
||||
return core_schema.CoreConfig(**{k: v for k, v in core_config_values.items() if v is not None})
|
||||
|
||||
def __repr__(self):
|
||||
c = ', '.join(f'{k}={v!r}' for k, v in self.config_dict.items())
|
||||
return f'ConfigWrapper({c})'
|
||||
|
||||
|
||||
class ConfigWrapperStack:
|
||||
"""A stack of `ConfigWrapper` instances."""
|
||||
|
||||
def __init__(self, config_wrapper: ConfigWrapper):
|
||||
self._config_wrapper_stack: list[ConfigWrapper] = [config_wrapper]
|
||||
|
||||
@property
|
||||
def tail(self) -> ConfigWrapper:
|
||||
return self._config_wrapper_stack[-1]
|
||||
|
||||
@contextmanager
|
||||
def push(self, config_wrapper: ConfigWrapper | ConfigDict | None):
|
||||
if config_wrapper is None:
|
||||
yield
|
||||
return
|
||||
|
||||
if not isinstance(config_wrapper, ConfigWrapper):
|
||||
config_wrapper = ConfigWrapper(config_wrapper, check=False)
|
||||
|
||||
self._config_wrapper_stack.append(config_wrapper)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._config_wrapper_stack.pop()
|
||||
|
||||
|
||||
config_defaults = ConfigDict(
|
||||
title=None,
|
||||
str_to_lower=False,
|
||||
str_to_upper=False,
|
||||
str_strip_whitespace=False,
|
||||
str_min_length=0,
|
||||
str_max_length=None,
|
||||
# let the model / dataclass decide how to handle it
|
||||
extra=None,
|
||||
frozen=False,
|
||||
populate_by_name=False,
|
||||
use_enum_values=False,
|
||||
validate_assignment=False,
|
||||
arbitrary_types_allowed=False,
|
||||
from_attributes=False,
|
||||
loc_by_alias=True,
|
||||
alias_generator=None,
|
||||
model_title_generator=None,
|
||||
field_title_generator=None,
|
||||
ignored_types=(),
|
||||
allow_inf_nan=True,
|
||||
json_schema_extra=None,
|
||||
strict=False,
|
||||
revalidate_instances='never',
|
||||
ser_json_timedelta='iso8601',
|
||||
ser_json_bytes='utf8',
|
||||
val_json_bytes='utf8',
|
||||
ser_json_inf_nan='null',
|
||||
validate_default=False,
|
||||
validate_return=False,
|
||||
protected_namespaces=('model_validate', 'model_dump'),
|
||||
hide_input_in_errors=False,
|
||||
json_encoders=None,
|
||||
defer_build=False,
|
||||
schema_generator=None,
|
||||
plugin_settings=None,
|
||||
json_schema_serialization_defaults_required=False,
|
||||
json_schema_mode_override=None,
|
||||
coerce_numbers_to_str=False,
|
||||
regex_engine='rust-regex',
|
||||
validation_error_cause=False,
|
||||
use_attribute_docstrings=False,
|
||||
cache_strings=True,
|
||||
)
|
||||
|
||||
|
||||
def prepare_config(config: ConfigDict | dict[str, Any] | type[Any] | None) -> ConfigDict:
|
||||
"""Create a `ConfigDict` instance from an existing dict, a class (e.g. old class-based config) or None.
|
||||
|
||||
Args:
|
||||
config: The input config.
|
||||
|
||||
Returns:
|
||||
A ConfigDict object created from config.
|
||||
"""
|
||||
if config is None:
|
||||
return ConfigDict()
|
||||
|
||||
if not isinstance(config, dict):
|
||||
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
|
||||
config = {k: getattr(config, k) for k in dir(config) if not k.startswith('__')}
|
||||
|
||||
config_dict = cast(ConfigDict, config)
|
||||
check_deprecated(config_dict)
|
||||
return config_dict
|
||||
|
||||
|
||||
config_keys = set(ConfigDict.__annotations__.keys())
|
||||
|
||||
|
||||
V2_REMOVED_KEYS = {
|
||||
'allow_mutation',
|
||||
'error_msg_templates',
|
||||
'fields',
|
||||
'getter_dict',
|
||||
'smart_union',
|
||||
'underscore_attrs_are_private',
|
||||
'json_loads',
|
||||
'json_dumps',
|
||||
'copy_on_model_validation',
|
||||
'post_init_call',
|
||||
}
|
||||
V2_RENAMED_KEYS = {
|
||||
'allow_population_by_field_name': 'populate_by_name',
|
||||
'anystr_lower': 'str_to_lower',
|
||||
'anystr_strip_whitespace': 'str_strip_whitespace',
|
||||
'anystr_upper': 'str_to_upper',
|
||||
'keep_untouched': 'ignored_types',
|
||||
'max_anystr_length': 'str_max_length',
|
||||
'min_anystr_length': 'str_min_length',
|
||||
'orm_mode': 'from_attributes',
|
||||
'schema_extra': 'json_schema_extra',
|
||||
'validate_all': 'validate_default',
|
||||
}
|
||||
|
||||
|
||||
def check_deprecated(config_dict: ConfigDict) -> None:
|
||||
"""Check for deprecated config keys and warn the user.
|
||||
|
||||
Args:
|
||||
config_dict: The input config.
|
||||
"""
|
||||
deprecated_removed_keys = V2_REMOVED_KEYS & config_dict.keys()
|
||||
deprecated_renamed_keys = V2_RENAMED_KEYS.keys() & config_dict.keys()
|
||||
if deprecated_removed_keys or deprecated_renamed_keys:
|
||||
renamings = {k: V2_RENAMED_KEYS[k] for k in sorted(deprecated_renamed_keys)}
|
||||
renamed_bullets = [f'* {k!r} has been renamed to {v!r}' for k, v in renamings.items()]
|
||||
removed_bullets = [f'* {k!r} has been removed' for k in sorted(deprecated_removed_keys)]
|
||||
message = '\n'.join(['Valid config keys have changed in V2:'] + renamed_bullets + removed_bullets)
|
||||
warnings.warn(message, UserWarning)
|
||||
@@ -0,0 +1,91 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
from warnings import warn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..config import JsonDict, JsonSchemaExtraCallable
|
||||
from ._schema_generation_shared import (
|
||||
GetJsonSchemaFunction,
|
||||
)
|
||||
|
||||
|
||||
class CoreMetadata(TypedDict, total=False):
|
||||
"""A `TypedDict` for holding the metadata dict of the schema.
|
||||
|
||||
Attributes:
|
||||
pydantic_js_functions: List of JSON schema functions that resolve refs during application.
|
||||
pydantic_js_annotation_functions: List of JSON schema functions that don't resolve refs during application.
|
||||
pydantic_js_prefer_positional_arguments: Whether JSON schema generator will
|
||||
prefer positional over keyword arguments for an 'arguments' schema.
|
||||
custom validation function. Only applies to before, plain, and wrap validators.
|
||||
pydantic_js_udpates: key / value pair updates to apply to the JSON schema for a type.
|
||||
pydantic_js_extra: WIP, either key/value pair updates to apply to the JSON schema, or a custom callable.
|
||||
|
||||
TODO: Perhaps we should move this structure to pydantic-core. At the moment, though,
|
||||
it's easier to iterate on if we leave it in pydantic until we feel there is a semi-stable API.
|
||||
|
||||
TODO: It's unfortunate how functionally oriented JSON schema generation is, especially that which occurs during
|
||||
the core schema generation process. It's inevitable that we need to store some json schema related information
|
||||
on core schemas, given that we generate JSON schemas directly from core schemas. That being said, debugging related
|
||||
issues is quite difficult when JSON schema information is disguised via dynamically defined functions.
|
||||
"""
|
||||
|
||||
pydantic_js_functions: list[GetJsonSchemaFunction]
|
||||
pydantic_js_annotation_functions: list[GetJsonSchemaFunction]
|
||||
pydantic_js_prefer_positional_arguments: bool
|
||||
pydantic_js_updates: JsonDict
|
||||
pydantic_js_extra: JsonDict | JsonSchemaExtraCallable
|
||||
|
||||
|
||||
def update_core_metadata(
|
||||
core_metadata: Any,
|
||||
/,
|
||||
*,
|
||||
pydantic_js_functions: list[GetJsonSchemaFunction] | None = None,
|
||||
pydantic_js_annotation_functions: list[GetJsonSchemaFunction] | None = None,
|
||||
pydantic_js_updates: JsonDict | None = None,
|
||||
pydantic_js_extra: JsonDict | JsonSchemaExtraCallable | None = None,
|
||||
) -> None:
|
||||
from ..json_schema import PydanticJsonSchemaWarning
|
||||
|
||||
"""Update CoreMetadata instance in place. When we make modifications in this function, they
|
||||
take effect on the `core_metadata` reference passed in as the first (and only) positional argument.
|
||||
|
||||
First, cast to `CoreMetadata`, then finish with a cast to `dict[str, Any]` for core schema compatibility.
|
||||
We do this here, instead of before / after each call to this function so that this typing hack
|
||||
can be easily removed if/when we move `CoreMetadata` to `pydantic-core`.
|
||||
|
||||
For parameter descriptions, see `CoreMetadata` above.
|
||||
"""
|
||||
core_metadata = cast(CoreMetadata, core_metadata)
|
||||
|
||||
if pydantic_js_functions:
|
||||
core_metadata.setdefault('pydantic_js_functions', []).extend(pydantic_js_functions)
|
||||
|
||||
if pydantic_js_annotation_functions:
|
||||
core_metadata.setdefault('pydantic_js_annotation_functions', []).extend(pydantic_js_annotation_functions)
|
||||
|
||||
if pydantic_js_updates:
|
||||
if (existing_updates := core_metadata.get('pydantic_js_updates')) is not None:
|
||||
core_metadata['pydantic_js_updates'] = {**existing_updates, **pydantic_js_updates}
|
||||
else:
|
||||
core_metadata['pydantic_js_updates'] = pydantic_js_updates
|
||||
|
||||
if pydantic_js_extra is not None:
|
||||
existing_pydantic_js_extra = core_metadata.get('pydantic_js_extra')
|
||||
if existing_pydantic_js_extra is None:
|
||||
core_metadata['pydantic_js_extra'] = pydantic_js_extra
|
||||
if isinstance(existing_pydantic_js_extra, dict):
|
||||
if isinstance(pydantic_js_extra, dict):
|
||||
core_metadata['pydantic_js_extra'] = {**existing_pydantic_js_extra, **pydantic_js_extra}
|
||||
if callable(pydantic_js_extra):
|
||||
warn(
|
||||
'Composing `dict` and `callable` type `json_schema_extra` is not supported.'
|
||||
'The `callable` type is being ignored.'
|
||||
"If you'd like support for this behavior, please open an issue on pydantic.",
|
||||
PydanticJsonSchemaWarning,
|
||||
)
|
||||
if callable(existing_pydantic_js_extra):
|
||||
# if ever there's a case of a callable, we'll just keep the last json schema extra spec
|
||||
core_metadata['pydantic_js_extra'] = pydantic_js_extra
|
||||
@@ -0,0 +1,610 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Hashable, TypeVar, Union
|
||||
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from pydantic_core import validate_core_schema as _validate_core_schema
|
||||
from typing_extensions import TypeGuard, get_args, get_origin
|
||||
|
||||
from ..errors import PydanticUserError
|
||||
from . import _repr
|
||||
from ._core_metadata import CoreMetadata
|
||||
from ._typing_extra import is_generic_alias, is_type_alias_type
|
||||
|
||||
AnyFunctionSchema = Union[
|
||||
core_schema.AfterValidatorFunctionSchema,
|
||||
core_schema.BeforeValidatorFunctionSchema,
|
||||
core_schema.WrapValidatorFunctionSchema,
|
||||
core_schema.PlainValidatorFunctionSchema,
|
||||
]
|
||||
|
||||
|
||||
FunctionSchemaWithInnerSchema = Union[
|
||||
core_schema.AfterValidatorFunctionSchema,
|
||||
core_schema.BeforeValidatorFunctionSchema,
|
||||
core_schema.WrapValidatorFunctionSchema,
|
||||
]
|
||||
|
||||
CoreSchemaField = Union[
|
||||
core_schema.ModelField, core_schema.DataclassField, core_schema.TypedDictField, core_schema.ComputedField
|
||||
]
|
||||
CoreSchemaOrField = Union[core_schema.CoreSchema, CoreSchemaField]
|
||||
|
||||
_CORE_SCHEMA_FIELD_TYPES = {'typed-dict-field', 'dataclass-field', 'model-field', 'computed-field'}
|
||||
_FUNCTION_WITH_INNER_SCHEMA_TYPES = {'function-before', 'function-after', 'function-wrap'}
|
||||
_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'set', 'frozenset'}
|
||||
|
||||
TAGGED_UNION_TAG_KEY = 'pydantic.internal.tagged_union_tag'
|
||||
"""
|
||||
Used in a `Tag` schema to specify the tag used for a discriminated union.
|
||||
"""
|
||||
|
||||
|
||||
def is_core_schema(
|
||||
schema: CoreSchemaOrField,
|
||||
) -> TypeGuard[CoreSchema]:
|
||||
return schema['type'] not in _CORE_SCHEMA_FIELD_TYPES
|
||||
|
||||
|
||||
def is_core_schema_field(
|
||||
schema: CoreSchemaOrField,
|
||||
) -> TypeGuard[CoreSchemaField]:
|
||||
return schema['type'] in _CORE_SCHEMA_FIELD_TYPES
|
||||
|
||||
|
||||
def is_function_with_inner_schema(
|
||||
schema: CoreSchemaOrField,
|
||||
) -> TypeGuard[FunctionSchemaWithInnerSchema]:
|
||||
return schema['type'] in _FUNCTION_WITH_INNER_SCHEMA_TYPES
|
||||
|
||||
|
||||
def is_list_like_schema_with_items_schema(
|
||||
schema: CoreSchema,
|
||||
) -> TypeGuard[core_schema.ListSchema | core_schema.SetSchema | core_schema.FrozenSetSchema]:
|
||||
return schema['type'] in _LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES
|
||||
|
||||
|
||||
def get_type_ref(type_: type[Any], args_override: tuple[type[Any], ...] | None = None) -> str:
|
||||
"""Produces the ref to be used for this type by pydantic_core's core schemas.
|
||||
|
||||
This `args_override` argument was added for the purpose of creating valid recursive references
|
||||
when creating generic models without needing to create a concrete class.
|
||||
"""
|
||||
origin = get_origin(type_) or type_
|
||||
|
||||
args = get_args(type_) if is_generic_alias(type_) else (args_override or ())
|
||||
generic_metadata = getattr(type_, '__pydantic_generic_metadata__', None)
|
||||
if generic_metadata:
|
||||
origin = generic_metadata['origin'] or origin
|
||||
args = generic_metadata['args'] or args
|
||||
|
||||
module_name = getattr(origin, '__module__', '<No __module__>')
|
||||
if is_type_alias_type(origin):
|
||||
type_ref = f'{module_name}.{origin.__name__}:{id(origin)}'
|
||||
else:
|
||||
try:
|
||||
qualname = getattr(origin, '__qualname__', f'<No __qualname__: {origin}>')
|
||||
except Exception:
|
||||
qualname = getattr(origin, '__qualname__', '<No __qualname__>')
|
||||
type_ref = f'{module_name}.{qualname}:{id(origin)}'
|
||||
|
||||
arg_refs: list[str] = []
|
||||
for arg in args:
|
||||
if isinstance(arg, str):
|
||||
# Handle string literals as a special case; we may be able to remove this special handling if we
|
||||
# wrap them in a ForwardRef at some point.
|
||||
arg_ref = f'{arg}:str-{id(arg)}'
|
||||
else:
|
||||
arg_ref = f'{_repr.display_as_type(arg)}:{id(arg)}'
|
||||
arg_refs.append(arg_ref)
|
||||
if arg_refs:
|
||||
type_ref = f'{type_ref}[{",".join(arg_refs)}]'
|
||||
return type_ref
|
||||
|
||||
|
||||
def get_ref(s: core_schema.CoreSchema) -> None | str:
|
||||
"""Get the ref from the schema if it has one.
|
||||
This exists just for type checking to work correctly.
|
||||
"""
|
||||
return s.get('ref', None)
|
||||
|
||||
|
||||
def collect_definitions(schema: core_schema.CoreSchema) -> dict[str, core_schema.CoreSchema]:
|
||||
defs: dict[str, CoreSchema] = {}
|
||||
|
||||
def _record_valid_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
ref = get_ref(s)
|
||||
if ref:
|
||||
defs[ref] = s
|
||||
return recurse(s, _record_valid_refs)
|
||||
|
||||
walk_core_schema(schema, _record_valid_refs, copy=False)
|
||||
|
||||
return defs
|
||||
|
||||
|
||||
def define_expected_missing_refs(
|
||||
schema: core_schema.CoreSchema, allowed_missing_refs: set[str]
|
||||
) -> core_schema.CoreSchema | None:
|
||||
if not allowed_missing_refs:
|
||||
# in this case, there are no missing refs to potentially substitute, so there's no need to walk the schema
|
||||
# this is a common case (will be hit for all non-generic models), so it's worth optimizing for
|
||||
return None
|
||||
|
||||
refs = collect_definitions(schema).keys()
|
||||
|
||||
expected_missing_refs = allowed_missing_refs.difference(refs)
|
||||
if expected_missing_refs:
|
||||
definitions: list[core_schema.CoreSchema] = [
|
||||
core_schema.invalid_schema(ref=ref) for ref in expected_missing_refs
|
||||
]
|
||||
return core_schema.definitions_schema(schema, definitions)
|
||||
return None
|
||||
|
||||
|
||||
def collect_invalid_schemas(schema: core_schema.CoreSchema) -> bool:
|
||||
invalid = False
|
||||
|
||||
def _is_schema_valid(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
nonlocal invalid
|
||||
|
||||
if s['type'] == 'invalid':
|
||||
invalid = True
|
||||
return s
|
||||
|
||||
return recurse(s, _is_schema_valid)
|
||||
|
||||
walk_core_schema(schema, _is_schema_valid, copy=False)
|
||||
return invalid
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
Recurse = Callable[[core_schema.CoreSchema, 'Walk'], core_schema.CoreSchema]
|
||||
Walk = Callable[[core_schema.CoreSchema, Recurse], core_schema.CoreSchema]
|
||||
|
||||
# TODO: Should we move _WalkCoreSchema into pydantic_core proper?
|
||||
# Issue: https://github.com/pydantic/pydantic-core/issues/615
|
||||
|
||||
CoreSchemaT = TypeVar('CoreSchemaT')
|
||||
|
||||
|
||||
class _WalkCoreSchema:
|
||||
def __init__(self, *, copy: bool = True):
|
||||
self._schema_type_to_method = self._build_schema_type_to_method()
|
||||
self._copy = copy
|
||||
|
||||
def _copy_schema(self, schema: CoreSchemaT) -> CoreSchemaT:
|
||||
return schema.copy() if self._copy else schema # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
def _build_schema_type_to_method(self) -> dict[core_schema.CoreSchemaType, Recurse]:
|
||||
mapping: dict[core_schema.CoreSchemaType, Recurse] = {}
|
||||
key: core_schema.CoreSchemaType
|
||||
for key in get_args(core_schema.CoreSchemaType):
|
||||
method_name = f"handle_{key.replace('-', '_')}_schema"
|
||||
mapping[key] = getattr(self, method_name, self._handle_other_schemas)
|
||||
return mapping
|
||||
|
||||
def walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
return f(schema, self._walk)
|
||||
|
||||
def _walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema = self._schema_type_to_method[schema['type']](self._copy_schema(schema), f)
|
||||
ser_schema: core_schema.SerSchema | None = schema.get('serialization') # type: ignore
|
||||
if ser_schema:
|
||||
schema['serialization'] = self._handle_ser_schemas(ser_schema, f)
|
||||
return schema
|
||||
|
||||
def _handle_other_schemas(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
sub_schema = schema.get('schema', None)
|
||||
if sub_schema is not None:
|
||||
schema['schema'] = self.walk(sub_schema, f) # type: ignore
|
||||
return schema
|
||||
|
||||
def _handle_ser_schemas(self, ser_schema: core_schema.SerSchema, f: Walk) -> core_schema.SerSchema:
|
||||
schema: core_schema.CoreSchema | None = ser_schema.get('schema', None)
|
||||
return_schema: core_schema.CoreSchema | None = ser_schema.get('return_schema', None)
|
||||
if schema is not None or return_schema is not None:
|
||||
ser_schema = self._copy_schema(ser_schema)
|
||||
if schema is not None:
|
||||
ser_schema['schema'] = self.walk(schema, f) # type: ignore
|
||||
if return_schema is not None:
|
||||
ser_schema['return_schema'] = self.walk(return_schema, f) # type: ignore
|
||||
return ser_schema
|
||||
|
||||
def handle_definitions_schema(self, schema: core_schema.DefinitionsSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
new_definitions: list[core_schema.CoreSchema] = []
|
||||
for definition in schema['definitions']:
|
||||
if 'schema_ref' in definition and 'ref' in definition:
|
||||
# This indicates a purposely indirect reference
|
||||
# We want to keep such references around for implications related to JSON schema, etc.:
|
||||
new_definitions.append(definition)
|
||||
# However, we still need to walk the referenced definition:
|
||||
self.walk(definition, f)
|
||||
continue
|
||||
|
||||
updated_definition = self.walk(definition, f)
|
||||
if 'ref' in updated_definition:
|
||||
# If the updated definition schema doesn't have a 'ref', it shouldn't go in the definitions
|
||||
# This is most likely to happen due to replacing something with a definition reference, in
|
||||
# which case it should certainly not go in the definitions list
|
||||
new_definitions.append(updated_definition)
|
||||
new_inner_schema = self.walk(schema['schema'], f)
|
||||
|
||||
if not new_definitions and len(schema) == 3:
|
||||
# This means we'd be returning a "trivial" definitions schema that just wrapped the inner schema
|
||||
return new_inner_schema
|
||||
|
||||
new_schema = self._copy_schema(schema)
|
||||
new_schema['schema'] = new_inner_schema
|
||||
new_schema['definitions'] = new_definitions
|
||||
return new_schema
|
||||
|
||||
def handle_list_schema(self, schema: core_schema.ListSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_set_schema(self, schema: core_schema.SetSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_frozenset_schema(self, schema: core_schema.FrozenSetSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_generator_schema(self, schema: core_schema.GeneratorSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_tuple_schema(self, schema: core_schema.TupleSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['items_schema'] = [self.walk(v, f) for v in schema['items_schema']]
|
||||
return schema
|
||||
|
||||
def handle_dict_schema(self, schema: core_schema.DictSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
keys_schema = schema.get('keys_schema')
|
||||
if keys_schema is not None:
|
||||
schema['keys_schema'] = self.walk(keys_schema, f)
|
||||
values_schema = schema.get('values_schema')
|
||||
if values_schema:
|
||||
schema['values_schema'] = self.walk(values_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_function_after_schema(
|
||||
self, schema: core_schema.AfterValidatorFunctionSchema, f: Walk
|
||||
) -> core_schema.CoreSchema:
|
||||
schema['schema'] = self.walk(schema['schema'], f)
|
||||
return schema
|
||||
|
||||
def handle_function_before_schema(
|
||||
self, schema: core_schema.BeforeValidatorFunctionSchema, f: Walk
|
||||
) -> core_schema.CoreSchema:
|
||||
schema['schema'] = self.walk(schema['schema'], f)
|
||||
if 'json_schema_input_schema' in schema:
|
||||
schema['json_schema_input_schema'] = self.walk(schema['json_schema_input_schema'], f)
|
||||
return schema
|
||||
|
||||
# TODO duplicate schema types for serializers and validators, needs to be deduplicated:
|
||||
def handle_function_plain_schema(
|
||||
self, schema: core_schema.PlainValidatorFunctionSchema | core_schema.PlainSerializerFunctionSerSchema, f: Walk
|
||||
) -> core_schema.CoreSchema:
|
||||
if 'json_schema_input_schema' in schema:
|
||||
schema['json_schema_input_schema'] = self.walk(schema['json_schema_input_schema'], f)
|
||||
return schema # pyright: ignore[reportReturnType]
|
||||
|
||||
# TODO duplicate schema types for serializers and validators, needs to be deduplicated:
|
||||
def handle_function_wrap_schema(
|
||||
self, schema: core_schema.WrapValidatorFunctionSchema | core_schema.WrapSerializerFunctionSerSchema, f: Walk
|
||||
) -> core_schema.CoreSchema:
|
||||
if 'schema' in schema:
|
||||
schema['schema'] = self.walk(schema['schema'], f)
|
||||
if 'json_schema_input_schema' in schema:
|
||||
schema['json_schema_input_schema'] = self.walk(schema['json_schema_input_schema'], f)
|
||||
return schema # pyright: ignore[reportReturnType]
|
||||
|
||||
def handle_union_schema(self, schema: core_schema.UnionSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
new_choices: list[CoreSchema | tuple[CoreSchema, str]] = []
|
||||
for v in schema['choices']:
|
||||
if isinstance(v, tuple):
|
||||
new_choices.append((self.walk(v[0], f), v[1]))
|
||||
else:
|
||||
new_choices.append(self.walk(v, f))
|
||||
schema['choices'] = new_choices
|
||||
return schema
|
||||
|
||||
def handle_tagged_union_schema(self, schema: core_schema.TaggedUnionSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
new_choices: dict[Hashable, core_schema.CoreSchema] = {}
|
||||
for k, v in schema['choices'].items():
|
||||
new_choices[k] = v if isinstance(v, (str, int)) else self.walk(v, f)
|
||||
schema['choices'] = new_choices
|
||||
return schema
|
||||
|
||||
def handle_chain_schema(self, schema: core_schema.ChainSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['steps'] = [self.walk(v, f) for v in schema['steps']]
|
||||
return schema
|
||||
|
||||
def handle_lax_or_strict_schema(self, schema: core_schema.LaxOrStrictSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['lax_schema'] = self.walk(schema['lax_schema'], f)
|
||||
schema['strict_schema'] = self.walk(schema['strict_schema'], f)
|
||||
return schema
|
||||
|
||||
def handle_json_or_python_schema(self, schema: core_schema.JsonOrPythonSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['json_schema'] = self.walk(schema['json_schema'], f)
|
||||
schema['python_schema'] = self.walk(schema['python_schema'], f)
|
||||
return schema
|
||||
|
||||
def handle_model_fields_schema(self, schema: core_schema.ModelFieldsSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
extras_schema = schema.get('extras_schema')
|
||||
if extras_schema is not None:
|
||||
schema['extras_schema'] = self.walk(extras_schema, f)
|
||||
replaced_fields: dict[str, core_schema.ModelField] = {}
|
||||
replaced_computed_fields: list[core_schema.ComputedField] = []
|
||||
for computed_field in schema.get('computed_fields', ()):
|
||||
replaced_field = self._copy_schema(computed_field)
|
||||
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
|
||||
replaced_computed_fields.append(replaced_field)
|
||||
if replaced_computed_fields:
|
||||
schema['computed_fields'] = replaced_computed_fields
|
||||
for k, v in schema['fields'].items():
|
||||
replaced_field = self._copy_schema(v)
|
||||
replaced_field['schema'] = self.walk(v['schema'], f)
|
||||
replaced_fields[k] = replaced_field
|
||||
schema['fields'] = replaced_fields
|
||||
return schema
|
||||
|
||||
def handle_typed_dict_schema(self, schema: core_schema.TypedDictSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
extras_schema = schema.get('extras_schema')
|
||||
if extras_schema is not None:
|
||||
schema['extras_schema'] = self.walk(extras_schema, f)
|
||||
replaced_computed_fields: list[core_schema.ComputedField] = []
|
||||
for computed_field in schema.get('computed_fields', ()):
|
||||
replaced_field = self._copy_schema(computed_field)
|
||||
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
|
||||
replaced_computed_fields.append(replaced_field)
|
||||
if replaced_computed_fields:
|
||||
schema['computed_fields'] = replaced_computed_fields
|
||||
replaced_fields: dict[str, core_schema.TypedDictField] = {}
|
||||
for k, v in schema['fields'].items():
|
||||
replaced_field = self._copy_schema(v)
|
||||
replaced_field['schema'] = self.walk(v['schema'], f)
|
||||
replaced_fields[k] = replaced_field
|
||||
schema['fields'] = replaced_fields
|
||||
return schema
|
||||
|
||||
def handle_dataclass_args_schema(self, schema: core_schema.DataclassArgsSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
replaced_fields: list[core_schema.DataclassField] = []
|
||||
replaced_computed_fields: list[core_schema.ComputedField] = []
|
||||
for computed_field in schema.get('computed_fields', ()):
|
||||
replaced_field = self._copy_schema(computed_field)
|
||||
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
|
||||
replaced_computed_fields.append(replaced_field)
|
||||
if replaced_computed_fields:
|
||||
schema['computed_fields'] = replaced_computed_fields
|
||||
for field in schema['fields']:
|
||||
replaced_field = self._copy_schema(field)
|
||||
replaced_field['schema'] = self.walk(field['schema'], f)
|
||||
replaced_fields.append(replaced_field)
|
||||
schema['fields'] = replaced_fields
|
||||
return schema
|
||||
|
||||
def handle_arguments_schema(self, schema: core_schema.ArgumentsSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
replaced_arguments_schema: list[core_schema.ArgumentsParameter] = []
|
||||
for param in schema['arguments_schema']:
|
||||
replaced_param = self._copy_schema(param)
|
||||
replaced_param['schema'] = self.walk(param['schema'], f)
|
||||
replaced_arguments_schema.append(replaced_param)
|
||||
schema['arguments_schema'] = replaced_arguments_schema
|
||||
if 'var_args_schema' in schema:
|
||||
schema['var_args_schema'] = self.walk(schema['var_args_schema'], f)
|
||||
if 'var_kwargs_schema' in schema:
|
||||
schema['var_kwargs_schema'] = self.walk(schema['var_kwargs_schema'], f)
|
||||
return schema
|
||||
|
||||
def handle_call_schema(self, schema: core_schema.CallSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['arguments_schema'] = self.walk(schema['arguments_schema'], f)
|
||||
if 'return_schema' in schema:
|
||||
schema['return_schema'] = self.walk(schema['return_schema'], f)
|
||||
return schema
|
||||
|
||||
|
||||
_dispatch = _WalkCoreSchema().walk
|
||||
_dispatch_no_copy = _WalkCoreSchema(copy=False).walk
|
||||
|
||||
|
||||
def walk_core_schema(schema: core_schema.CoreSchema, f: Walk, *, copy: bool = True) -> core_schema.CoreSchema:
|
||||
"""Recursively traverse a CoreSchema.
|
||||
|
||||
Args:
|
||||
schema (core_schema.CoreSchema): The CoreSchema to process, it will not be modified.
|
||||
f (Walk): A function to apply. This function takes two arguments:
|
||||
1. The current CoreSchema that is being processed
|
||||
(not the same one you passed into this function, one level down).
|
||||
2. The "next" `f` to call. This lets you for example use `f=functools.partial(some_method, some_context)`
|
||||
to pass data down the recursive calls without using globals or other mutable state.
|
||||
copy: Whether schema should be recursively copied.
|
||||
|
||||
Returns:
|
||||
core_schema.CoreSchema: A processed CoreSchema.
|
||||
"""
|
||||
return f(schema.copy() if copy else schema, _dispatch if copy else _dispatch_no_copy)
|
||||
|
||||
|
||||
def simplify_schema_references(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: # noqa: C901
|
||||
definitions: dict[str, core_schema.CoreSchema] = {}
|
||||
ref_counts: dict[str, int] = defaultdict(int)
|
||||
involved_in_recursion: dict[str, bool] = {}
|
||||
current_recursion_ref_count: dict[str, int] = defaultdict(int)
|
||||
|
||||
def collect_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
if s['type'] == 'definitions':
|
||||
for definition in s['definitions']:
|
||||
ref = get_ref(definition)
|
||||
assert ref is not None
|
||||
if ref not in definitions:
|
||||
definitions[ref] = definition
|
||||
recurse(definition, collect_refs)
|
||||
return recurse(s['schema'], collect_refs)
|
||||
else:
|
||||
ref = get_ref(s)
|
||||
if ref is not None:
|
||||
new = recurse(s, collect_refs)
|
||||
new_ref = get_ref(new)
|
||||
if new_ref:
|
||||
definitions[new_ref] = new
|
||||
return core_schema.definition_reference_schema(schema_ref=ref)
|
||||
else:
|
||||
return recurse(s, collect_refs)
|
||||
|
||||
schema = walk_core_schema(schema, collect_refs)
|
||||
|
||||
def count_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
if s['type'] != 'definition-ref':
|
||||
return recurse(s, count_refs)
|
||||
ref = s['schema_ref']
|
||||
ref_counts[ref] += 1
|
||||
|
||||
if ref_counts[ref] >= 2:
|
||||
# If this model is involved in a recursion this should be detected
|
||||
# on its second encounter, we can safely stop the walk here.
|
||||
if current_recursion_ref_count[ref] != 0:
|
||||
involved_in_recursion[ref] = True
|
||||
return s
|
||||
|
||||
current_recursion_ref_count[ref] += 1
|
||||
if 'serialization' in s:
|
||||
# Even though this is a `'definition-ref'` schema, there might
|
||||
# be more references inside the serialization schema:
|
||||
recurse(s, count_refs)
|
||||
|
||||
next_s = definitions[ref]
|
||||
visited: set[str] = set()
|
||||
while next_s['type'] == 'definition-ref':
|
||||
if next_s['schema_ref'] in visited:
|
||||
raise PydanticUserError(
|
||||
f'{ref} contains a circular reference to itself.', code='circular-reference-schema'
|
||||
)
|
||||
|
||||
visited.add(next_s['schema_ref'])
|
||||
ref_counts[next_s['schema_ref']] += 1
|
||||
next_s = definitions[next_s['schema_ref']]
|
||||
|
||||
recurse(next_s, count_refs)
|
||||
current_recursion_ref_count[ref] -= 1
|
||||
return s
|
||||
|
||||
schema = walk_core_schema(schema, count_refs, copy=False)
|
||||
|
||||
assert all(c == 0 for c in current_recursion_ref_count.values()), 'this is a bug! please report it'
|
||||
|
||||
def can_be_inlined(s: core_schema.DefinitionReferenceSchema, ref: str) -> bool:
|
||||
if ref_counts[ref] > 1:
|
||||
return False
|
||||
if involved_in_recursion.get(ref, False):
|
||||
return False
|
||||
if 'serialization' in s:
|
||||
return False
|
||||
if 'metadata' in s:
|
||||
metadata = s['metadata']
|
||||
for k in [
|
||||
*CoreMetadata.__annotations__.keys(),
|
||||
'pydantic.internal.union_discriminator',
|
||||
'pydantic.internal.tagged_union_tag',
|
||||
]:
|
||||
if k in metadata:
|
||||
# we need to keep this as a ref
|
||||
return False
|
||||
return True
|
||||
|
||||
def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
# Assume there are no infinite loops, because we already checked for that in `count_refs`
|
||||
while s['type'] == 'definition-ref':
|
||||
ref = s['schema_ref']
|
||||
|
||||
# Check if the reference is only used once, not involved in recursion and does not have
|
||||
# any extra keys (like 'serialization')
|
||||
if can_be_inlined(s, ref):
|
||||
# Inline the reference by replacing the reference with the actual schema
|
||||
new = definitions.pop(ref)
|
||||
ref_counts[ref] -= 1 # because we just replaced it!
|
||||
# put all other keys that were on the def-ref schema into the inlined version
|
||||
# in particular this is needed for `serialization`
|
||||
if 'serialization' in s:
|
||||
new['serialization'] = s['serialization']
|
||||
s = new
|
||||
else:
|
||||
break
|
||||
return recurse(s, inline_refs)
|
||||
|
||||
schema = walk_core_schema(schema, inline_refs, copy=False)
|
||||
|
||||
def_values = [v for v in definitions.values() if ref_counts[v['ref']] > 0] # type: ignore
|
||||
|
||||
if def_values:
|
||||
schema = core_schema.definitions_schema(schema=schema, definitions=def_values)
|
||||
return schema
|
||||
|
||||
|
||||
def _strip_metadata(schema: CoreSchema) -> CoreSchema:
|
||||
def strip_metadata(s: CoreSchema, recurse: Recurse) -> CoreSchema:
|
||||
s = s.copy()
|
||||
s.pop('metadata', None)
|
||||
if s['type'] == 'model-fields':
|
||||
s = s.copy()
|
||||
s['fields'] = {k: v.copy() for k, v in s['fields'].items()}
|
||||
for field_name, field_schema in s['fields'].items():
|
||||
field_schema.pop('metadata', None)
|
||||
s['fields'][field_name] = field_schema
|
||||
computed_fields = s.get('computed_fields', None)
|
||||
if computed_fields:
|
||||
s['computed_fields'] = [cf.copy() for cf in computed_fields]
|
||||
for cf in computed_fields:
|
||||
cf.pop('metadata', None)
|
||||
else:
|
||||
s.pop('computed_fields', None)
|
||||
elif s['type'] == 'model':
|
||||
# remove some defaults
|
||||
if s.get('custom_init', True) is False:
|
||||
s.pop('custom_init')
|
||||
if s.get('root_model', True) is False:
|
||||
s.pop('root_model')
|
||||
if {'title'}.issuperset(s.get('config', {}).keys()):
|
||||
s.pop('config', None)
|
||||
|
||||
return recurse(s, strip_metadata)
|
||||
|
||||
return walk_core_schema(schema, strip_metadata)
|
||||
|
||||
|
||||
def pretty_print_core_schema(
|
||||
schema: CoreSchema,
|
||||
include_metadata: bool = False,
|
||||
) -> None:
|
||||
"""Pretty print a CoreSchema using rich.
|
||||
This is intended for debugging purposes.
|
||||
|
||||
Args:
|
||||
schema: The CoreSchema to print.
|
||||
include_metadata: Whether to include metadata in the output. Defaults to `False`.
|
||||
"""
|
||||
from rich import print # type: ignore # install it manually in your dev env
|
||||
|
||||
if not include_metadata:
|
||||
schema = _strip_metadata(schema)
|
||||
|
||||
return print(schema)
|
||||
|
||||
|
||||
def validate_core_schema(schema: CoreSchema) -> CoreSchema:
|
||||
if 'PYDANTIC_SKIP_VALIDATING_CORE_SCHEMAS' in os.environ:
|
||||
return schema
|
||||
return _validate_core_schema(schema)
|
||||
@@ -0,0 +1,246 @@
|
||||
"""Private logic for creating pydantic dataclasses."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses
|
||||
import typing
|
||||
import warnings
|
||||
from functools import partial, wraps
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from pydantic_core import (
|
||||
ArgsKwargs,
|
||||
SchemaSerializer,
|
||||
SchemaValidator,
|
||||
core_schema,
|
||||
)
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
from ..errors import PydanticUndefinedAnnotation
|
||||
from ..plugin._schema_validator import PluggableSchemaValidator, create_schema_validator
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
from . import _config, _decorators
|
||||
from ._fields import collect_dataclass_fields
|
||||
from ._generate_schema import GenerateSchema
|
||||
from ._generics import get_standard_typevars_map
|
||||
from ._mock_val_ser import set_dataclass_mocks
|
||||
from ._namespace_utils import NsResolver
|
||||
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
|
||||
from ._signature import generate_pydantic_signature
|
||||
from ._utils import LazyClassAttribute
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from _typeshed import DataclassInstance as StandardDataclass
|
||||
|
||||
from ..config import ConfigDict
|
||||
from ..fields import FieldInfo
|
||||
|
||||
class PydanticDataclass(StandardDataclass, typing.Protocol):
|
||||
"""A protocol containing attributes only available once a class has been decorated as a Pydantic dataclass.
|
||||
|
||||
Attributes:
|
||||
__pydantic_config__: Pydantic-specific configuration settings for the dataclass.
|
||||
__pydantic_complete__: Whether dataclass building is completed, or if there are still undefined fields.
|
||||
__pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
|
||||
__pydantic_decorators__: Metadata containing the decorators defined on the dataclass.
|
||||
__pydantic_fields__: Metadata about the fields defined on the dataclass.
|
||||
__pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the dataclass.
|
||||
__pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the dataclass.
|
||||
"""
|
||||
|
||||
__pydantic_config__: ClassVar[ConfigDict]
|
||||
__pydantic_complete__: ClassVar[bool]
|
||||
__pydantic_core_schema__: ClassVar[core_schema.CoreSchema]
|
||||
__pydantic_decorators__: ClassVar[_decorators.DecoratorInfos]
|
||||
__pydantic_fields__: ClassVar[dict[str, FieldInfo]]
|
||||
__pydantic_serializer__: ClassVar[SchemaSerializer]
|
||||
__pydantic_validator__: ClassVar[SchemaValidator | PluggableSchemaValidator]
|
||||
|
||||
else:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
|
||||
|
||||
def set_dataclass_fields(
|
||||
cls: type[StandardDataclass],
|
||||
ns_resolver: NsResolver | None = None,
|
||||
config_wrapper: _config.ConfigWrapper | None = None,
|
||||
) -> None:
|
||||
"""Collect and set `cls.__pydantic_fields__`.
|
||||
|
||||
Args:
|
||||
cls: The class.
|
||||
ns_resolver: Namespace resolver to use when getting dataclass annotations.
|
||||
config_wrapper: The config wrapper instance, defaults to `None`.
|
||||
"""
|
||||
typevars_map = get_standard_typevars_map(cls)
|
||||
fields = collect_dataclass_fields(
|
||||
cls, ns_resolver=ns_resolver, typevars_map=typevars_map, config_wrapper=config_wrapper
|
||||
)
|
||||
|
||||
cls.__pydantic_fields__ = fields # type: ignore
|
||||
|
||||
|
||||
def complete_dataclass(
|
||||
cls: type[Any],
|
||||
config_wrapper: _config.ConfigWrapper,
|
||||
*,
|
||||
raise_errors: bool = True,
|
||||
ns_resolver: NsResolver | None = None,
|
||||
_force_build: bool = False,
|
||||
) -> bool:
|
||||
"""Finish building a pydantic dataclass.
|
||||
|
||||
This logic is called on a class which has already been wrapped in `dataclasses.dataclass()`.
|
||||
|
||||
This is somewhat analogous to `pydantic._internal._model_construction.complete_model_class`.
|
||||
|
||||
Args:
|
||||
cls: The class.
|
||||
config_wrapper: The config wrapper instance.
|
||||
raise_errors: Whether to raise errors, defaults to `True`.
|
||||
ns_resolver: The namespace resolver instance to use when collecting dataclass fields
|
||||
and during schema building.
|
||||
_force_build: Whether to force building the dataclass, no matter if
|
||||
[`defer_build`][pydantic.config.ConfigDict.defer_build] is set.
|
||||
|
||||
Returns:
|
||||
`True` if building a pydantic dataclass is successfully completed, `False` otherwise.
|
||||
|
||||
Raises:
|
||||
PydanticUndefinedAnnotation: If `raise_error` is `True` and there is an undefined annotations.
|
||||
"""
|
||||
original_init = cls.__init__
|
||||
|
||||
# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied,
|
||||
# and so that the mock validator is used if building was deferred:
|
||||
def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -> None:
|
||||
__tracebackhide__ = True
|
||||
s = __dataclass_self__
|
||||
s.__pydantic_validator__.validate_python(ArgsKwargs(args, kwargs), self_instance=s)
|
||||
|
||||
__init__.__qualname__ = f'{cls.__qualname__}.__init__'
|
||||
|
||||
cls.__init__ = __init__ # type: ignore
|
||||
cls.__pydantic_config__ = config_wrapper.config_dict # type: ignore
|
||||
|
||||
set_dataclass_fields(cls, ns_resolver, config_wrapper=config_wrapper)
|
||||
|
||||
if not _force_build and config_wrapper.defer_build:
|
||||
set_dataclass_mocks(cls, cls.__name__)
|
||||
return False
|
||||
|
||||
if hasattr(cls, '__post_init_post_parse__'):
|
||||
warnings.warn(
|
||||
'Support for `__post_init_post_parse__` has been dropped, the method will not be called', DeprecationWarning
|
||||
)
|
||||
|
||||
typevars_map = get_standard_typevars_map(cls)
|
||||
gen_schema = GenerateSchema(
|
||||
config_wrapper,
|
||||
ns_resolver=ns_resolver,
|
||||
typevars_map=typevars_map,
|
||||
)
|
||||
|
||||
# set __signature__ attr only for the class, but not for its instances
|
||||
# (because instances can define `__call__`, and `inspect.signature` shouldn't
|
||||
# use the `__signature__` attribute and instead generate from `__call__`).
|
||||
cls.__signature__ = LazyClassAttribute(
|
||||
'__signature__',
|
||||
partial(
|
||||
generate_pydantic_signature,
|
||||
# It's important that we reference the `original_init` here,
|
||||
# as it is the one synthesized by the stdlib `dataclass` module:
|
||||
init=original_init,
|
||||
fields=cls.__pydantic_fields__, # type: ignore
|
||||
populate_by_name=config_wrapper.populate_by_name,
|
||||
extra=config_wrapper.extra,
|
||||
is_dataclass=True,
|
||||
),
|
||||
)
|
||||
get_core_schema = getattr(cls, '__get_pydantic_core_schema__', None)
|
||||
try:
|
||||
if get_core_schema:
|
||||
schema = get_core_schema(
|
||||
cls,
|
||||
CallbackGetCoreSchemaHandler(
|
||||
partial(gen_schema.generate_schema, from_dunder_get_core_schema=False),
|
||||
gen_schema,
|
||||
ref_mode='unpack',
|
||||
),
|
||||
)
|
||||
else:
|
||||
schema = gen_schema.generate_schema(cls, from_dunder_get_core_schema=False)
|
||||
except PydanticUndefinedAnnotation as e:
|
||||
if raise_errors:
|
||||
raise
|
||||
set_dataclass_mocks(cls, cls.__name__, f'`{e.name}`')
|
||||
return False
|
||||
|
||||
core_config = config_wrapper.core_config(title=cls.__name__)
|
||||
|
||||
try:
|
||||
schema = gen_schema.clean_schema(schema)
|
||||
except gen_schema.CollectedInvalid:
|
||||
set_dataclass_mocks(cls, cls.__name__, 'all referenced types')
|
||||
return False
|
||||
|
||||
# We are about to set all the remaining required properties expected for this cast;
|
||||
# __pydantic_decorators__ and __pydantic_fields__ should already be set
|
||||
cls = typing.cast('type[PydanticDataclass]', cls)
|
||||
# debug(schema)
|
||||
|
||||
cls.__pydantic_core_schema__ = schema
|
||||
cls.__pydantic_validator__ = validator = create_schema_validator(
|
||||
schema, cls, cls.__module__, cls.__qualname__, 'dataclass', core_config, config_wrapper.plugin_settings
|
||||
)
|
||||
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
|
||||
|
||||
if config_wrapper.validate_assignment:
|
||||
|
||||
@wraps(cls.__setattr__)
|
||||
def validated_setattr(instance: Any, field: str, value: str, /) -> None:
|
||||
validator.validate_assignment(instance, field, value)
|
||||
|
||||
cls.__setattr__ = validated_setattr.__get__(None, cls) # type: ignore
|
||||
|
||||
cls.__pydantic_complete__ = True
|
||||
return True
|
||||
|
||||
|
||||
def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
|
||||
"""Returns True if a class is a stdlib dataclass and *not* a pydantic dataclass.
|
||||
|
||||
We check that
|
||||
- `_cls` is a dataclass
|
||||
- `_cls` does not inherit from a processed pydantic dataclass (and thus have a `__pydantic_validator__`)
|
||||
- `_cls` does not have any annotations that are not dataclass fields
|
||||
e.g.
|
||||
```python
|
||||
import dataclasses
|
||||
|
||||
import pydantic.dataclasses
|
||||
|
||||
@dataclasses.dataclass
|
||||
class A:
|
||||
x: int
|
||||
|
||||
@pydantic.dataclasses.dataclass
|
||||
class B(A):
|
||||
y: int
|
||||
```
|
||||
In this case, when we first check `B`, we make an extra check and look at the annotations ('y'),
|
||||
which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x')
|
||||
|
||||
Args:
|
||||
cls: The class.
|
||||
|
||||
Returns:
|
||||
`True` if the class is a stdlib dataclass, `False` otherwise.
|
||||
"""
|
||||
return (
|
||||
dataclasses.is_dataclass(_cls)
|
||||
and not hasattr(_cls, '__pydantic_validator__')
|
||||
and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {})))
|
||||
)
|
||||
@@ -0,0 +1,823 @@
|
||||
"""Logic related to validators applied to models etc. via the `@field_validator` and `@model_validator` decorators."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property, partial, partialmethod
|
||||
from inspect import Parameter, Signature, isdatadescriptor, ismethoddescriptor, signature
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Iterable, TypeVar, Union
|
||||
|
||||
from pydantic_core import PydanticUndefined, core_schema
|
||||
from typing_extensions import Literal, TypeAlias, is_typeddict
|
||||
|
||||
from ..errors import PydanticUserError
|
||||
from ._core_utils import get_type_ref
|
||||
from ._internal_dataclass import slots_true
|
||||
from ._namespace_utils import GlobalsNamespace, MappingNamespace
|
||||
from ._typing_extra import get_function_type_hints
|
||||
from ._utils import can_be_positional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..fields import ComputedFieldInfo
|
||||
from ..functional_validators import FieldValidatorModes
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
class ValidatorDecoratorInfo:
|
||||
"""A container for data from `@validator` so that we can access it
|
||||
while building the pydantic-core schema.
|
||||
|
||||
Attributes:
|
||||
decorator_repr: A class variable representing the decorator string, '@validator'.
|
||||
fields: A tuple of field names the validator should be called on.
|
||||
mode: The proposed validator mode.
|
||||
each_item: For complex objects (sets, lists etc.) whether to validate individual
|
||||
elements rather than the whole object.
|
||||
always: Whether this method and other validators should be called even if the value is missing.
|
||||
check_fields: Whether to check that the fields actually exist on the model.
|
||||
"""
|
||||
|
||||
decorator_repr: ClassVar[str] = '@validator'
|
||||
|
||||
fields: tuple[str, ...]
|
||||
mode: Literal['before', 'after']
|
||||
each_item: bool
|
||||
always: bool
|
||||
check_fields: bool | None
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
class FieldValidatorDecoratorInfo:
|
||||
"""A container for data from `@field_validator` so that we can access it
|
||||
while building the pydantic-core schema.
|
||||
|
||||
Attributes:
|
||||
decorator_repr: A class variable representing the decorator string, '@field_validator'.
|
||||
fields: A tuple of field names the validator should be called on.
|
||||
mode: The proposed validator mode.
|
||||
check_fields: Whether to check that the fields actually exist on the model.
|
||||
json_schema_input_type: The input type of the function. This is only used to generate
|
||||
the appropriate JSON Schema (in validation mode) and can only specified
|
||||
when `mode` is either `'before'`, `'plain'` or `'wrap'`.
|
||||
"""
|
||||
|
||||
decorator_repr: ClassVar[str] = '@field_validator'
|
||||
|
||||
fields: tuple[str, ...]
|
||||
mode: FieldValidatorModes
|
||||
check_fields: bool | None
|
||||
json_schema_input_type: Any
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
class RootValidatorDecoratorInfo:
|
||||
"""A container for data from `@root_validator` so that we can access it
|
||||
while building the pydantic-core schema.
|
||||
|
||||
Attributes:
|
||||
decorator_repr: A class variable representing the decorator string, '@root_validator'.
|
||||
mode: The proposed validator mode.
|
||||
"""
|
||||
|
||||
decorator_repr: ClassVar[str] = '@root_validator'
|
||||
mode: Literal['before', 'after']
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
class FieldSerializerDecoratorInfo:
|
||||
"""A container for data from `@field_serializer` so that we can access it
|
||||
while building the pydantic-core schema.
|
||||
|
||||
Attributes:
|
||||
decorator_repr: A class variable representing the decorator string, '@field_serializer'.
|
||||
fields: A tuple of field names the serializer should be called on.
|
||||
mode: The proposed serializer mode.
|
||||
return_type: The type of the serializer's return value.
|
||||
when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`,
|
||||
and `'json-unless-none'`.
|
||||
check_fields: Whether to check that the fields actually exist on the model.
|
||||
"""
|
||||
|
||||
decorator_repr: ClassVar[str] = '@field_serializer'
|
||||
fields: tuple[str, ...]
|
||||
mode: Literal['plain', 'wrap']
|
||||
return_type: Any
|
||||
when_used: core_schema.WhenUsed
|
||||
check_fields: bool | None
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
class ModelSerializerDecoratorInfo:
|
||||
"""A container for data from `@model_serializer` so that we can access it
|
||||
while building the pydantic-core schema.
|
||||
|
||||
Attributes:
|
||||
decorator_repr: A class variable representing the decorator string, '@model_serializer'.
|
||||
mode: The proposed serializer mode.
|
||||
return_type: The type of the serializer's return value.
|
||||
when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`,
|
||||
and `'json-unless-none'`.
|
||||
"""
|
||||
|
||||
decorator_repr: ClassVar[str] = '@model_serializer'
|
||||
mode: Literal['plain', 'wrap']
|
||||
return_type: Any
|
||||
when_used: core_schema.WhenUsed
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
class ModelValidatorDecoratorInfo:
|
||||
"""A container for data from `@model_validator` so that we can access it
|
||||
while building the pydantic-core schema.
|
||||
|
||||
Attributes:
|
||||
decorator_repr: A class variable representing the decorator string, '@model_validator'.
|
||||
mode: The proposed serializer mode.
|
||||
"""
|
||||
|
||||
decorator_repr: ClassVar[str] = '@model_validator'
|
||||
mode: Literal['wrap', 'before', 'after']
|
||||
|
||||
|
||||
DecoratorInfo: TypeAlias = """Union[
|
||||
ValidatorDecoratorInfo,
|
||||
FieldValidatorDecoratorInfo,
|
||||
RootValidatorDecoratorInfo,
|
||||
FieldSerializerDecoratorInfo,
|
||||
ModelSerializerDecoratorInfo,
|
||||
ModelValidatorDecoratorInfo,
|
||||
ComputedFieldInfo,
|
||||
]"""
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
DecoratedType: TypeAlias = (
|
||||
'Union[classmethod[Any, Any, ReturnType], staticmethod[Any, ReturnType], Callable[..., ReturnType], property]'
|
||||
)
|
||||
|
||||
|
||||
@dataclass # can't use slots here since we set attributes on `__post_init__`
|
||||
class PydanticDescriptorProxy(Generic[ReturnType]):
|
||||
"""Wrap a classmethod, staticmethod, property or unbound function
|
||||
and act as a descriptor that allows us to detect decorated items
|
||||
from the class' attributes.
|
||||
|
||||
This class' __get__ returns the wrapped item's __get__ result,
|
||||
which makes it transparent for classmethods and staticmethods.
|
||||
|
||||
Attributes:
|
||||
wrapped: The decorator that has to be wrapped.
|
||||
decorator_info: The decorator info.
|
||||
shim: A wrapper function to wrap V1 style function.
|
||||
"""
|
||||
|
||||
wrapped: DecoratedType[ReturnType]
|
||||
decorator_info: DecoratorInfo
|
||||
shim: Callable[[Callable[..., Any]], Callable[..., Any]] | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
for attr in 'setter', 'deleter':
|
||||
if hasattr(self.wrapped, attr):
|
||||
f = partial(self._call_wrapped_attr, name=attr)
|
||||
setattr(self, attr, f)
|
||||
|
||||
def _call_wrapped_attr(self, func: Callable[[Any], None], *, name: str) -> PydanticDescriptorProxy[ReturnType]:
|
||||
self.wrapped = getattr(self.wrapped, name)(func)
|
||||
if isinstance(self.wrapped, property):
|
||||
# update ComputedFieldInfo.wrapped_property
|
||||
from ..fields import ComputedFieldInfo
|
||||
|
||||
if isinstance(self.decorator_info, ComputedFieldInfo):
|
||||
self.decorator_info.wrapped_property = self.wrapped
|
||||
return self
|
||||
|
||||
def __get__(self, obj: object | None, obj_type: type[object] | None = None) -> PydanticDescriptorProxy[ReturnType]:
|
||||
try:
|
||||
return self.wrapped.__get__(obj, obj_type)
|
||||
except AttributeError:
|
||||
# not a descriptor, e.g. a partial object
|
||||
return self.wrapped # type: ignore[return-value]
|
||||
|
||||
def __set_name__(self, instance: Any, name: str) -> None:
|
||||
if hasattr(self.wrapped, '__set_name__'):
|
||||
self.wrapped.__set_name__(instance, name) # pyright: ignore[reportFunctionMemberAccess]
|
||||
|
||||
def __getattr__(self, __name: str) -> Any:
|
||||
"""Forward checks for __isabstractmethod__ and such."""
|
||||
return getattr(self.wrapped, __name)
|
||||
|
||||
|
||||
DecoratorInfoType = TypeVar('DecoratorInfoType', bound=DecoratorInfo)
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
class Decorator(Generic[DecoratorInfoType]):
|
||||
"""A generic container class to join together the decorator metadata
|
||||
(metadata from decorator itself, which we have when the
|
||||
decorator is called but not when we are building the core-schema)
|
||||
and the bound function (which we have after the class itself is created).
|
||||
|
||||
Attributes:
|
||||
cls_ref: The class ref.
|
||||
cls_var_name: The decorated function name.
|
||||
func: The decorated function.
|
||||
shim: A wrapper function to wrap V1 style function.
|
||||
info: The decorator info.
|
||||
"""
|
||||
|
||||
cls_ref: str
|
||||
cls_var_name: str
|
||||
func: Callable[..., Any]
|
||||
shim: Callable[[Any], Any] | None
|
||||
info: DecoratorInfoType
|
||||
|
||||
@staticmethod
|
||||
def build(
|
||||
cls_: Any,
|
||||
*,
|
||||
cls_var_name: str,
|
||||
shim: Callable[[Any], Any] | None,
|
||||
info: DecoratorInfoType,
|
||||
) -> Decorator[DecoratorInfoType]:
|
||||
"""Build a new decorator.
|
||||
|
||||
Args:
|
||||
cls_: The class.
|
||||
cls_var_name: The decorated function name.
|
||||
shim: A wrapper function to wrap V1 style function.
|
||||
info: The decorator info.
|
||||
|
||||
Returns:
|
||||
The new decorator instance.
|
||||
"""
|
||||
func = get_attribute_from_bases(cls_, cls_var_name)
|
||||
if shim is not None:
|
||||
func = shim(func)
|
||||
func = unwrap_wrapped_function(func, unwrap_partial=False)
|
||||
if not callable(func):
|
||||
# This branch will get hit for classmethod properties
|
||||
attribute = get_attribute_from_base_dicts(cls_, cls_var_name) # prevents the binding call to `__get__`
|
||||
if isinstance(attribute, PydanticDescriptorProxy):
|
||||
func = unwrap_wrapped_function(attribute.wrapped)
|
||||
return Decorator(
|
||||
cls_ref=get_type_ref(cls_),
|
||||
cls_var_name=cls_var_name,
|
||||
func=func,
|
||||
shim=shim,
|
||||
info=info,
|
||||
)
|
||||
|
||||
def bind_to_cls(self, cls: Any) -> Decorator[DecoratorInfoType]:
|
||||
"""Bind the decorator to a class.
|
||||
|
||||
Args:
|
||||
cls: the class.
|
||||
|
||||
Returns:
|
||||
The new decorator instance.
|
||||
"""
|
||||
return self.build(
|
||||
cls,
|
||||
cls_var_name=self.cls_var_name,
|
||||
shim=self.shim,
|
||||
info=self.info,
|
||||
)
|
||||
|
||||
|
||||
def get_bases(tp: type[Any]) -> tuple[type[Any], ...]:
|
||||
"""Get the base classes of a class or typeddict.
|
||||
|
||||
Args:
|
||||
tp: The type or class to get the bases.
|
||||
|
||||
Returns:
|
||||
The base classes.
|
||||
"""
|
||||
if is_typeddict(tp):
|
||||
return tp.__orig_bases__ # type: ignore
|
||||
try:
|
||||
return tp.__bases__
|
||||
except AttributeError:
|
||||
return ()
|
||||
|
||||
|
||||
def mro(tp: type[Any]) -> tuple[type[Any], ...]:
|
||||
"""Calculate the Method Resolution Order of bases using the C3 algorithm.
|
||||
|
||||
See https://www.python.org/download/releases/2.3/mro/
|
||||
"""
|
||||
# try to use the existing mro, for performance mainly
|
||||
# but also because it helps verify the implementation below
|
||||
if not is_typeddict(tp):
|
||||
try:
|
||||
return tp.__mro__
|
||||
except AttributeError:
|
||||
# GenericAlias and some other cases
|
||||
pass
|
||||
|
||||
bases = get_bases(tp)
|
||||
return (tp,) + mro_for_bases(bases)
|
||||
|
||||
|
||||
def mro_for_bases(bases: tuple[type[Any], ...]) -> tuple[type[Any], ...]:
|
||||
def merge_seqs(seqs: list[deque[type[Any]]]) -> Iterable[type[Any]]:
|
||||
while True:
|
||||
non_empty = [seq for seq in seqs if seq]
|
||||
if not non_empty:
|
||||
# Nothing left to process, we're done.
|
||||
return
|
||||
candidate: type[Any] | None = None
|
||||
for seq in non_empty: # Find merge candidates among seq heads.
|
||||
candidate = seq[0]
|
||||
not_head = [s for s in non_empty if candidate in islice(s, 1, None)]
|
||||
if not_head:
|
||||
# Reject the candidate.
|
||||
candidate = None
|
||||
else:
|
||||
break
|
||||
if not candidate:
|
||||
raise TypeError('Inconsistent hierarchy, no C3 MRO is possible')
|
||||
yield candidate
|
||||
for seq in non_empty:
|
||||
# Remove candidate.
|
||||
if seq[0] == candidate:
|
||||
seq.popleft()
|
||||
|
||||
seqs = [deque(mro(base)) for base in bases] + [deque(bases)]
|
||||
return tuple(merge_seqs(seqs))
|
||||
|
||||
|
||||
_sentinel = object()
|
||||
|
||||
|
||||
def get_attribute_from_bases(tp: type[Any] | tuple[type[Any], ...], name: str) -> Any:
|
||||
"""Get the attribute from the next class in the MRO that has it,
|
||||
aiming to simulate calling the method on the actual class.
|
||||
|
||||
The reason for iterating over the mro instead of just getting
|
||||
the attribute (which would do that for us) is to support TypedDict,
|
||||
which lacks a real __mro__, but can have a virtual one constructed
|
||||
from its bases (as done here).
|
||||
|
||||
Args:
|
||||
tp: The type or class to search for the attribute. If a tuple, this is treated as a set of base classes.
|
||||
name: The name of the attribute to retrieve.
|
||||
|
||||
Returns:
|
||||
Any: The attribute value, if found.
|
||||
|
||||
Raises:
|
||||
AttributeError: If the attribute is not found in any class in the MRO.
|
||||
"""
|
||||
if isinstance(tp, tuple):
|
||||
for base in mro_for_bases(tp):
|
||||
attribute = base.__dict__.get(name, _sentinel)
|
||||
if attribute is not _sentinel:
|
||||
attribute_get = getattr(attribute, '__get__', None)
|
||||
if attribute_get is not None:
|
||||
return attribute_get(None, tp)
|
||||
return attribute
|
||||
raise AttributeError(f'{name} not found in {tp}')
|
||||
else:
|
||||
try:
|
||||
return getattr(tp, name)
|
||||
except AttributeError:
|
||||
return get_attribute_from_bases(mro(tp), name)
|
||||
|
||||
|
||||
def get_attribute_from_base_dicts(tp: type[Any], name: str) -> Any:
|
||||
"""Get an attribute out of the `__dict__` following the MRO.
|
||||
This prevents the call to `__get__` on the descriptor, and allows
|
||||
us to get the original function for classmethod properties.
|
||||
|
||||
Args:
|
||||
tp: The type or class to search for the attribute.
|
||||
name: The name of the attribute to retrieve.
|
||||
|
||||
Returns:
|
||||
Any: The attribute value, if found.
|
||||
|
||||
Raises:
|
||||
KeyError: If the attribute is not found in any class's `__dict__` in the MRO.
|
||||
"""
|
||||
for base in reversed(mro(tp)):
|
||||
if name in base.__dict__:
|
||||
return base.__dict__[name]
|
||||
return tp.__dict__[name] # raise the error
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
class DecoratorInfos:
|
||||
"""Mapping of name in the class namespace to decorator info.
|
||||
|
||||
note that the name in the class namespace is the function or attribute name
|
||||
not the field name!
|
||||
"""
|
||||
|
||||
validators: dict[str, Decorator[ValidatorDecoratorInfo]] = field(default_factory=dict)
|
||||
field_validators: dict[str, Decorator[FieldValidatorDecoratorInfo]] = field(default_factory=dict)
|
||||
root_validators: dict[str, Decorator[RootValidatorDecoratorInfo]] = field(default_factory=dict)
|
||||
field_serializers: dict[str, Decorator[FieldSerializerDecoratorInfo]] = field(default_factory=dict)
|
||||
model_serializers: dict[str, Decorator[ModelSerializerDecoratorInfo]] = field(default_factory=dict)
|
||||
model_validators: dict[str, Decorator[ModelValidatorDecoratorInfo]] = field(default_factory=dict)
|
||||
computed_fields: dict[str, Decorator[ComputedFieldInfo]] = field(default_factory=dict)
|
||||
|
||||
@staticmethod
|
||||
def build(model_dc: type[Any]) -> DecoratorInfos: # noqa: C901 (ignore complexity)
|
||||
"""We want to collect all DecFunc instances that exist as
|
||||
attributes in the namespace of the class (a BaseModel or dataclass)
|
||||
that called us
|
||||
But we want to collect these in the order of the bases
|
||||
So instead of getting them all from the leaf class (the class that called us),
|
||||
we traverse the bases from root (the oldest ancestor class) to leaf
|
||||
and collect all of the instances as we go, taking care to replace
|
||||
any duplicate ones with the last one we see to mimic how function overriding
|
||||
works with inheritance.
|
||||
If we do replace any functions we put the replacement into the position
|
||||
the replaced function was in; that is, we maintain the order.
|
||||
"""
|
||||
# reminder: dicts are ordered and replacement does not alter the order
|
||||
res = DecoratorInfos()
|
||||
for base in reversed(mro(model_dc)[1:]):
|
||||
existing: DecoratorInfos | None = base.__dict__.get('__pydantic_decorators__')
|
||||
if existing is None:
|
||||
existing = DecoratorInfos.build(base)
|
||||
res.validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.validators.items()})
|
||||
res.field_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_validators.items()})
|
||||
res.root_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.root_validators.items()})
|
||||
res.field_serializers.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_serializers.items()})
|
||||
res.model_serializers.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_serializers.items()})
|
||||
res.model_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_validators.items()})
|
||||
res.computed_fields.update({k: v.bind_to_cls(model_dc) for k, v in existing.computed_fields.items()})
|
||||
|
||||
to_replace: list[tuple[str, Any]] = []
|
||||
|
||||
for var_name, var_value in vars(model_dc).items():
|
||||
if isinstance(var_value, PydanticDescriptorProxy):
|
||||
info = var_value.decorator_info
|
||||
if isinstance(info, ValidatorDecoratorInfo):
|
||||
res.validators[var_name] = Decorator.build(
|
||||
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
|
||||
)
|
||||
elif isinstance(info, FieldValidatorDecoratorInfo):
|
||||
res.field_validators[var_name] = Decorator.build(
|
||||
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
|
||||
)
|
||||
elif isinstance(info, RootValidatorDecoratorInfo):
|
||||
res.root_validators[var_name] = Decorator.build(
|
||||
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
|
||||
)
|
||||
elif isinstance(info, FieldSerializerDecoratorInfo):
|
||||
# check whether a serializer function is already registered for fields
|
||||
for field_serializer_decorator in res.field_serializers.values():
|
||||
# check that each field has at most one serializer function.
|
||||
# serializer functions for the same field in subclasses are allowed,
|
||||
# and are treated as overrides
|
||||
if field_serializer_decorator.cls_var_name == var_name:
|
||||
continue
|
||||
for f in info.fields:
|
||||
if f in field_serializer_decorator.info.fields:
|
||||
raise PydanticUserError(
|
||||
'Multiple field serializer functions were defined '
|
||||
f'for field {f!r}, this is not allowed.',
|
||||
code='multiple-field-serializers',
|
||||
)
|
||||
res.field_serializers[var_name] = Decorator.build(
|
||||
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
|
||||
)
|
||||
elif isinstance(info, ModelValidatorDecoratorInfo):
|
||||
res.model_validators[var_name] = Decorator.build(
|
||||
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
|
||||
)
|
||||
elif isinstance(info, ModelSerializerDecoratorInfo):
|
||||
res.model_serializers[var_name] = Decorator.build(
|
||||
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
|
||||
)
|
||||
else:
|
||||
from ..fields import ComputedFieldInfo
|
||||
|
||||
isinstance(var_value, ComputedFieldInfo)
|
||||
res.computed_fields[var_name] = Decorator.build(
|
||||
model_dc, cls_var_name=var_name, shim=None, info=info
|
||||
)
|
||||
to_replace.append((var_name, var_value.wrapped))
|
||||
if to_replace:
|
||||
# If we can save `__pydantic_decorators__` on the class we'll be able to check for it above
|
||||
# so then we don't need to re-process the type, which means we can discard our descriptor wrappers
|
||||
# and replace them with the thing they are wrapping (see the other setattr call below)
|
||||
# which allows validator class methods to also function as regular class methods
|
||||
model_dc.__pydantic_decorators__ = res
|
||||
for name, value in to_replace:
|
||||
setattr(model_dc, name, value)
|
||||
return res
|
||||
|
||||
|
||||
def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes) -> bool:
|
||||
"""Look at a field or model validator function and determine whether it takes an info argument.
|
||||
|
||||
An error is raised if the function has an invalid signature.
|
||||
|
||||
Args:
|
||||
validator: The validator function to inspect.
|
||||
mode: The proposed validator mode.
|
||||
|
||||
Returns:
|
||||
Whether the validator takes an info argument.
|
||||
"""
|
||||
try:
|
||||
sig = signature(validator)
|
||||
except (ValueError, TypeError):
|
||||
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
|
||||
# In this case, we assume no info argument is present:
|
||||
return False
|
||||
n_positional = count_positional_required_params(sig)
|
||||
if mode == 'wrap':
|
||||
if n_positional == 3:
|
||||
return True
|
||||
elif n_positional == 2:
|
||||
return False
|
||||
else:
|
||||
assert mode in {'before', 'after', 'plain'}, f"invalid mode: {mode!r}, expected 'before', 'after' or 'plain"
|
||||
if n_positional == 2:
|
||||
return True
|
||||
elif n_positional == 1:
|
||||
return False
|
||||
|
||||
raise PydanticUserError(
|
||||
f'Unrecognized field_validator function signature for {validator} with `mode={mode}`:{sig}',
|
||||
code='validator-signature',
|
||||
)
|
||||
|
||||
|
||||
def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> tuple[bool, bool]:
|
||||
"""Look at a field serializer function and determine if it is a field serializer,
|
||||
and whether it takes an info argument.
|
||||
|
||||
An error is raised if the function has an invalid signature.
|
||||
|
||||
Args:
|
||||
serializer: The serializer function to inspect.
|
||||
mode: The serializer mode, either 'plain' or 'wrap'.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_field_serializer, info_arg).
|
||||
"""
|
||||
try:
|
||||
sig = signature(serializer)
|
||||
except (ValueError, TypeError):
|
||||
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
|
||||
# In this case, we assume no info argument is present and this is not a method:
|
||||
return (False, False)
|
||||
|
||||
first = next(iter(sig.parameters.values()), None)
|
||||
is_field_serializer = first is not None and first.name == 'self'
|
||||
|
||||
n_positional = count_positional_required_params(sig)
|
||||
if is_field_serializer:
|
||||
# -1 to correct for self parameter
|
||||
info_arg = _serializer_info_arg(mode, n_positional - 1)
|
||||
else:
|
||||
info_arg = _serializer_info_arg(mode, n_positional)
|
||||
|
||||
if info_arg is None:
|
||||
raise PydanticUserError(
|
||||
f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
|
||||
code='field-serializer-signature',
|
||||
)
|
||||
|
||||
return is_field_serializer, info_arg
|
||||
|
||||
|
||||
def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
|
||||
"""Look at a serializer function used via `Annotated` and determine whether it takes an info argument.
|
||||
|
||||
An error is raised if the function has an invalid signature.
|
||||
|
||||
Args:
|
||||
serializer: The serializer function to check.
|
||||
mode: The serializer mode, either 'plain' or 'wrap'.
|
||||
|
||||
Returns:
|
||||
info_arg
|
||||
"""
|
||||
try:
|
||||
sig = signature(serializer)
|
||||
except (ValueError, TypeError):
|
||||
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
|
||||
# In this case, we assume no info argument is present:
|
||||
return False
|
||||
info_arg = _serializer_info_arg(mode, count_positional_required_params(sig))
|
||||
if info_arg is None:
|
||||
raise PydanticUserError(
|
||||
f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
|
||||
code='field-serializer-signature',
|
||||
)
|
||||
else:
|
||||
return info_arg
|
||||
|
||||
|
||||
def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
|
||||
"""Look at a model serializer function and determine whether it takes an info argument.
|
||||
|
||||
An error is raised if the function has an invalid signature.
|
||||
|
||||
Args:
|
||||
serializer: The serializer function to check.
|
||||
mode: The serializer mode, either 'plain' or 'wrap'.
|
||||
|
||||
Returns:
|
||||
`info_arg` - whether the function expects an info argument.
|
||||
"""
|
||||
if isinstance(serializer, (staticmethod, classmethod)) or not is_instance_method_from_sig(serializer):
|
||||
raise PydanticUserError(
|
||||
'`@model_serializer` must be applied to instance methods', code='model-serializer-instance-method'
|
||||
)
|
||||
|
||||
sig = signature(serializer)
|
||||
info_arg = _serializer_info_arg(mode, count_positional_required_params(sig))
|
||||
if info_arg is None:
|
||||
raise PydanticUserError(
|
||||
f'Unrecognized model_serializer function signature for {serializer} with `mode={mode}`:{sig}',
|
||||
code='model-serializer-signature',
|
||||
)
|
||||
else:
|
||||
return info_arg
|
||||
|
||||
|
||||
def _serializer_info_arg(mode: Literal['plain', 'wrap'], n_positional: int) -> bool | None:
|
||||
if mode == 'plain':
|
||||
if n_positional == 1:
|
||||
# (input_value: Any, /) -> Any
|
||||
return False
|
||||
elif n_positional == 2:
|
||||
# (model: Any, input_value: Any, /) -> Any
|
||||
return True
|
||||
else:
|
||||
assert mode == 'wrap', f"invalid mode: {mode!r}, expected 'plain' or 'wrap'"
|
||||
if n_positional == 2:
|
||||
# (input_value: Any, serializer: SerializerFunctionWrapHandler, /) -> Any
|
||||
return False
|
||||
elif n_positional == 3:
|
||||
# (input_value: Any, serializer: SerializerFunctionWrapHandler, info: SerializationInfo, /) -> Any
|
||||
return True
|
||||
|
||||
return None
|
||||
|
||||
|
||||
AnyDecoratorCallable: TypeAlias = (
|
||||
'Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any], Callable[..., Any]]'
|
||||
)
|
||||
|
||||
|
||||
def is_instance_method_from_sig(function: AnyDecoratorCallable) -> bool:
|
||||
"""Whether the function is an instance method.
|
||||
|
||||
It will consider a function as instance method if the first parameter of
|
||||
function is `self`.
|
||||
|
||||
Args:
|
||||
function: The function to check.
|
||||
|
||||
Returns:
|
||||
`True` if the function is an instance method, `False` otherwise.
|
||||
"""
|
||||
sig = signature(unwrap_wrapped_function(function))
|
||||
first = next(iter(sig.parameters.values()), None)
|
||||
if first and first.name == 'self':
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def ensure_classmethod_based_on_signature(function: AnyDecoratorCallable) -> Any:
|
||||
"""Apply the `@classmethod` decorator on the function.
|
||||
|
||||
Args:
|
||||
function: The function to apply the decorator on.
|
||||
|
||||
Return:
|
||||
The `@classmethod` decorator applied function.
|
||||
"""
|
||||
if not isinstance(
|
||||
unwrap_wrapped_function(function, unwrap_class_static_method=False), classmethod
|
||||
) and _is_classmethod_from_sig(function):
|
||||
return classmethod(function) # type: ignore[arg-type]
|
||||
return function
|
||||
|
||||
|
||||
def _is_classmethod_from_sig(function: AnyDecoratorCallable) -> bool:
|
||||
sig = signature(unwrap_wrapped_function(function))
|
||||
first = next(iter(sig.parameters.values()), None)
|
||||
if first and first.name == 'cls':
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def unwrap_wrapped_function(
|
||||
func: Any,
|
||||
*,
|
||||
unwrap_partial: bool = True,
|
||||
unwrap_class_static_method: bool = True,
|
||||
) -> Any:
|
||||
"""Recursively unwraps a wrapped function until the underlying function is reached.
|
||||
This handles property, functools.partial, functools.partialmethod, staticmethod, and classmethod.
|
||||
|
||||
Args:
|
||||
func: The function to unwrap.
|
||||
unwrap_partial: If True (default), unwrap partial and partialmethod decorators.
|
||||
unwrap_class_static_method: If True (default), also unwrap classmethod and staticmethod
|
||||
decorators. If False, only unwrap partial and partialmethod decorators.
|
||||
|
||||
Returns:
|
||||
The underlying function of the wrapped function.
|
||||
"""
|
||||
# Define the types we want to check against as a single tuple.
|
||||
unwrap_types = (
|
||||
(property, cached_property)
|
||||
+ ((partial, partialmethod) if unwrap_partial else ())
|
||||
+ ((staticmethod, classmethod) if unwrap_class_static_method else ())
|
||||
)
|
||||
|
||||
while isinstance(func, unwrap_types):
|
||||
if unwrap_class_static_method and isinstance(func, (classmethod, staticmethod)):
|
||||
func = func.__func__
|
||||
elif isinstance(func, (partial, partialmethod)):
|
||||
func = func.func
|
||||
elif isinstance(func, property):
|
||||
func = func.fget # arbitrary choice, convenient for computed fields
|
||||
else:
|
||||
# Make coverage happy as it can only get here in the last possible case
|
||||
assert isinstance(func, cached_property)
|
||||
func = func.func # type: ignore
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def get_function_return_type(
|
||||
func: Any,
|
||||
explicit_return_type: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
) -> Any:
|
||||
"""Get the function return type.
|
||||
|
||||
It gets the return type from the type annotation if `explicit_return_type` is `None`.
|
||||
Otherwise, it returns `explicit_return_type`.
|
||||
|
||||
Args:
|
||||
func: The function to get its return type.
|
||||
explicit_return_type: The explicit return type.
|
||||
globalns: The globals namespace to use during type annotation evaluation.
|
||||
localns: The locals namespace to use during type annotation evaluation.
|
||||
|
||||
Returns:
|
||||
The function return type.
|
||||
"""
|
||||
if explicit_return_type is PydanticUndefined:
|
||||
# try to get it from the type annotation
|
||||
hints = get_function_type_hints(
|
||||
unwrap_wrapped_function(func),
|
||||
include_keys={'return'},
|
||||
globalns=globalns,
|
||||
localns=localns,
|
||||
)
|
||||
return hints.get('return', PydanticUndefined)
|
||||
else:
|
||||
return explicit_return_type
|
||||
|
||||
|
||||
def count_positional_required_params(sig: Signature) -> int:
|
||||
"""Get the number of positional (required) arguments of a signature.
|
||||
|
||||
This function should only be used to inspect signatures of validation and serialization functions.
|
||||
The first argument (the value being serialized or validated) is counted as a required argument
|
||||
even if a default value exists.
|
||||
|
||||
Returns:
|
||||
The number of positional arguments of a signature.
|
||||
"""
|
||||
parameters = list(sig.parameters.values())
|
||||
return sum(
|
||||
1
|
||||
for param in parameters
|
||||
if can_be_positional(param)
|
||||
# First argument is the value being validated/serialized, and can have a default value
|
||||
# (e.g. `float`, which has signature `(x=0, /)`). We assume other parameters (the info arg
|
||||
# for instance) should be required, and thus without any default value.
|
||||
and (param.default is Parameter.empty or param is parameters[0])
|
||||
)
|
||||
|
||||
|
||||
def ensure_property(f: Any) -> Any:
|
||||
"""Ensure that a function is a `property` or `cached_property`, or is a valid descriptor.
|
||||
|
||||
Args:
|
||||
f: The function to check.
|
||||
|
||||
Returns:
|
||||
The function, or a `property` or `cached_property` instance wrapping the function.
|
||||
"""
|
||||
if ismethoddescriptor(f) or isdatadescriptor(f):
|
||||
return f
|
||||
else:
|
||||
return property(f)
|
||||
@@ -0,0 +1,174 @@
|
||||
"""Logic for V1 validators, e.g. `@validator` and `@root_validator`."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from inspect import Parameter, signature
|
||||
from typing import Any, Dict, Tuple, Union, cast
|
||||
|
||||
from pydantic_core import core_schema
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from ..errors import PydanticUserError
|
||||
from ._utils import can_be_positional
|
||||
|
||||
|
||||
class V1OnlyValueValidator(Protocol):
|
||||
"""A simple validator, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any) -> Any: ...
|
||||
|
||||
|
||||
class V1ValidatorWithValues(Protocol):
|
||||
"""A validator with `values` argument, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any, values: dict[str, Any]) -> Any: ...
|
||||
|
||||
|
||||
class V1ValidatorWithValuesKwOnly(Protocol):
|
||||
"""A validator with keyword only `values` argument, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any, *, values: dict[str, Any]) -> Any: ...
|
||||
|
||||
|
||||
class V1ValidatorWithKwargs(Protocol):
|
||||
"""A validator with `kwargs` argument, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
class V1ValidatorWithValuesAndKwargs(Protocol):
|
||||
"""A validator with `values` and `kwargs` arguments, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any, values: dict[str, Any], **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
V1Validator = Union[
|
||||
V1ValidatorWithValues, V1ValidatorWithValuesKwOnly, V1ValidatorWithKwargs, V1ValidatorWithValuesAndKwargs
|
||||
]
|
||||
|
||||
|
||||
def can_be_keyword(param: Parameter) -> bool:
|
||||
return param.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY)
|
||||
|
||||
|
||||
def make_generic_v1_field_validator(validator: V1Validator) -> core_schema.WithInfoValidatorFunction:
|
||||
"""Wrap a V1 style field validator for V2 compatibility.
|
||||
|
||||
Args:
|
||||
validator: The V1 style field validator.
|
||||
|
||||
Returns:
|
||||
A wrapped V2 style field validator.
|
||||
|
||||
Raises:
|
||||
PydanticUserError: If the signature is not supported or the parameters are
|
||||
not available in Pydantic V2.
|
||||
"""
|
||||
sig = signature(validator)
|
||||
|
||||
needs_values_kw = False
|
||||
|
||||
for param_num, (param_name, parameter) in enumerate(sig.parameters.items()):
|
||||
if can_be_keyword(parameter) and param_name in ('field', 'config'):
|
||||
raise PydanticUserError(
|
||||
'The `field` and `config` parameters are not available in Pydantic V2, '
|
||||
'please use the `info` parameter instead.',
|
||||
code='validator-field-config-info',
|
||||
)
|
||||
if parameter.kind is Parameter.VAR_KEYWORD:
|
||||
needs_values_kw = True
|
||||
elif can_be_keyword(parameter) and param_name == 'values':
|
||||
needs_values_kw = True
|
||||
elif can_be_positional(parameter) and param_num == 0:
|
||||
# value
|
||||
continue
|
||||
elif parameter.default is Parameter.empty: # ignore params with defaults e.g. bound by functools.partial
|
||||
raise PydanticUserError(
|
||||
f'Unsupported signature for V1 style validator {validator}: {sig} is not supported.',
|
||||
code='validator-v1-signature',
|
||||
)
|
||||
|
||||
if needs_values_kw:
|
||||
# (v, **kwargs), (v, values, **kwargs), (v, *, values, **kwargs) or (v, *, values)
|
||||
val1 = cast(V1ValidatorWithValues, validator)
|
||||
|
||||
def wrapper1(value: Any, info: core_schema.ValidationInfo) -> Any:
|
||||
return val1(value, values=info.data)
|
||||
|
||||
return wrapper1
|
||||
else:
|
||||
val2 = cast(V1OnlyValueValidator, validator)
|
||||
|
||||
def wrapper2(value: Any, _: core_schema.ValidationInfo) -> Any:
|
||||
return val2(value)
|
||||
|
||||
return wrapper2
|
||||
|
||||
|
||||
RootValidatorValues = Dict[str, Any]
|
||||
# technically tuple[model_dict, model_extra, fields_set] | tuple[dataclass_dict, init_vars]
|
||||
RootValidatorFieldsTuple = Tuple[Any, ...]
|
||||
|
||||
|
||||
class V1RootValidatorFunction(Protocol):
|
||||
"""A simple root validator, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __values: RootValidatorValues) -> RootValidatorValues: ...
|
||||
|
||||
|
||||
class V2CoreBeforeRootValidator(Protocol):
|
||||
"""V2 validator with mode='before'."""
|
||||
|
||||
def __call__(self, __values: RootValidatorValues, __info: core_schema.ValidationInfo) -> RootValidatorValues: ...
|
||||
|
||||
|
||||
class V2CoreAfterRootValidator(Protocol):
|
||||
"""V2 validator with mode='after'."""
|
||||
|
||||
def __call__(
|
||||
self, __fields_tuple: RootValidatorFieldsTuple, __info: core_schema.ValidationInfo
|
||||
) -> RootValidatorFieldsTuple: ...
|
||||
|
||||
|
||||
def make_v1_generic_root_validator(
|
||||
validator: V1RootValidatorFunction, pre: bool
|
||||
) -> V2CoreBeforeRootValidator | V2CoreAfterRootValidator:
|
||||
"""Wrap a V1 style root validator for V2 compatibility.
|
||||
|
||||
Args:
|
||||
validator: The V1 style field validator.
|
||||
pre: Whether the validator is a pre validator.
|
||||
|
||||
Returns:
|
||||
A wrapped V2 style validator.
|
||||
"""
|
||||
if pre is True:
|
||||
# mode='before' for pydantic-core
|
||||
def _wrapper1(values: RootValidatorValues, _: core_schema.ValidationInfo) -> RootValidatorValues:
|
||||
return validator(values)
|
||||
|
||||
return _wrapper1
|
||||
|
||||
# mode='after' for pydantic-core
|
||||
def _wrapper2(fields_tuple: RootValidatorFieldsTuple, _: core_schema.ValidationInfo) -> RootValidatorFieldsTuple:
|
||||
if len(fields_tuple) == 2:
|
||||
# dataclass, this is easy
|
||||
values, init_vars = fields_tuple
|
||||
values = validator(values)
|
||||
return values, init_vars
|
||||
else:
|
||||
# ugly hack: to match v1 behaviour, we merge values and model_extra, then split them up based on fields
|
||||
# afterwards
|
||||
model_dict, model_extra, fields_set = fields_tuple
|
||||
if model_extra:
|
||||
fields = set(model_dict.keys())
|
||||
model_dict.update(model_extra)
|
||||
model_dict_new = validator(model_dict)
|
||||
for k in list(model_dict_new.keys()):
|
||||
if k not in fields:
|
||||
model_extra[k] = model_dict_new.pop(k)
|
||||
else:
|
||||
model_dict_new = validator(model_dict)
|
||||
return model_dict_new, model_extra, fields_set
|
||||
|
||||
return _wrapper2
|
||||
@@ -0,0 +1,503 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Hashable, Sequence
|
||||
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
from ..errors import PydanticUserError
|
||||
from . import _core_utils
|
||||
from ._core_utils import (
|
||||
CoreSchemaField,
|
||||
collect_definitions,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..types import Discriminator
|
||||
|
||||
CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator'
|
||||
|
||||
|
||||
class MissingDefinitionForUnionRef(Exception):
|
||||
"""Raised when applying a discriminated union discriminator to a schema
|
||||
requires a definition that is not yet defined
|
||||
"""
|
||||
|
||||
def __init__(self, ref: str) -> None:
|
||||
self.ref = ref
|
||||
super().__init__(f'Missing definition for ref {self.ref!r}')
|
||||
|
||||
|
||||
def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None:
|
||||
schema.setdefault('metadata', {})
|
||||
metadata = schema.get('metadata')
|
||||
assert metadata is not None
|
||||
metadata[CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY] = discriminator
|
||||
|
||||
|
||||
def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
|
||||
# We recursively walk through the `schema` passed to `apply_discriminators`, applying discriminators
|
||||
# where necessary at each level. During this recursion, we allow references to be resolved from the definitions
|
||||
# that are originally present on the original, outermost `schema`. Before `apply_discriminators` is called,
|
||||
# `simplify_schema_references` is called on the schema (in the `clean_schema` function),
|
||||
# which often puts the definitions in the outermost schema.
|
||||
global_definitions: dict[str, CoreSchema] = collect_definitions(schema)
|
||||
|
||||
def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema:
|
||||
nonlocal global_definitions
|
||||
|
||||
s = recurse(s, inner)
|
||||
if s['type'] == 'tagged-union':
|
||||
return s
|
||||
|
||||
metadata = s.get('metadata', {})
|
||||
discriminator = metadata.pop(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
|
||||
if discriminator is not None:
|
||||
s = apply_discriminator(s, discriminator, global_definitions)
|
||||
return s
|
||||
|
||||
return _core_utils.walk_core_schema(schema, inner, copy=False)
|
||||
|
||||
|
||||
def apply_discriminator(
|
||||
schema: core_schema.CoreSchema,
|
||||
discriminator: str | Discriminator,
|
||||
definitions: dict[str, core_schema.CoreSchema] | None = None,
|
||||
) -> core_schema.CoreSchema:
|
||||
"""Applies the discriminator and returns a new core schema.
|
||||
|
||||
Args:
|
||||
schema: The input schema.
|
||||
discriminator: The name of the field which will serve as the discriminator.
|
||||
definitions: A mapping of schema ref to schema.
|
||||
|
||||
Returns:
|
||||
The new core schema.
|
||||
|
||||
Raises:
|
||||
TypeError:
|
||||
- If `discriminator` is used with invalid union variant.
|
||||
- If `discriminator` is used with `Union` type with one variant.
|
||||
- If `discriminator` value mapped to multiple choices.
|
||||
MissingDefinitionForUnionRef:
|
||||
If the definition for ref is missing.
|
||||
PydanticUserError:
|
||||
- If a model in union doesn't have a discriminator field.
|
||||
- If discriminator field has a non-string alias.
|
||||
- If discriminator fields have different aliases.
|
||||
- If discriminator field not of type `Literal`.
|
||||
"""
|
||||
from ..types import Discriminator
|
||||
|
||||
if isinstance(discriminator, Discriminator):
|
||||
if isinstance(discriminator.discriminator, str):
|
||||
discriminator = discriminator.discriminator
|
||||
else:
|
||||
return discriminator._convert_schema(schema)
|
||||
|
||||
return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema)
|
||||
|
||||
|
||||
class _ApplyInferredDiscriminator:
|
||||
"""This class is used to convert an input schema containing a union schema into one where that union is
|
||||
replaced with a tagged-union, with all the associated debugging and performance benefits.
|
||||
|
||||
This is done by:
|
||||
* Validating that the input schema is compatible with the provided discriminator
|
||||
* Introspecting the schema to determine which discriminator values should map to which union choices
|
||||
* Handling various edge cases such as 'definitions', 'default', 'nullable' schemas, and more
|
||||
|
||||
I have chosen to implement the conversion algorithm in this class, rather than a function,
|
||||
to make it easier to maintain state while recursively walking the provided CoreSchema.
|
||||
"""
|
||||
|
||||
def __init__(self, discriminator: str, definitions: dict[str, core_schema.CoreSchema]):
|
||||
# `discriminator` should be the name of the field which will serve as the discriminator.
|
||||
# It must be the python name of the field, and *not* the field's alias. Note that as of now,
|
||||
# all members of a discriminated union _must_ use a field with the same name as the discriminator.
|
||||
# This may change if/when we expose a way to manually specify the TaggedUnionSchema's choices.
|
||||
self.discriminator = discriminator
|
||||
|
||||
# `definitions` should contain a mapping of schema ref to schema for all schemas which might
|
||||
# be referenced by some choice
|
||||
self.definitions = definitions
|
||||
|
||||
# `_discriminator_alias` will hold the value, if present, of the alias for the discriminator
|
||||
#
|
||||
# Note: following the v1 implementation, we currently disallow the use of different aliases
|
||||
# for different choices. This is not a limitation of pydantic_core, but if we try to handle
|
||||
# this, the inference logic gets complicated very quickly, and could result in confusing
|
||||
# debugging challenges for users making subtle mistakes.
|
||||
#
|
||||
# Rather than trying to do the most powerful inference possible, I think we should eventually
|
||||
# expose a way to more-manually control the way the TaggedUnionSchema is constructed through
|
||||
# the use of a new type which would be placed as an Annotation on the Union type. This would
|
||||
# provide the full flexibility/power of pydantic_core's TaggedUnionSchema where necessary for
|
||||
# more complex cases, without over-complicating the inference logic for the common cases.
|
||||
self._discriminator_alias: str | None = None
|
||||
|
||||
# `_should_be_nullable` indicates whether the converted union has `None` as an allowed value.
|
||||
# If `None` is an acceptable value of the (possibly-wrapped) union, we ignore it while
|
||||
# constructing the TaggedUnionSchema, but set the `_should_be_nullable` attribute to True.
|
||||
# Once we have constructed the TaggedUnionSchema, if `_should_be_nullable` is True, we ensure
|
||||
# that the final schema gets wrapped as a NullableSchema. This has the same semantics on the
|
||||
# python side, but resolves the issue that `None` cannot correspond to any discriminator values.
|
||||
self._should_be_nullable = False
|
||||
|
||||
# `_is_nullable` is used to track if the final produced schema will definitely be nullable;
|
||||
# we set it to True if the input schema is wrapped in a nullable schema that we know will be preserved
|
||||
# as an indication that, even if None is discovered as one of the union choices, we will not need to wrap
|
||||
# the final value in another nullable schema.
|
||||
#
|
||||
# This is more complicated than just checking for the final outermost schema having type 'nullable' thanks
|
||||
# to the possible presence of other wrapper schemas such as DefinitionsSchema, WithDefaultSchema, etc.
|
||||
self._is_nullable = False
|
||||
|
||||
# `_choices_to_handle` serves as a stack of choices to add to the tagged union. Initially, choices
|
||||
# from the union in the wrapped schema will be appended to this list, and the recursive choice-handling
|
||||
# algorithm may add more choices to this stack as (nested) unions are encountered.
|
||||
self._choices_to_handle: list[core_schema.CoreSchema] = []
|
||||
|
||||
# `_tagged_union_choices` is built during the call to `apply`, and will hold the choices to be included
|
||||
# in the output TaggedUnionSchema that will replace the union from the input schema
|
||||
self._tagged_union_choices: dict[Hashable, core_schema.CoreSchema] = {}
|
||||
|
||||
# `_used` is changed to True after applying the discriminator to prevent accidental reuse
|
||||
self._used = False
|
||||
|
||||
def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
|
||||
"""Return a new CoreSchema based on `schema` that uses a tagged-union with the discriminator provided
|
||||
to this class.
|
||||
|
||||
Args:
|
||||
schema: The input schema.
|
||||
|
||||
Returns:
|
||||
The new core schema.
|
||||
|
||||
Raises:
|
||||
TypeError:
|
||||
- If `discriminator` is used with invalid union variant.
|
||||
- If `discriminator` is used with `Union` type with one variant.
|
||||
- If `discriminator` value mapped to multiple choices.
|
||||
ValueError:
|
||||
If the definition for ref is missing.
|
||||
PydanticUserError:
|
||||
- If a model in union doesn't have a discriminator field.
|
||||
- If discriminator field has a non-string alias.
|
||||
- If discriminator fields have different aliases.
|
||||
- If discriminator field not of type `Literal`.
|
||||
"""
|
||||
assert not self._used
|
||||
schema = self._apply_to_root(schema)
|
||||
if self._should_be_nullable and not self._is_nullable:
|
||||
schema = core_schema.nullable_schema(schema)
|
||||
self._used = True
|
||||
return schema
|
||||
|
||||
def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
|
||||
"""This method handles the outer-most stage of recursion over the input schema:
|
||||
unwrapping nullable or definitions schemas, and calling the `_handle_choice`
|
||||
method iteratively on the choices extracted (recursively) from the possibly-wrapped union.
|
||||
"""
|
||||
if schema['type'] == 'nullable':
|
||||
self._is_nullable = True
|
||||
wrapped = self._apply_to_root(schema['schema'])
|
||||
nullable_wrapper = schema.copy()
|
||||
nullable_wrapper['schema'] = wrapped
|
||||
return nullable_wrapper
|
||||
|
||||
if schema['type'] == 'definitions':
|
||||
wrapped = self._apply_to_root(schema['schema'])
|
||||
definitions_wrapper = schema.copy()
|
||||
definitions_wrapper['schema'] = wrapped
|
||||
return definitions_wrapper
|
||||
|
||||
if schema['type'] != 'union':
|
||||
# If the schema is not a union, it probably means it just had a single member and
|
||||
# was flattened by pydantic_core.
|
||||
# However, it still may make sense to apply the discriminator to this schema,
|
||||
# as a way to get discriminated-union-style error messages, so we allow this here.
|
||||
schema = core_schema.union_schema([schema])
|
||||
|
||||
# Reverse the choices list before extending the stack so that they get handled in the order they occur
|
||||
choices_schemas = [v[0] if isinstance(v, tuple) else v for v in schema['choices'][::-1]]
|
||||
self._choices_to_handle.extend(choices_schemas)
|
||||
while self._choices_to_handle:
|
||||
choice = self._choices_to_handle.pop()
|
||||
self._handle_choice(choice)
|
||||
|
||||
if self._discriminator_alias is not None and self._discriminator_alias != self.discriminator:
|
||||
# * We need to annotate `discriminator` as a union here to handle both branches of this conditional
|
||||
# * We need to annotate `discriminator` as list[list[str | int]] and not list[list[str]] due to the
|
||||
# invariance of list, and because list[list[str | int]] is the type of the discriminator argument
|
||||
# to tagged_union_schema below
|
||||
# * See the docstring of pydantic_core.core_schema.tagged_union_schema for more details about how to
|
||||
# interpret the value of the discriminator argument to tagged_union_schema. (The list[list[str]] here
|
||||
# is the appropriate way to provide a list of fallback attributes to check for a discriminator value.)
|
||||
discriminator: str | list[list[str | int]] = [[self.discriminator], [self._discriminator_alias]]
|
||||
else:
|
||||
discriminator = self.discriminator
|
||||
return core_schema.tagged_union_schema(
|
||||
choices=self._tagged_union_choices,
|
||||
discriminator=discriminator,
|
||||
custom_error_type=schema.get('custom_error_type'),
|
||||
custom_error_message=schema.get('custom_error_message'),
|
||||
custom_error_context=schema.get('custom_error_context'),
|
||||
strict=False,
|
||||
from_attributes=True,
|
||||
ref=schema.get('ref'),
|
||||
metadata=schema.get('metadata'),
|
||||
serialization=schema.get('serialization'),
|
||||
)
|
||||
|
||||
def _handle_choice(self, choice: core_schema.CoreSchema) -> None:
|
||||
"""This method handles the "middle" stage of recursion over the input schema.
|
||||
Specifically, it is responsible for handling each choice of the outermost union
|
||||
(and any "coalesced" choices obtained from inner unions).
|
||||
|
||||
Here, "handling" entails:
|
||||
* Coalescing nested unions and compatible tagged-unions
|
||||
* Tracking the presence of 'none' and 'nullable' schemas occurring as choices
|
||||
* Validating that each allowed discriminator value maps to a unique choice
|
||||
* Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema.
|
||||
"""
|
||||
if choice['type'] == 'definition-ref':
|
||||
if choice['schema_ref'] not in self.definitions:
|
||||
raise MissingDefinitionForUnionRef(choice['schema_ref'])
|
||||
|
||||
if choice['type'] == 'none':
|
||||
self._should_be_nullable = True
|
||||
elif choice['type'] == 'definitions':
|
||||
self._handle_choice(choice['schema'])
|
||||
elif choice['type'] == 'nullable':
|
||||
self._should_be_nullable = True
|
||||
self._handle_choice(choice['schema']) # unwrap the nullable schema
|
||||
elif choice['type'] == 'union':
|
||||
# Reverse the choices list before extending the stack so that they get handled in the order they occur
|
||||
choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]]
|
||||
self._choices_to_handle.extend(choices_schemas)
|
||||
elif choice['type'] not in {
|
||||
'model',
|
||||
'typed-dict',
|
||||
'tagged-union',
|
||||
'lax-or-strict',
|
||||
'dataclass',
|
||||
'dataclass-args',
|
||||
'definition-ref',
|
||||
} and not _core_utils.is_function_with_inner_schema(choice):
|
||||
# We should eventually handle 'definition-ref' as well
|
||||
raise TypeError(
|
||||
f'{choice["type"]!r} is not a valid discriminated union variant;'
|
||||
' should be a `BaseModel` or `dataclass`'
|
||||
)
|
||||
else:
|
||||
if choice['type'] == 'tagged-union' and self._is_discriminator_shared(choice):
|
||||
# In this case, this inner tagged-union is compatible with the outer tagged-union,
|
||||
# and its choices can be coalesced into the outer TaggedUnionSchema.
|
||||
subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))]
|
||||
# Reverse the choices list before extending the stack so that they get handled in the order they occur
|
||||
self._choices_to_handle.extend(subchoices[::-1])
|
||||
return
|
||||
|
||||
inferred_discriminator_values = self._infer_discriminator_values_for_choice(choice, source_name=None)
|
||||
self._set_unique_choice_for_values(choice, inferred_discriminator_values)
|
||||
|
||||
def _is_discriminator_shared(self, choice: core_schema.TaggedUnionSchema) -> bool:
|
||||
"""This method returns a boolean indicating whether the discriminator for the `choice`
|
||||
is the same as that being used for the outermost tagged union. This is used to
|
||||
determine whether this TaggedUnionSchema choice should be "coalesced" into the top level,
|
||||
or whether it should be treated as a separate (nested) choice.
|
||||
"""
|
||||
inner_discriminator = choice['discriminator']
|
||||
return inner_discriminator == self.discriminator or (
|
||||
isinstance(inner_discriminator, list)
|
||||
and (self.discriminator in inner_discriminator or [self.discriminator] in inner_discriminator)
|
||||
)
|
||||
|
||||
def _infer_discriminator_values_for_choice( # noqa C901
|
||||
self, choice: core_schema.CoreSchema, source_name: str | None
|
||||
) -> list[str | int]:
|
||||
"""This function recurses over `choice`, extracting all discriminator values that should map to this choice.
|
||||
|
||||
`model_name` is accepted for the purpose of producing useful error messages.
|
||||
"""
|
||||
if choice['type'] == 'definitions':
|
||||
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
|
||||
elif choice['type'] == 'function-plain':
|
||||
raise TypeError(
|
||||
f'{choice["type"]!r} is not a valid discriminated union variant;'
|
||||
' should be a `BaseModel` or `dataclass`'
|
||||
)
|
||||
elif _core_utils.is_function_with_inner_schema(choice):
|
||||
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
|
||||
elif choice['type'] == 'lax-or-strict':
|
||||
return sorted(
|
||||
set(
|
||||
self._infer_discriminator_values_for_choice(choice['lax_schema'], source_name=None)
|
||||
+ self._infer_discriminator_values_for_choice(choice['strict_schema'], source_name=None)
|
||||
)
|
||||
)
|
||||
|
||||
elif choice['type'] == 'tagged-union':
|
||||
values: list[str | int] = []
|
||||
# Ignore str/int "choices" since these are just references to other choices
|
||||
subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))]
|
||||
for subchoice in subchoices:
|
||||
subchoice_values = self._infer_discriminator_values_for_choice(subchoice, source_name=None)
|
||||
values.extend(subchoice_values)
|
||||
return values
|
||||
|
||||
elif choice['type'] == 'union':
|
||||
values = []
|
||||
for subchoice in choice['choices']:
|
||||
subchoice_schema = subchoice[0] if isinstance(subchoice, tuple) else subchoice
|
||||
subchoice_values = self._infer_discriminator_values_for_choice(subchoice_schema, source_name=None)
|
||||
values.extend(subchoice_values)
|
||||
return values
|
||||
|
||||
elif choice['type'] == 'nullable':
|
||||
self._should_be_nullable = True
|
||||
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=None)
|
||||
|
||||
elif choice['type'] == 'model':
|
||||
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__)
|
||||
|
||||
elif choice['type'] == 'dataclass':
|
||||
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__)
|
||||
|
||||
elif choice['type'] == 'model-fields':
|
||||
return self._infer_discriminator_values_for_model_choice(choice, source_name=source_name)
|
||||
|
||||
elif choice['type'] == 'dataclass-args':
|
||||
return self._infer_discriminator_values_for_dataclass_choice(choice, source_name=source_name)
|
||||
|
||||
elif choice['type'] == 'typed-dict':
|
||||
return self._infer_discriminator_values_for_typed_dict_choice(choice, source_name=source_name)
|
||||
|
||||
elif choice['type'] == 'definition-ref':
|
||||
schema_ref = choice['schema_ref']
|
||||
if schema_ref not in self.definitions:
|
||||
raise MissingDefinitionForUnionRef(schema_ref)
|
||||
return self._infer_discriminator_values_for_choice(self.definitions[schema_ref], source_name=source_name)
|
||||
else:
|
||||
raise TypeError(
|
||||
f'{choice["type"]!r} is not a valid discriminated union variant;'
|
||||
' should be a `BaseModel` or `dataclass`'
|
||||
)
|
||||
|
||||
def _infer_discriminator_values_for_typed_dict_choice(
|
||||
self, choice: core_schema.TypedDictSchema, source_name: str | None = None
|
||||
) -> list[str | int]:
|
||||
"""This method just extracts the _infer_discriminator_values_for_choice logic specific to TypedDictSchema
|
||||
for the sake of readability.
|
||||
"""
|
||||
source = 'TypedDict' if source_name is None else f'TypedDict {source_name!r}'
|
||||
field = choice['fields'].get(self.discriminator)
|
||||
if field is None:
|
||||
raise PydanticUserError(
|
||||
f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
|
||||
)
|
||||
return self._infer_discriminator_values_for_field(field, source)
|
||||
|
||||
def _infer_discriminator_values_for_model_choice(
|
||||
self, choice: core_schema.ModelFieldsSchema, source_name: str | None = None
|
||||
) -> list[str | int]:
|
||||
source = 'ModelFields' if source_name is None else f'Model {source_name!r}'
|
||||
field = choice['fields'].get(self.discriminator)
|
||||
if field is None:
|
||||
raise PydanticUserError(
|
||||
f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
|
||||
)
|
||||
return self._infer_discriminator_values_for_field(field, source)
|
||||
|
||||
def _infer_discriminator_values_for_dataclass_choice(
|
||||
self, choice: core_schema.DataclassArgsSchema, source_name: str | None = None
|
||||
) -> list[str | int]:
|
||||
source = 'DataclassArgs' if source_name is None else f'Dataclass {source_name!r}'
|
||||
for field in choice['fields']:
|
||||
if field['name'] == self.discriminator:
|
||||
break
|
||||
else:
|
||||
raise PydanticUserError(
|
||||
f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
|
||||
)
|
||||
return self._infer_discriminator_values_for_field(field, source)
|
||||
|
||||
def _infer_discriminator_values_for_field(self, field: CoreSchemaField, source: str) -> list[str | int]:
|
||||
if field['type'] == 'computed-field':
|
||||
# This should never occur as a discriminator, as it is only relevant to serialization
|
||||
return []
|
||||
alias = field.get('validation_alias', self.discriminator)
|
||||
if not isinstance(alias, str):
|
||||
raise PydanticUserError(
|
||||
f'Alias {alias!r} is not supported in a discriminated union', code='discriminator-alias-type'
|
||||
)
|
||||
if self._discriminator_alias is None:
|
||||
self._discriminator_alias = alias
|
||||
elif self._discriminator_alias != alias:
|
||||
raise PydanticUserError(
|
||||
f'Aliases for discriminator {self.discriminator!r} must be the same '
|
||||
f'(got {alias}, {self._discriminator_alias})',
|
||||
code='discriminator-alias',
|
||||
)
|
||||
return self._infer_discriminator_values_for_inner_schema(field['schema'], source)
|
||||
|
||||
def _infer_discriminator_values_for_inner_schema(
|
||||
self, schema: core_schema.CoreSchema, source: str
|
||||
) -> list[str | int]:
|
||||
"""When inferring discriminator values for a field, we typically extract the expected values from a literal
|
||||
schema. This function does that, but also handles nested unions and defaults.
|
||||
"""
|
||||
if schema['type'] == 'literal':
|
||||
return schema['expected']
|
||||
|
||||
elif schema['type'] == 'union':
|
||||
# Generally when multiple values are allowed they should be placed in a single `Literal`, but
|
||||
# we add this case to handle the situation where a field is annotated as a `Union` of `Literal`s.
|
||||
# For example, this lets us handle `Union[Literal['key'], Union[Literal['Key'], Literal['KEY']]]`
|
||||
values: list[Any] = []
|
||||
for choice in schema['choices']:
|
||||
choice_schema = choice[0] if isinstance(choice, tuple) else choice
|
||||
choice_values = self._infer_discriminator_values_for_inner_schema(choice_schema, source)
|
||||
values.extend(choice_values)
|
||||
return values
|
||||
|
||||
elif schema['type'] == 'default':
|
||||
# This will happen if the field has a default value; we ignore it while extracting the discriminator values
|
||||
return self._infer_discriminator_values_for_inner_schema(schema['schema'], source)
|
||||
|
||||
elif schema['type'] == 'function-after':
|
||||
# After validators don't affect the discriminator values
|
||||
return self._infer_discriminator_values_for_inner_schema(schema['schema'], source)
|
||||
|
||||
elif schema['type'] in {'function-before', 'function-wrap', 'function-plain'}:
|
||||
validator_type = repr(schema['type'].split('-')[1])
|
||||
raise PydanticUserError(
|
||||
f'Cannot use a mode={validator_type} validator in the'
|
||||
f' discriminator field {self.discriminator!r} of {source}',
|
||||
code='discriminator-validator',
|
||||
)
|
||||
|
||||
else:
|
||||
raise PydanticUserError(
|
||||
f'{source} needs field {self.discriminator!r} to be of type `Literal`',
|
||||
code='discriminator-needs-literal',
|
||||
)
|
||||
|
||||
def _set_unique_choice_for_values(self, choice: core_schema.CoreSchema, values: Sequence[str | int]) -> None:
|
||||
"""This method updates `self.tagged_union_choices` so that all provided (discriminator) `values` map to the
|
||||
provided `choice`, validating that none of these values already map to another (different) choice.
|
||||
"""
|
||||
for discriminator_value in values:
|
||||
if discriminator_value in self._tagged_union_choices:
|
||||
# It is okay if `value` is already in tagged_union_choices as long as it maps to the same value.
|
||||
# Because tagged_union_choices may map values to other values, we need to walk the choices dict
|
||||
# until we get to a "real" choice, and confirm that is equal to the one assigned.
|
||||
existing_choice = self._tagged_union_choices[discriminator_value]
|
||||
if existing_choice != choice:
|
||||
raise TypeError(
|
||||
f'Value {discriminator_value!r} for discriminator '
|
||||
f'{self.discriminator!r} mapped to multiple choices'
|
||||
)
|
||||
else:
|
||||
self._tagged_union_choices[discriminator_value] = choice
|
||||
@@ -0,0 +1,108 @@
|
||||
"""Utilities related to attribute docstring extraction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import Any
|
||||
|
||||
|
||||
class DocstringVisitor(ast.NodeVisitor):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.target: str | None = None
|
||||
self.attrs: dict[str, str] = {}
|
||||
self.previous_node_type: type[ast.AST] | None = None
|
||||
|
||||
def visit(self, node: ast.AST) -> Any:
|
||||
node_result = super().visit(node)
|
||||
self.previous_node_type = type(node)
|
||||
return node_result
|
||||
|
||||
def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
|
||||
if isinstance(node.target, ast.Name):
|
||||
self.target = node.target.id
|
||||
|
||||
def visit_Expr(self, node: ast.Expr) -> Any:
|
||||
if (
|
||||
isinstance(node.value, ast.Constant)
|
||||
and isinstance(node.value.value, str)
|
||||
and self.previous_node_type is ast.AnnAssign
|
||||
):
|
||||
docstring = inspect.cleandoc(node.value.value)
|
||||
if self.target:
|
||||
self.attrs[self.target] = docstring
|
||||
self.target = None
|
||||
|
||||
|
||||
def _dedent_source_lines(source: list[str]) -> str:
|
||||
# Required for nested class definitions, e.g. in a function block
|
||||
dedent_source = textwrap.dedent(''.join(source))
|
||||
if dedent_source.startswith((' ', '\t')):
|
||||
# We are in the case where there's a dedented (usually multiline) string
|
||||
# at a lower indentation level than the class itself. We wrap our class
|
||||
# in a function as a workaround.
|
||||
dedent_source = f'def dedent_workaround():\n{dedent_source}'
|
||||
return dedent_source
|
||||
|
||||
|
||||
def _extract_source_from_frame(cls: type[Any]) -> list[str] | None:
|
||||
frame = inspect.currentframe()
|
||||
|
||||
while frame:
|
||||
if inspect.getmodule(frame) is inspect.getmodule(cls):
|
||||
lnum = frame.f_lineno
|
||||
try:
|
||||
lines, _ = inspect.findsource(frame)
|
||||
except OSError:
|
||||
# Source can't be retrieved (maybe because running in an interactive terminal),
|
||||
# we don't want to error here.
|
||||
pass
|
||||
else:
|
||||
block_lines = inspect.getblock(lines[lnum - 1 :])
|
||||
dedent_source = _dedent_source_lines(block_lines)
|
||||
try:
|
||||
block_tree = ast.parse(dedent_source)
|
||||
except SyntaxError:
|
||||
pass
|
||||
else:
|
||||
stmt = block_tree.body[0]
|
||||
if isinstance(stmt, ast.FunctionDef) and stmt.name == 'dedent_workaround':
|
||||
# `_dedent_source_lines` wrapped the class around the workaround function
|
||||
stmt = stmt.body[0]
|
||||
if isinstance(stmt, ast.ClassDef) and stmt.name == cls.__name__:
|
||||
return block_lines
|
||||
|
||||
frame = frame.f_back
|
||||
|
||||
|
||||
def extract_docstrings_from_cls(cls: type[Any], use_inspect: bool = False) -> dict[str, str]:
|
||||
"""Map model attributes and their corresponding docstring.
|
||||
|
||||
Args:
|
||||
cls: The class of the Pydantic model to inspect.
|
||||
use_inspect: Whether to skip usage of frames to find the object and use
|
||||
the `inspect` module instead.
|
||||
|
||||
Returns:
|
||||
A mapping containing attribute names and their corresponding docstring.
|
||||
"""
|
||||
if use_inspect:
|
||||
# Might not work as expected if two classes have the same name in the same source file.
|
||||
try:
|
||||
source, _ = inspect.getsourcelines(cls)
|
||||
except OSError:
|
||||
return {}
|
||||
else:
|
||||
source = _extract_source_from_frame(cls)
|
||||
|
||||
if not source:
|
||||
return {}
|
||||
|
||||
dedent_source = _dedent_source_lines(source)
|
||||
|
||||
visitor = DocstringVisitor()
|
||||
visitor.visit(ast.parse(dedent_source))
|
||||
return visitor.attrs
|
||||
@@ -0,0 +1,392 @@
|
||||
"""Private logic related to fields (the `Field()` function and `FieldInfo` class), and arguments to `Annotated`."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses
|
||||
import warnings
|
||||
from copy import copy
|
||||
from functools import lru_cache
|
||||
from inspect import Parameter, ismethoddescriptor, signature
|
||||
from typing import TYPE_CHECKING, Any, Callable, Pattern
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from pydantic.errors import PydanticUserError
|
||||
|
||||
from . import _typing_extra
|
||||
from ._config import ConfigWrapper
|
||||
from ._docs_extraction import extract_docstrings_from_cls
|
||||
from ._import_utils import import_cached_base_model, import_cached_field_info
|
||||
from ._namespace_utils import NsResolver
|
||||
from ._repr import Representation
|
||||
from ._utils import can_be_positional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from annotated_types import BaseMetadata
|
||||
|
||||
from ..fields import FieldInfo
|
||||
from ..main import BaseModel
|
||||
from ._dataclasses import StandardDataclass
|
||||
from ._decorators import DecoratorInfos
|
||||
|
||||
|
||||
class PydanticMetadata(Representation):
|
||||
"""Base class for annotation markers like `Strict`."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
def pydantic_general_metadata(**metadata: Any) -> BaseMetadata:
|
||||
"""Create a new `_PydanticGeneralMetadata` class with the given metadata.
|
||||
|
||||
Args:
|
||||
**metadata: The metadata to add.
|
||||
|
||||
Returns:
|
||||
The new `_PydanticGeneralMetadata` class.
|
||||
"""
|
||||
return _general_metadata_cls()(metadata) # type: ignore
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _general_metadata_cls() -> type[BaseMetadata]:
|
||||
"""Do it this way to avoid importing `annotated_types` at import time."""
|
||||
from annotated_types import BaseMetadata
|
||||
|
||||
class _PydanticGeneralMetadata(PydanticMetadata, BaseMetadata):
|
||||
"""Pydantic general metadata like `max_digits`."""
|
||||
|
||||
def __init__(self, metadata: Any):
|
||||
self.__dict__ = metadata
|
||||
|
||||
return _PydanticGeneralMetadata # type: ignore
|
||||
|
||||
|
||||
def _update_fields_from_docstrings(cls: type[Any], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper) -> None:
|
||||
if config_wrapper.use_attribute_docstrings:
|
||||
fields_docs = extract_docstrings_from_cls(cls)
|
||||
for ann_name, field_info in fields.items():
|
||||
if field_info.description is None and ann_name in fields_docs:
|
||||
field_info.description = fields_docs[ann_name]
|
||||
|
||||
|
||||
def collect_model_fields( # noqa: C901
|
||||
cls: type[BaseModel],
|
||||
bases: tuple[type[Any], ...],
|
||||
config_wrapper: ConfigWrapper,
|
||||
ns_resolver: NsResolver | None,
|
||||
*,
|
||||
typevars_map: dict[Any, Any] | None = None,
|
||||
) -> tuple[dict[str, FieldInfo], set[str]]:
|
||||
"""Collect the fields of a nascent pydantic model.
|
||||
|
||||
Also collect the names of any ClassVars present in the type hints.
|
||||
|
||||
The returned value is a tuple of two items: the fields dict, and the set of ClassVar names.
|
||||
|
||||
Args:
|
||||
cls: BaseModel or dataclass.
|
||||
bases: Parents of the class, generally `cls.__bases__`.
|
||||
config_wrapper: The config wrapper instance.
|
||||
ns_resolver: Namespace resolver to use when getting model annotations.
|
||||
typevars_map: A dictionary mapping type variables to their concrete types.
|
||||
|
||||
Returns:
|
||||
A tuple contains fields and class variables.
|
||||
|
||||
Raises:
|
||||
NameError:
|
||||
- If there is a conflict between a field name and protected namespaces.
|
||||
- If there is a field other than `root` in `RootModel`.
|
||||
- If a field shadows an attribute in the parent model.
|
||||
"""
|
||||
BaseModel = import_cached_base_model()
|
||||
FieldInfo_ = import_cached_field_info()
|
||||
|
||||
parent_fields_lookup: dict[str, FieldInfo] = {}
|
||||
for base in reversed(bases):
|
||||
if model_fields := getattr(base, '__pydantic_fields__', None):
|
||||
parent_fields_lookup.update(model_fields)
|
||||
|
||||
type_hints = _typing_extra.get_model_type_hints(cls, ns_resolver=ns_resolver)
|
||||
|
||||
# https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
|
||||
# annotations is only used for finding fields in parent classes
|
||||
annotations = cls.__dict__.get('__annotations__', {})
|
||||
fields: dict[str, FieldInfo] = {}
|
||||
|
||||
class_vars: set[str] = set()
|
||||
for ann_name, (ann_type, evaluated) in type_hints.items():
|
||||
if ann_name == 'model_config':
|
||||
# We never want to treat `model_config` as a field
|
||||
# Note: we may need to change this logic if/when we introduce a `BareModel` class with no
|
||||
# protected namespaces (where `model_config` might be allowed as a field name)
|
||||
continue
|
||||
|
||||
for protected_namespace in config_wrapper.protected_namespaces:
|
||||
ns_violation: bool = False
|
||||
if isinstance(protected_namespace, Pattern):
|
||||
ns_violation = protected_namespace.match(ann_name) is not None
|
||||
elif isinstance(protected_namespace, str):
|
||||
ns_violation = ann_name.startswith(protected_namespace)
|
||||
|
||||
if ns_violation:
|
||||
for b in bases:
|
||||
if hasattr(b, ann_name):
|
||||
if not (issubclass(b, BaseModel) and ann_name in getattr(b, '__pydantic_fields__', {})):
|
||||
raise NameError(
|
||||
f'Field "{ann_name}" conflicts with member {getattr(b, ann_name)}'
|
||||
f' of protected namespace "{protected_namespace}".'
|
||||
)
|
||||
else:
|
||||
valid_namespaces = ()
|
||||
for pn in config_wrapper.protected_namespaces:
|
||||
if isinstance(pn, Pattern):
|
||||
if not pn.match(ann_name):
|
||||
valid_namespaces += (f're.compile({pn.pattern})',)
|
||||
else:
|
||||
if not ann_name.startswith(pn):
|
||||
valid_namespaces += (pn,)
|
||||
|
||||
warnings.warn(
|
||||
f'Field "{ann_name}" in {cls.__name__} has conflict with protected namespace "{protected_namespace}".'
|
||||
'\n\nYou may be able to resolve this warning by setting'
|
||||
f" `model_config['protected_namespaces'] = {valid_namespaces}`.",
|
||||
UserWarning,
|
||||
)
|
||||
if _typing_extra.is_classvar_annotation(ann_type):
|
||||
class_vars.add(ann_name)
|
||||
continue
|
||||
if _is_finalvar_with_default_val(ann_type, getattr(cls, ann_name, PydanticUndefined)):
|
||||
class_vars.add(ann_name)
|
||||
continue
|
||||
if not is_valid_field_name(ann_name):
|
||||
continue
|
||||
if cls.__pydantic_root_model__ and ann_name != 'root':
|
||||
raise NameError(
|
||||
f"Unexpected field with name {ann_name!r}; only 'root' is allowed as a field of a `RootModel`"
|
||||
)
|
||||
|
||||
# when building a generic model with `MyModel[int]`, the generic_origin check makes sure we don't get
|
||||
# "... shadows an attribute" warnings
|
||||
generic_origin = getattr(cls, '__pydantic_generic_metadata__', {}).get('origin')
|
||||
for base in bases:
|
||||
dataclass_fields = {
|
||||
field.name for field in (dataclasses.fields(base) if dataclasses.is_dataclass(base) else ())
|
||||
}
|
||||
if hasattr(base, ann_name):
|
||||
if base is generic_origin:
|
||||
# Don't warn when "shadowing" of attributes in parametrized generics
|
||||
continue
|
||||
|
||||
if ann_name in dataclass_fields:
|
||||
# Don't warn when inheriting stdlib dataclasses whose fields are "shadowed" by defaults being set
|
||||
# on the class instance.
|
||||
continue
|
||||
|
||||
if ann_name not in annotations:
|
||||
# Don't warn when a field exists in a parent class but has not been defined in the current class
|
||||
continue
|
||||
|
||||
warnings.warn(
|
||||
f'Field name "{ann_name}" in "{cls.__qualname__}" shadows an attribute in parent '
|
||||
f'"{base.__qualname__}"',
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
try:
|
||||
default = getattr(cls, ann_name, PydanticUndefined)
|
||||
if default is PydanticUndefined:
|
||||
raise AttributeError
|
||||
except AttributeError:
|
||||
if ann_name in annotations:
|
||||
field_info = FieldInfo_.from_annotation(ann_type)
|
||||
field_info.evaluated = evaluated
|
||||
else:
|
||||
# if field has no default value and is not in __annotations__ this means that it is
|
||||
# defined in a base class and we can take it from there
|
||||
if ann_name in parent_fields_lookup:
|
||||
# The field was present on one of the (possibly multiple) base classes
|
||||
# copy the field to make sure typevar substitutions don't cause issues with the base classes
|
||||
field_info = copy(parent_fields_lookup[ann_name])
|
||||
else:
|
||||
# The field was not found on any base classes; this seems to be caused by fields not getting
|
||||
# generated thanks to models not being fully defined while initializing recursive models.
|
||||
# Nothing stops us from just creating a new FieldInfo for this type hint, so we do this.
|
||||
field_info = FieldInfo_.from_annotation(ann_type)
|
||||
field_info.evaluated = evaluated
|
||||
else:
|
||||
_warn_on_nested_alias_in_annotation(ann_type, ann_name)
|
||||
if isinstance(default, FieldInfo_) and ismethoddescriptor(default.default):
|
||||
# the `getattr` call above triggers a call to `__get__` for descriptors, so we do
|
||||
# the same if the `= field(default=...)` form is used. Note that we only do this
|
||||
# for method descriptors for now, we might want to extend this to any descriptor
|
||||
# in the future (by simply checking for `hasattr(default.default, '__get__')`).
|
||||
default.default = default.default.__get__(None, cls)
|
||||
|
||||
field_info = FieldInfo_.from_annotated_attribute(ann_type, default)
|
||||
field_info.evaluated = evaluated
|
||||
# attributes which are fields are removed from the class namespace:
|
||||
# 1. To match the behaviour of annotation-only fields
|
||||
# 2. To avoid false positives in the NameError check above
|
||||
try:
|
||||
delattr(cls, ann_name)
|
||||
except AttributeError:
|
||||
pass # indicates the attribute was on a parent class
|
||||
|
||||
# Use cls.__dict__['__pydantic_decorators__'] instead of cls.__pydantic_decorators__
|
||||
# to make sure the decorators have already been built for this exact class
|
||||
decorators: DecoratorInfos = cls.__dict__['__pydantic_decorators__']
|
||||
if ann_name in decorators.computed_fields:
|
||||
raise ValueError("you can't override a field with a computed field")
|
||||
fields[ann_name] = field_info
|
||||
|
||||
if typevars_map:
|
||||
for field in fields.values():
|
||||
field.apply_typevars_map(typevars_map)
|
||||
|
||||
_update_fields_from_docstrings(cls, fields, config_wrapper)
|
||||
return fields, class_vars
|
||||
|
||||
|
||||
def _warn_on_nested_alias_in_annotation(ann_type: type[Any], ann_name: str) -> None:
|
||||
FieldInfo = import_cached_field_info()
|
||||
|
||||
args = getattr(ann_type, '__args__', None)
|
||||
if args:
|
||||
for anno_arg in args:
|
||||
if _typing_extra.is_annotated(anno_arg):
|
||||
for anno_type_arg in _typing_extra.get_args(anno_arg):
|
||||
if isinstance(anno_type_arg, FieldInfo) and anno_type_arg.alias is not None:
|
||||
warnings.warn(
|
||||
f'`alias` specification on field "{ann_name}" must be set on outermost annotation to take effect.',
|
||||
UserWarning,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def _is_finalvar_with_default_val(type_: type[Any], val: Any) -> bool:
|
||||
FieldInfo = import_cached_field_info()
|
||||
|
||||
if not _typing_extra.is_finalvar(type_):
|
||||
return False
|
||||
elif val is PydanticUndefined:
|
||||
return False
|
||||
elif isinstance(val, FieldInfo) and (val.default is PydanticUndefined and val.default_factory is None):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def collect_dataclass_fields(
|
||||
cls: type[StandardDataclass],
|
||||
*,
|
||||
ns_resolver: NsResolver | None = None,
|
||||
typevars_map: dict[Any, Any] | None = None,
|
||||
config_wrapper: ConfigWrapper | None = None,
|
||||
) -> dict[str, FieldInfo]:
|
||||
"""Collect the fields of a dataclass.
|
||||
|
||||
Args:
|
||||
cls: dataclass.
|
||||
ns_resolver: Namespace resolver to use when getting dataclass annotations.
|
||||
Defaults to an empty instance.
|
||||
typevars_map: A dictionary mapping type variables to their concrete types.
|
||||
config_wrapper: The config wrapper instance.
|
||||
|
||||
Returns:
|
||||
The dataclass fields.
|
||||
"""
|
||||
FieldInfo_ = import_cached_field_info()
|
||||
|
||||
fields: dict[str, FieldInfo] = {}
|
||||
ns_resolver = ns_resolver or NsResolver()
|
||||
dataclass_fields = cls.__dataclass_fields__
|
||||
|
||||
# The logic here is similar to `_typing_extra.get_cls_type_hints`,
|
||||
# although we do it manually as stdlib dataclasses already have annotations
|
||||
# collected in each class:
|
||||
for base in reversed(cls.__mro__):
|
||||
if not dataclasses.is_dataclass(base):
|
||||
continue
|
||||
|
||||
with ns_resolver.push(base):
|
||||
for ann_name, dataclass_field in dataclass_fields.items():
|
||||
if ann_name not in base.__dict__.get('__annotations__', {}):
|
||||
# `__dataclass_fields__`contains every field, even the ones from base classes.
|
||||
# Only collect the ones defined on `base`.
|
||||
continue
|
||||
|
||||
globalns, localns = ns_resolver.types_namespace
|
||||
ann_type, _ = _typing_extra.try_eval_type(dataclass_field.type, globalns, localns)
|
||||
|
||||
if _typing_extra.is_classvar_annotation(ann_type):
|
||||
continue
|
||||
|
||||
if (
|
||||
not dataclass_field.init
|
||||
and dataclass_field.default is dataclasses.MISSING
|
||||
and dataclass_field.default_factory is dataclasses.MISSING
|
||||
):
|
||||
# TODO: We should probably do something with this so that validate_assignment behaves properly
|
||||
# Issue: https://github.com/pydantic/pydantic/issues/5470
|
||||
continue
|
||||
|
||||
if isinstance(dataclass_field.default, FieldInfo_):
|
||||
if dataclass_field.default.init_var:
|
||||
if dataclass_field.default.init is False:
|
||||
raise PydanticUserError(
|
||||
f'Dataclass field {ann_name} has init=False and init_var=True, but these are mutually exclusive.',
|
||||
code='clashing-init-and-init-var',
|
||||
)
|
||||
|
||||
# TODO: same note as above re validate_assignment
|
||||
continue
|
||||
field_info = FieldInfo_.from_annotated_attribute(ann_type, dataclass_field.default)
|
||||
else:
|
||||
field_info = FieldInfo_.from_annotated_attribute(ann_type, dataclass_field)
|
||||
|
||||
fields[ann_name] = field_info
|
||||
|
||||
if field_info.default is not PydanticUndefined and isinstance(
|
||||
getattr(cls, ann_name, field_info), FieldInfo_
|
||||
):
|
||||
# We need this to fix the default when the "default" from __dataclass_fields__ is a pydantic.FieldInfo
|
||||
setattr(cls, ann_name, field_info.default)
|
||||
|
||||
if typevars_map:
|
||||
for field in fields.values():
|
||||
# We don't pass any ns, as `field.annotation`
|
||||
# was already evaluated. TODO: is this method relevant?
|
||||
# Can't we juste use `_generics.replace_types`?
|
||||
field.apply_typevars_map(typevars_map)
|
||||
|
||||
if config_wrapper is not None:
|
||||
_update_fields_from_docstrings(cls, fields, config_wrapper)
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
def is_valid_field_name(name: str) -> bool:
|
||||
return not name.startswith('_')
|
||||
|
||||
|
||||
def is_valid_privateattr_name(name: str) -> bool:
|
||||
return name.startswith('_') and not name.startswith('__')
|
||||
|
||||
|
||||
def takes_validated_data_argument(
|
||||
default_factory: Callable[[], Any] | Callable[[dict[str, Any]], Any],
|
||||
) -> TypeIs[Callable[[dict[str, Any]], Any]]:
|
||||
"""Whether the provided default factory callable has a validated data parameter."""
|
||||
try:
|
||||
sig = signature(default_factory)
|
||||
except (ValueError, TypeError):
|
||||
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
|
||||
# In this case, we assume no data argument is present:
|
||||
return False
|
||||
|
||||
parameters = list(sig.parameters.values())
|
||||
|
||||
return len(parameters) == 1 and can_be_positional(parameters[0]) and parameters[0].default is Parameter.empty
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
|
||||
@dataclass
|
||||
class PydanticRecursiveRef:
|
||||
type_ref: str
|
||||
|
||||
__name__ = 'PydanticRecursiveRef'
|
||||
__hash__ = object.__hash__
|
||||
|
||||
def __call__(self) -> None:
|
||||
"""Defining __call__ is necessary for the `typing` module to let you use an instance of
|
||||
this class as the result of resolving a standard ForwardRef.
|
||||
"""
|
||||
|
||||
def __or__(self, other):
|
||||
return Union[self, other] # type: ignore
|
||||
|
||||
def __ror__(self, other):
|
||||
return Union[other, self] # type: ignore
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,536 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
from collections import ChainMap
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from types import prepare_class
|
||||
from typing import TYPE_CHECKING, Any, Iterator, Mapping, MutableMapping, Tuple, TypeVar
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import typing_extensions
|
||||
|
||||
from . import _typing_extra
|
||||
from ._core_utils import get_type_ref
|
||||
from ._forward_ref import PydanticRecursiveRef
|
||||
from ._utils import all_identical, is_model_class
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import _UnionGenericAlias # type: ignore[attr-defined]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..main import BaseModel
|
||||
|
||||
GenericTypesCacheKey = Tuple[Any, Any, Tuple[Any, ...]]
|
||||
|
||||
# Note: We want to remove LimitedDict, but to do this, we'd need to improve the handling of generics caching.
|
||||
# Right now, to handle recursive generics, we some types must remain cached for brief periods without references.
|
||||
# By chaining the WeakValuesDict with a LimitedDict, we have a way to retain caching for all types with references,
|
||||
# while also retaining a limited number of types even without references. This is generally enough to build
|
||||
# specific recursive generic models without losing required items out of the cache.
|
||||
|
||||
KT = TypeVar('KT')
|
||||
VT = TypeVar('VT')
|
||||
_LIMITED_DICT_SIZE = 100
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class LimitedDict(dict, MutableMapping[KT, VT]):
|
||||
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE): ...
|
||||
|
||||
else:
|
||||
|
||||
class LimitedDict(dict):
|
||||
"""Limit the size/length of a dict used for caching to avoid unlimited increase in memory usage.
|
||||
|
||||
Since the dict is ordered, and we always remove elements from the beginning, this is effectively a FIFO cache.
|
||||
"""
|
||||
|
||||
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE):
|
||||
self.size_limit = size_limit
|
||||
super().__init__()
|
||||
|
||||
def __setitem__(self, key: Any, value: Any, /) -> None:
|
||||
super().__setitem__(key, value)
|
||||
if len(self) > self.size_limit:
|
||||
excess = len(self) - self.size_limit + self.size_limit // 10
|
||||
to_remove = list(self.keys())[:excess]
|
||||
for k in to_remove:
|
||||
del self[k]
|
||||
|
||||
|
||||
# weak dictionaries allow the dynamically created parametrized versions of generic models to get collected
|
||||
# once they are no longer referenced by the caller.
|
||||
if sys.version_info >= (3, 9): # Typing for weak dictionaries available at 3.9
|
||||
GenericTypesCache = WeakValueDictionary[GenericTypesCacheKey, 'type[BaseModel]']
|
||||
else:
|
||||
GenericTypesCache = WeakValueDictionary
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class DeepChainMap(ChainMap[KT, VT]): # type: ignore
|
||||
...
|
||||
|
||||
else:
|
||||
|
||||
class DeepChainMap(ChainMap):
|
||||
"""Variant of ChainMap that allows direct updates to inner scopes.
|
||||
|
||||
Taken from https://docs.python.org/3/library/collections.html#collections.ChainMap,
|
||||
with some light modifications for this use case.
|
||||
"""
|
||||
|
||||
def clear(self) -> None:
|
||||
for mapping in self.maps:
|
||||
mapping.clear()
|
||||
|
||||
def __setitem__(self, key: KT, value: VT) -> None:
|
||||
for mapping in self.maps:
|
||||
mapping[key] = value
|
||||
|
||||
def __delitem__(self, key: KT) -> None:
|
||||
hit = False
|
||||
for mapping in self.maps:
|
||||
if key in mapping:
|
||||
del mapping[key]
|
||||
hit = True
|
||||
if not hit:
|
||||
raise KeyError(key)
|
||||
|
||||
|
||||
# Despite the fact that LimitedDict _seems_ no longer necessary, I'm very nervous to actually remove it
|
||||
# and discover later on that we need to re-add all this infrastructure...
|
||||
# _GENERIC_TYPES_CACHE = DeepChainMap(GenericTypesCache(), LimitedDict())
|
||||
|
||||
_GENERIC_TYPES_CACHE = GenericTypesCache()
|
||||
|
||||
|
||||
class PydanticGenericMetadata(typing_extensions.TypedDict):
|
||||
origin: type[BaseModel] | None # analogous to typing._GenericAlias.__origin__
|
||||
args: tuple[Any, ...] # analogous to typing._GenericAlias.__args__
|
||||
parameters: tuple[TypeVar, ...] # analogous to typing.Generic.__parameters__
|
||||
|
||||
|
||||
def create_generic_submodel(
|
||||
model_name: str, origin: type[BaseModel], args: tuple[Any, ...], params: tuple[Any, ...]
|
||||
) -> type[BaseModel]:
|
||||
"""Dynamically create a submodel of a provided (generic) BaseModel.
|
||||
|
||||
This is used when producing concrete parametrizations of generic models. This function
|
||||
only *creates* the new subclass; the schema/validators/serialization must be updated to
|
||||
reflect a concrete parametrization elsewhere.
|
||||
|
||||
Args:
|
||||
model_name: The name of the newly created model.
|
||||
origin: The base class for the new model to inherit from.
|
||||
args: A tuple of generic metadata arguments.
|
||||
params: A tuple of generic metadata parameters.
|
||||
|
||||
Returns:
|
||||
The created submodel.
|
||||
"""
|
||||
namespace: dict[str, Any] = {'__module__': origin.__module__}
|
||||
bases = (origin,)
|
||||
meta, ns, kwds = prepare_class(model_name, bases)
|
||||
namespace.update(ns)
|
||||
created_model = meta(
|
||||
model_name,
|
||||
bases,
|
||||
namespace,
|
||||
__pydantic_generic_metadata__={
|
||||
'origin': origin,
|
||||
'args': args,
|
||||
'parameters': params,
|
||||
},
|
||||
__pydantic_reset_parent_namespace__=False,
|
||||
**kwds,
|
||||
)
|
||||
|
||||
model_module, called_globally = _get_caller_frame_info(depth=3)
|
||||
if called_globally: # create global reference and therefore allow pickling
|
||||
object_by_reference = None
|
||||
reference_name = model_name
|
||||
reference_module_globals = sys.modules[created_model.__module__].__dict__
|
||||
while object_by_reference is not created_model:
|
||||
object_by_reference = reference_module_globals.setdefault(reference_name, created_model)
|
||||
reference_name += '_'
|
||||
|
||||
return created_model
|
||||
|
||||
|
||||
def _get_caller_frame_info(depth: int = 2) -> tuple[str | None, bool]:
|
||||
"""Used inside a function to check whether it was called globally.
|
||||
|
||||
Args:
|
||||
depth: The depth to get the frame.
|
||||
|
||||
Returns:
|
||||
A tuple contains `module_name` and `called_globally`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the function is not called inside a function.
|
||||
"""
|
||||
try:
|
||||
previous_caller_frame = sys._getframe(depth)
|
||||
except ValueError as e:
|
||||
raise RuntimeError('This function must be used inside another function') from e
|
||||
except AttributeError: # sys module does not have _getframe function, so there's nothing we can do about it
|
||||
return None, False
|
||||
frame_globals = previous_caller_frame.f_globals
|
||||
return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals
|
||||
|
||||
|
||||
DictValues: type[Any] = {}.values().__class__
|
||||
|
||||
|
||||
def iter_contained_typevars(v: Any) -> Iterator[TypeVar]:
|
||||
"""Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found.
|
||||
|
||||
This is inspired as an alternative to directly accessing the `__parameters__` attribute of a GenericAlias,
|
||||
since __parameters__ of (nested) generic BaseModel subclasses won't show up in that list.
|
||||
"""
|
||||
if isinstance(v, TypeVar):
|
||||
yield v
|
||||
elif is_model_class(v):
|
||||
yield from v.__pydantic_generic_metadata__['parameters']
|
||||
elif isinstance(v, (DictValues, list)):
|
||||
for var in v:
|
||||
yield from iter_contained_typevars(var)
|
||||
else:
|
||||
args = get_args(v)
|
||||
for arg in args:
|
||||
yield from iter_contained_typevars(arg)
|
||||
|
||||
|
||||
def get_args(v: Any) -> Any:
|
||||
pydantic_generic_metadata: PydanticGenericMetadata | None = getattr(v, '__pydantic_generic_metadata__', None)
|
||||
if pydantic_generic_metadata:
|
||||
return pydantic_generic_metadata.get('args')
|
||||
return typing_extensions.get_args(v)
|
||||
|
||||
|
||||
def get_origin(v: Any) -> Any:
|
||||
pydantic_generic_metadata: PydanticGenericMetadata | None = getattr(v, '__pydantic_generic_metadata__', None)
|
||||
if pydantic_generic_metadata:
|
||||
return pydantic_generic_metadata.get('origin')
|
||||
return typing_extensions.get_origin(v)
|
||||
|
||||
|
||||
def get_standard_typevars_map(cls: Any) -> dict[TypeVar, Any] | None:
|
||||
"""Package a generic type's typevars and parametrization (if present) into a dictionary compatible with the
|
||||
`replace_types` function. Specifically, this works with standard typing generics and typing._GenericAlias.
|
||||
"""
|
||||
origin = get_origin(cls)
|
||||
if origin is None:
|
||||
return None
|
||||
if not hasattr(origin, '__parameters__'):
|
||||
return None
|
||||
|
||||
# In this case, we know that cls is a _GenericAlias, and origin is the generic type
|
||||
# So it is safe to access cls.__args__ and origin.__parameters__
|
||||
args: tuple[Any, ...] = cls.__args__ # type: ignore
|
||||
parameters: tuple[TypeVar, ...] = origin.__parameters__
|
||||
return dict(zip(parameters, args))
|
||||
|
||||
|
||||
def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVar, Any] | None:
|
||||
"""Package a generic BaseModel's typevars and concrete parametrization (if present) into a dictionary compatible
|
||||
with the `replace_types` function.
|
||||
|
||||
Since BaseModel.__class_getitem__ does not produce a typing._GenericAlias, and the BaseModel generic info is
|
||||
stored in the __pydantic_generic_metadata__ attribute, we need special handling here.
|
||||
"""
|
||||
# TODO: This could be unified with `get_standard_typevars_map` if we stored the generic metadata
|
||||
# in the __origin__, __args__, and __parameters__ attributes of the model.
|
||||
generic_metadata = cls.__pydantic_generic_metadata__
|
||||
origin = generic_metadata['origin']
|
||||
args = generic_metadata['args']
|
||||
return dict(zip(iter_contained_typevars(origin), args))
|
||||
|
||||
|
||||
def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any:
|
||||
"""Return type with all occurrences of `type_map` keys recursively replaced with their values.
|
||||
|
||||
Args:
|
||||
type_: The class or generic alias.
|
||||
type_map: Mapping from `TypeVar` instance to concrete types.
|
||||
|
||||
Returns:
|
||||
A new type representing the basic structure of `type_` with all
|
||||
`typevar_map` keys recursively replaced.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from pydantic._internal._generics import replace_types
|
||||
|
||||
replace_types(Tuple[str, Union[List[str], float]], {str: int})
|
||||
#> Tuple[int, Union[List[int], float]]
|
||||
```
|
||||
"""
|
||||
if not type_map:
|
||||
return type_
|
||||
|
||||
type_args = get_args(type_)
|
||||
|
||||
if _typing_extra.is_annotated(type_):
|
||||
annotated_type, *annotations = type_args
|
||||
annotated = replace_types(annotated_type, type_map)
|
||||
for annotation in annotations:
|
||||
annotated = typing_extensions.Annotated[annotated, annotation]
|
||||
return annotated
|
||||
|
||||
origin_type = get_origin(type_)
|
||||
|
||||
# Having type args is a good indicator that this is a typing special form
|
||||
# instance or a generic alias of some sort.
|
||||
if type_args:
|
||||
resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
|
||||
if all_identical(type_args, resolved_type_args):
|
||||
# If all arguments are the same, there is no need to modify the
|
||||
# type or create a new object at all
|
||||
return type_
|
||||
|
||||
if (
|
||||
origin_type is not None
|
||||
and isinstance(type_, _typing_extra.typing_base)
|
||||
and not isinstance(origin_type, _typing_extra.typing_base)
|
||||
and getattr(type_, '_name', None) is not None
|
||||
):
|
||||
# In python < 3.9 generic aliases don't exist so any of these like `list`,
|
||||
# `type` or `collections.abc.Callable` need to be translated.
|
||||
# See: https://www.python.org/dev/peps/pep-0585
|
||||
origin_type = getattr(typing, type_._name)
|
||||
assert origin_type is not None
|
||||
|
||||
if _typing_extra.origin_is_union(origin_type):
|
||||
if any(_typing_extra.is_any(arg) for arg in resolved_type_args):
|
||||
# `Any | T` ~ `Any`:
|
||||
resolved_type_args = (Any,)
|
||||
# `Never | T` ~ `T`:
|
||||
resolved_type_args = tuple(
|
||||
arg
|
||||
for arg in resolved_type_args
|
||||
if not (_typing_extra.is_no_return(arg) or _typing_extra.is_never(arg))
|
||||
)
|
||||
|
||||
# PEP-604 syntax (Ex.: list | str) is represented with a types.UnionType object that does not have __getitem__.
|
||||
# We also cannot use isinstance() since we have to compare types.
|
||||
if sys.version_info >= (3, 10) and origin_type is types.UnionType:
|
||||
return _UnionGenericAlias(origin_type, resolved_type_args)
|
||||
# NotRequired[T] and Required[T] don't support tuple type resolved_type_args, hence the condition below
|
||||
return origin_type[resolved_type_args[0] if len(resolved_type_args) == 1 else resolved_type_args]
|
||||
|
||||
# We handle pydantic generic models separately as they don't have the same
|
||||
# semantics as "typing" classes or generic aliases
|
||||
|
||||
if not origin_type and is_model_class(type_):
|
||||
parameters = type_.__pydantic_generic_metadata__['parameters']
|
||||
if not parameters:
|
||||
return type_
|
||||
resolved_type_args = tuple(replace_types(t, type_map) for t in parameters)
|
||||
if all_identical(parameters, resolved_type_args):
|
||||
return type_
|
||||
return type_[resolved_type_args]
|
||||
|
||||
# Handle special case for typehints that can have lists as arguments.
|
||||
# `typing.Callable[[int, str], int]` is an example for this.
|
||||
if isinstance(type_, list):
|
||||
resolved_list = [replace_types(element, type_map) for element in type_]
|
||||
if all_identical(type_, resolved_list):
|
||||
return type_
|
||||
return resolved_list
|
||||
|
||||
# If all else fails, we try to resolve the type directly and otherwise just
|
||||
# return the input with no modifications.
|
||||
return type_map.get(type_, type_)
|
||||
|
||||
|
||||
def has_instance_in_type(type_: Any, isinstance_target: Any) -> bool:
|
||||
"""Checks if the type, or any of its arbitrary nested args, satisfy
|
||||
`isinstance(<type>, isinstance_target)`.
|
||||
"""
|
||||
if isinstance(type_, isinstance_target):
|
||||
return True
|
||||
if _typing_extra.is_annotated(type_):
|
||||
return has_instance_in_type(type_.__origin__, isinstance_target)
|
||||
if _typing_extra.is_literal(type_):
|
||||
return False
|
||||
|
||||
type_args = get_args(type_)
|
||||
|
||||
# Having type args is a good indicator that this is a typing module
|
||||
# class instantiation or a generic alias of some sort.
|
||||
for arg in type_args:
|
||||
if has_instance_in_type(arg, isinstance_target):
|
||||
return True
|
||||
|
||||
# Handle special case for typehints that can have lists as arguments.
|
||||
# `typing.Callable[[int, str], int]` is an example for this.
|
||||
if (
|
||||
isinstance(type_, list)
|
||||
# On Python < 3.10, typing_extensions implements `ParamSpec` as a subclass of `list`:
|
||||
and not isinstance(type_, typing_extensions.ParamSpec)
|
||||
):
|
||||
for element in type_:
|
||||
if has_instance_in_type(element, isinstance_target):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_parameters_count(cls: type[BaseModel], parameters: tuple[Any, ...]) -> None:
|
||||
"""Check the generic model parameters count is equal.
|
||||
|
||||
Args:
|
||||
cls: The generic model.
|
||||
parameters: A tuple of passed parameters to the generic model.
|
||||
|
||||
Raises:
|
||||
TypeError: If the passed parameters count is not equal to generic model parameters count.
|
||||
"""
|
||||
actual = len(parameters)
|
||||
expected = len(cls.__pydantic_generic_metadata__['parameters'])
|
||||
if actual != expected:
|
||||
description = 'many' if actual > expected else 'few'
|
||||
raise TypeError(f'Too {description} parameters for {cls}; actual {actual}, expected {expected}')
|
||||
|
||||
|
||||
_generic_recursion_cache: ContextVar[set[str] | None] = ContextVar('_generic_recursion_cache', default=None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def generic_recursion_self_type(
|
||||
origin: type[BaseModel], args: tuple[Any, ...]
|
||||
) -> Iterator[PydanticRecursiveRef | None]:
|
||||
"""This contextmanager should be placed around the recursive calls used to build a generic type,
|
||||
and accept as arguments the generic origin type and the type arguments being passed to it.
|
||||
|
||||
If the same origin and arguments are observed twice, it implies that a self-reference placeholder
|
||||
can be used while building the core schema, and will produce a schema_ref that will be valid in the
|
||||
final parent schema.
|
||||
"""
|
||||
previously_seen_type_refs = _generic_recursion_cache.get()
|
||||
if previously_seen_type_refs is None:
|
||||
previously_seen_type_refs = set()
|
||||
token = _generic_recursion_cache.set(previously_seen_type_refs)
|
||||
else:
|
||||
token = None
|
||||
|
||||
try:
|
||||
type_ref = get_type_ref(origin, args_override=args)
|
||||
if type_ref in previously_seen_type_refs:
|
||||
self_type = PydanticRecursiveRef(type_ref=type_ref)
|
||||
yield self_type
|
||||
else:
|
||||
previously_seen_type_refs.add(type_ref)
|
||||
yield
|
||||
previously_seen_type_refs.remove(type_ref)
|
||||
finally:
|
||||
if token:
|
||||
_generic_recursion_cache.reset(token)
|
||||
|
||||
|
||||
def recursively_defined_type_refs() -> set[str]:
|
||||
visited = _generic_recursion_cache.get()
|
||||
if not visited:
|
||||
return set() # not in a generic recursion, so there are no types
|
||||
|
||||
return visited.copy() # don't allow modifications
|
||||
|
||||
|
||||
def get_cached_generic_type_early(parent: type[BaseModel], typevar_values: Any) -> type[BaseModel] | None:
|
||||
"""The use of a two-stage cache lookup approach was necessary to have the highest performance possible for
|
||||
repeated calls to `__class_getitem__` on generic types (which may happen in tighter loops during runtime),
|
||||
while still ensuring that certain alternative parametrizations ultimately resolve to the same type.
|
||||
|
||||
As a concrete example, this approach was necessary to make Model[List[T]][int] equal to Model[List[int]].
|
||||
The approach could be modified to not use two different cache keys at different points, but the
|
||||
_early_cache_key is optimized to be as quick to compute as possible (for repeated-access speed), and the
|
||||
_late_cache_key is optimized to be as "correct" as possible, so that two types that will ultimately be the
|
||||
same after resolving the type arguments will always produce cache hits.
|
||||
|
||||
If we wanted to move to only using a single cache key per type, we would either need to always use the
|
||||
slower/more computationally intensive logic associated with _late_cache_key, or would need to accept
|
||||
that Model[List[T]][int] is a different type than Model[List[T]][int]. Because we rely on subclass relationships
|
||||
during validation, I think it is worthwhile to ensure that types that are functionally equivalent are actually
|
||||
equal.
|
||||
"""
|
||||
return _GENERIC_TYPES_CACHE.get(_early_cache_key(parent, typevar_values))
|
||||
|
||||
|
||||
def get_cached_generic_type_late(
|
||||
parent: type[BaseModel], typevar_values: Any, origin: type[BaseModel], args: tuple[Any, ...]
|
||||
) -> type[BaseModel] | None:
|
||||
"""See the docstring of `get_cached_generic_type_early` for more information about the two-stage cache lookup."""
|
||||
cached = _GENERIC_TYPES_CACHE.get(_late_cache_key(origin, args, typevar_values))
|
||||
if cached is not None:
|
||||
set_cached_generic_type(parent, typevar_values, cached, origin, args)
|
||||
return cached
|
||||
|
||||
|
||||
def set_cached_generic_type(
|
||||
parent: type[BaseModel],
|
||||
typevar_values: tuple[Any, ...],
|
||||
type_: type[BaseModel],
|
||||
origin: type[BaseModel] | None = None,
|
||||
args: tuple[Any, ...] | None = None,
|
||||
) -> None:
|
||||
"""See the docstring of `get_cached_generic_type_early` for more information about why items are cached with
|
||||
two different keys.
|
||||
"""
|
||||
_GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values)] = type_
|
||||
if len(typevar_values) == 1:
|
||||
_GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values[0])] = type_
|
||||
if origin and args:
|
||||
_GENERIC_TYPES_CACHE[_late_cache_key(origin, args, typevar_values)] = type_
|
||||
|
||||
|
||||
def _union_orderings_key(typevar_values: Any) -> Any:
|
||||
"""This is intended to help differentiate between Union types with the same arguments in different order.
|
||||
|
||||
Thanks to caching internal to the `typing` module, it is not possible to distinguish between
|
||||
List[Union[int, float]] and List[Union[float, int]] (and similarly for other "parent" origins besides List)
|
||||
because `typing` considers Union[int, float] to be equal to Union[float, int].
|
||||
|
||||
However, you _can_ distinguish between (top-level) Union[int, float] vs. Union[float, int].
|
||||
Because we parse items as the first Union type that is successful, we get slightly more consistent behavior
|
||||
if we make an effort to distinguish the ordering of items in a union. It would be best if we could _always_
|
||||
get the exact-correct order of items in the union, but that would require a change to the `typing` module itself.
|
||||
(See https://github.com/python/cpython/issues/86483 for reference.)
|
||||
"""
|
||||
if isinstance(typevar_values, tuple):
|
||||
args_data = []
|
||||
for value in typevar_values:
|
||||
args_data.append(_union_orderings_key(value))
|
||||
return tuple(args_data)
|
||||
elif _typing_extra.is_union(typevar_values):
|
||||
return get_args(typevar_values)
|
||||
else:
|
||||
return ()
|
||||
|
||||
|
||||
def _early_cache_key(cls: type[BaseModel], typevar_values: Any) -> GenericTypesCacheKey:
|
||||
"""This is intended for minimal computational overhead during lookups of cached types.
|
||||
|
||||
Note that this is overly simplistic, and it's possible that two different cls/typevar_values
|
||||
inputs would ultimately result in the same type being created in BaseModel.__class_getitem__.
|
||||
To handle this, we have a fallback _late_cache_key that is checked later if the _early_cache_key
|
||||
lookup fails, and should result in a cache hit _precisely_ when the inputs to __class_getitem__
|
||||
would result in the same type.
|
||||
"""
|
||||
return cls, typevar_values, _union_orderings_key(typevar_values)
|
||||
|
||||
|
||||
def _late_cache_key(origin: type[BaseModel], args: tuple[Any, ...], typevar_values: Any) -> GenericTypesCacheKey:
|
||||
"""This is intended for use later in the process of creating a new type, when we have more information
|
||||
about the exact args that will be passed. If it turns out that a different set of inputs to
|
||||
__class_getitem__ resulted in the same inputs to the generic type creation process, we can still
|
||||
return the cached type, and update the cache with the _early_cache_key as well.
|
||||
"""
|
||||
# The _union_orderings_key is placed at the start here to ensure there cannot be a collision with an
|
||||
# _early_cache_key, as that function will always produce a BaseModel subclass as the first item in the key,
|
||||
# whereas this function will always produce a tuple as the first item in the key.
|
||||
return _union_orderings_key(typevar_values), origin, args
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Git utilities, adopted from mypy's git utilities (https://github.com/python/mypy/blob/master/mypy/git.py)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
|
||||
def is_git_repo(dir: str) -> bool:
|
||||
"""Is the given directory version-controlled with git?"""
|
||||
return os.path.exists(os.path.join(dir, '.git'))
|
||||
|
||||
|
||||
def have_git() -> bool:
|
||||
"""Can we run the git executable?"""
|
||||
try:
|
||||
subprocess.check_output(['git', '--help'])
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def git_revision(dir: str) -> str:
|
||||
"""Get the SHA-1 of the HEAD of a git repository."""
|
||||
return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], cwd=dir).decode('utf-8').strip()
|
||||
@@ -0,0 +1,20 @@
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def import_cached_base_model() -> Type['BaseModel']:
|
||||
from pydantic import BaseModel
|
||||
|
||||
return BaseModel
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def import_cached_field_info() -> Type['FieldInfo']:
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
return FieldInfo
|
||||
@@ -0,0 +1,7 @@
|
||||
import sys
|
||||
|
||||
# `slots` is available on Python >= 3.10
|
||||
if sys.version_info >= (3, 10):
|
||||
slots_true = {'slots': True}
|
||||
else:
|
||||
slots_true = {}
|
||||
+392
@@ -0,0 +1,392 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from copy import copy
|
||||
from functools import lru_cache, partial
|
||||
from typing import TYPE_CHECKING, Any, Iterable
|
||||
|
||||
from pydantic_core import CoreSchema, PydanticCustomError, ValidationError, to_jsonable_python
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
from ._fields import PydanticMetadata
|
||||
from ._import_utils import import_cached_field_info
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
STRICT = {'strict'}
|
||||
FAIL_FAST = {'fail_fast'}
|
||||
LENGTH_CONSTRAINTS = {'min_length', 'max_length'}
|
||||
INEQUALITY = {'le', 'ge', 'lt', 'gt'}
|
||||
NUMERIC_CONSTRAINTS = {'multiple_of', *INEQUALITY}
|
||||
ALLOW_INF_NAN = {'allow_inf_nan'}
|
||||
|
||||
STR_CONSTRAINTS = {
|
||||
*LENGTH_CONSTRAINTS,
|
||||
*STRICT,
|
||||
'strip_whitespace',
|
||||
'to_lower',
|
||||
'to_upper',
|
||||
'pattern',
|
||||
'coerce_numbers_to_str',
|
||||
}
|
||||
BYTES_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
|
||||
|
||||
LIST_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
|
||||
TUPLE_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
|
||||
SET_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
|
||||
DICT_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
|
||||
GENERATOR_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
|
||||
SEQUENCE_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *FAIL_FAST}
|
||||
|
||||
FLOAT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *ALLOW_INF_NAN, *STRICT}
|
||||
DECIMAL_CONSTRAINTS = {'max_digits', 'decimal_places', *FLOAT_CONSTRAINTS}
|
||||
INT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *ALLOW_INF_NAN, *STRICT}
|
||||
BOOL_CONSTRAINTS = STRICT
|
||||
UUID_CONSTRAINTS = STRICT
|
||||
|
||||
DATE_TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
TIMEDELTA_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
LAX_OR_STRICT_CONSTRAINTS = STRICT
|
||||
ENUM_CONSTRAINTS = STRICT
|
||||
COMPLEX_CONSTRAINTS = STRICT
|
||||
|
||||
UNION_CONSTRAINTS = {'union_mode'}
|
||||
URL_CONSTRAINTS = {
|
||||
'max_length',
|
||||
'allowed_schemes',
|
||||
'host_required',
|
||||
'default_host',
|
||||
'default_port',
|
||||
'default_path',
|
||||
}
|
||||
|
||||
TEXT_SCHEMA_TYPES = ('str', 'bytes', 'url', 'multi-host-url')
|
||||
SEQUENCE_SCHEMA_TYPES = ('list', 'tuple', 'set', 'frozenset', 'generator', *TEXT_SCHEMA_TYPES)
|
||||
NUMERIC_SCHEMA_TYPES = ('float', 'int', 'date', 'time', 'timedelta', 'datetime')
|
||||
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
constraint_schema_pairings: list[tuple[set[str], tuple[str, ...]]] = [
|
||||
(STR_CONSTRAINTS, TEXT_SCHEMA_TYPES),
|
||||
(BYTES_CONSTRAINTS, ('bytes',)),
|
||||
(LIST_CONSTRAINTS, ('list',)),
|
||||
(TUPLE_CONSTRAINTS, ('tuple',)),
|
||||
(SET_CONSTRAINTS, ('set', 'frozenset')),
|
||||
(DICT_CONSTRAINTS, ('dict',)),
|
||||
(GENERATOR_CONSTRAINTS, ('generator',)),
|
||||
(FLOAT_CONSTRAINTS, ('float',)),
|
||||
(INT_CONSTRAINTS, ('int',)),
|
||||
(DATE_TIME_CONSTRAINTS, ('date', 'time', 'datetime', 'timedelta')),
|
||||
# TODO: this is a bit redundant, we could probably avoid some of these
|
||||
(STRICT, (*TEXT_SCHEMA_TYPES, *SEQUENCE_SCHEMA_TYPES, *NUMERIC_SCHEMA_TYPES, 'typed-dict', 'model')),
|
||||
(UNION_CONSTRAINTS, ('union',)),
|
||||
(URL_CONSTRAINTS, ('url', 'multi-host-url')),
|
||||
(BOOL_CONSTRAINTS, ('bool',)),
|
||||
(UUID_CONSTRAINTS, ('uuid',)),
|
||||
(LAX_OR_STRICT_CONSTRAINTS, ('lax-or-strict',)),
|
||||
(ENUM_CONSTRAINTS, ('enum',)),
|
||||
(DECIMAL_CONSTRAINTS, ('decimal',)),
|
||||
(COMPLEX_CONSTRAINTS, ('complex',)),
|
||||
]
|
||||
|
||||
for constraints, schemas in constraint_schema_pairings:
|
||||
for c in constraints:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[c].update(schemas)
|
||||
|
||||
|
||||
def as_jsonable_value(v: Any) -> Any:
|
||||
if type(v) not in (int, str, float, bytes, bool, type(None)):
|
||||
return to_jsonable_python(v)
|
||||
return v
|
||||
|
||||
|
||||
def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]:
|
||||
"""Expand the annotations.
|
||||
|
||||
Args:
|
||||
annotations: An iterable of annotations.
|
||||
|
||||
Returns:
|
||||
An iterable of expanded annotations.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from annotated_types import Ge, Len
|
||||
|
||||
from pydantic._internal._known_annotated_metadata import expand_grouped_metadata
|
||||
|
||||
print(list(expand_grouped_metadata([Ge(4), Len(5)])))
|
||||
#> [Ge(ge=4), MinLen(min_length=5)]
|
||||
```
|
||||
"""
|
||||
import annotated_types as at
|
||||
|
||||
FieldInfo = import_cached_field_info()
|
||||
|
||||
for annotation in annotations:
|
||||
if isinstance(annotation, at.GroupedMetadata):
|
||||
yield from annotation
|
||||
elif isinstance(annotation, FieldInfo):
|
||||
yield from annotation.metadata
|
||||
# this is a bit problematic in that it results in duplicate metadata
|
||||
# all of our "consumers" can handle it, but it is not ideal
|
||||
# we probably should split up FieldInfo into:
|
||||
# - annotated types metadata
|
||||
# - individual metadata known only to Pydantic
|
||||
annotation = copy(annotation)
|
||||
annotation.metadata = []
|
||||
yield annotation
|
||||
else:
|
||||
yield annotation
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _get_at_to_constraint_map() -> dict[type, str]:
|
||||
"""Return a mapping of annotated types to constraints.
|
||||
|
||||
Normally, we would define a mapping like this in the module scope, but we can't do that
|
||||
because we don't permit module level imports of `annotated_types`, in an attempt to speed up
|
||||
the import time of `pydantic`. We still only want to have this dictionary defined in one place,
|
||||
so we use this function to cache the result.
|
||||
"""
|
||||
import annotated_types as at
|
||||
|
||||
return {
|
||||
at.Gt: 'gt',
|
||||
at.Ge: 'ge',
|
||||
at.Lt: 'lt',
|
||||
at.Le: 'le',
|
||||
at.MultipleOf: 'multiple_of',
|
||||
at.MinLen: 'min_length',
|
||||
at.MaxLen: 'max_length',
|
||||
}
|
||||
|
||||
|
||||
def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | None: # noqa: C901
|
||||
"""Apply `annotation` to `schema` if it is an annotation we know about (Gt, Le, etc.).
|
||||
Otherwise return `None`.
|
||||
|
||||
This does not handle all known annotations. If / when it does, it can always
|
||||
return a CoreSchema and return the unmodified schema if the annotation should be ignored.
|
||||
|
||||
Assumes that GroupedMetadata has already been expanded via `expand_grouped_metadata`.
|
||||
|
||||
Args:
|
||||
annotation: The annotation.
|
||||
schema: The schema.
|
||||
|
||||
Returns:
|
||||
An updated schema with annotation if it is an annotation we know about, `None` otherwise.
|
||||
|
||||
Raises:
|
||||
PydanticCustomError: If `Predicate` fails.
|
||||
"""
|
||||
import annotated_types as at
|
||||
|
||||
from ._validators import NUMERIC_VALIDATOR_LOOKUP, forbid_inf_nan_check
|
||||
|
||||
schema = schema.copy()
|
||||
schema_update, other_metadata = collect_known_metadata([annotation])
|
||||
schema_type = schema['type']
|
||||
|
||||
chain_schema_constraints: set[str] = {
|
||||
'pattern',
|
||||
'strip_whitespace',
|
||||
'to_lower',
|
||||
'to_upper',
|
||||
'coerce_numbers_to_str',
|
||||
}
|
||||
chain_schema_steps: list[CoreSchema] = []
|
||||
|
||||
for constraint, value in schema_update.items():
|
||||
if constraint not in CONSTRAINTS_TO_ALLOWED_SCHEMAS:
|
||||
raise ValueError(f'Unknown constraint {constraint}')
|
||||
allowed_schemas = CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint]
|
||||
|
||||
# if it becomes necessary to handle more than one constraint
|
||||
# in this recursive case with function-after or function-wrap, we should refactor
|
||||
# this is a bit challenging because we sometimes want to apply constraints to the inner schema,
|
||||
# whereas other times we want to wrap the existing schema with a new one that enforces a new constraint.
|
||||
if schema_type in {'function-before', 'function-wrap', 'function-after'} and constraint == 'strict':
|
||||
schema['schema'] = apply_known_metadata(annotation, schema['schema']) # type: ignore # schema is function schema
|
||||
return schema
|
||||
|
||||
# if we're allowed to apply constraint directly to the schema, like le to int, do that
|
||||
if schema_type in allowed_schemas:
|
||||
if constraint == 'union_mode' and schema_type == 'union':
|
||||
schema['mode'] = value # type: ignore # schema is UnionSchema
|
||||
else:
|
||||
schema[constraint] = value
|
||||
continue
|
||||
|
||||
# else, apply a function after validator to the schema to enforce the corresponding constraint
|
||||
if constraint in chain_schema_constraints:
|
||||
|
||||
def _apply_constraint_with_incompatibility_info(
|
||||
value: Any, handler: cs.ValidatorFunctionWrapHandler
|
||||
) -> Any:
|
||||
try:
|
||||
x = handler(value)
|
||||
except ValidationError as ve:
|
||||
# if the error is about the type, it's likely that the constraint is incompatible the type of the field
|
||||
# for example, the following invalid schema wouldn't be caught during schema build, but rather at this point
|
||||
# with a cryptic 'string_type' error coming from the string validator,
|
||||
# that we'd rather express as a constraint incompatibility error (TypeError)
|
||||
# Annotated[list[int], Field(pattern='abc')]
|
||||
if 'type' in ve.errors()[0]['type']:
|
||||
raise TypeError(
|
||||
f"Unable to apply constraint '{constraint}' to supplied value {value} for schema of type '{schema_type}'" # noqa: B023
|
||||
)
|
||||
raise ve
|
||||
return x
|
||||
|
||||
chain_schema_steps.append(
|
||||
cs.no_info_wrap_validator_function(
|
||||
_apply_constraint_with_incompatibility_info, cs.str_schema(**{constraint: value})
|
||||
)
|
||||
)
|
||||
elif constraint in NUMERIC_VALIDATOR_LOOKUP:
|
||||
if constraint in LENGTH_CONSTRAINTS:
|
||||
inner_schema = schema
|
||||
while inner_schema['type'] in {'function-before', 'function-wrap', 'function-after'}:
|
||||
inner_schema = inner_schema['schema'] # type: ignore
|
||||
inner_schema_type = inner_schema['type']
|
||||
if inner_schema_type == 'list' or (
|
||||
inner_schema_type == 'json-or-python' and inner_schema['json_schema']['type'] == 'list' # type: ignore
|
||||
):
|
||||
js_constraint_key = 'minItems' if constraint == 'min_length' else 'maxItems'
|
||||
else:
|
||||
js_constraint_key = 'minLength' if constraint == 'min_length' else 'maxLength'
|
||||
else:
|
||||
js_constraint_key = constraint
|
||||
|
||||
schema = cs.no_info_after_validator_function(
|
||||
partial(NUMERIC_VALIDATOR_LOOKUP[constraint], **{constraint: value}), schema
|
||||
)
|
||||
metadata = schema.get('metadata', {})
|
||||
if (existing_json_schema_updates := metadata.get('pydantic_js_updates')) is not None:
|
||||
metadata['pydantic_js_updates'] = {
|
||||
**existing_json_schema_updates,
|
||||
**{js_constraint_key: as_jsonable_value(value)},
|
||||
}
|
||||
else:
|
||||
metadata['pydantic_js_updates'] = {js_constraint_key: as_jsonable_value(value)}
|
||||
schema['metadata'] = metadata
|
||||
elif constraint == 'allow_inf_nan' and value is False:
|
||||
schema = cs.no_info_after_validator_function(
|
||||
forbid_inf_nan_check,
|
||||
schema,
|
||||
)
|
||||
else:
|
||||
# It's rare that we'd get here, but it's possible if we add a new constraint and forget to handle it
|
||||
# Most constraint errors are caught at runtime during attempted application
|
||||
raise RuntimeError(f"Unable to apply constraint '{constraint}' to schema of type '{schema_type}'")
|
||||
|
||||
for annotation in other_metadata:
|
||||
if (annotation_type := type(annotation)) in (at_to_constraint_map := _get_at_to_constraint_map()):
|
||||
constraint = at_to_constraint_map[annotation_type]
|
||||
validator = NUMERIC_VALIDATOR_LOOKUP.get(constraint)
|
||||
if validator is None:
|
||||
raise ValueError(f'Unknown constraint {constraint}')
|
||||
schema = cs.no_info_after_validator_function(
|
||||
partial(validator, {constraint: getattr(annotation, constraint)}), schema
|
||||
)
|
||||
continue
|
||||
elif isinstance(annotation, (at.Predicate, at.Not)):
|
||||
predicate_name = f'{annotation.func.__qualname__}' if hasattr(annotation.func, '__qualname__') else ''
|
||||
|
||||
def val_func(v: Any) -> Any:
|
||||
predicate_satisfied = annotation.func(v) # noqa: B023
|
||||
|
||||
# annotation.func may also raise an exception, let it pass through
|
||||
if isinstance(annotation, at.Predicate): # noqa: B023
|
||||
if not predicate_satisfied:
|
||||
raise PydanticCustomError(
|
||||
'predicate_failed',
|
||||
f'Predicate {predicate_name} failed', # type: ignore # noqa: B023
|
||||
)
|
||||
else:
|
||||
if predicate_satisfied:
|
||||
raise PydanticCustomError(
|
||||
'not_operation_failed',
|
||||
f'Not of {predicate_name} failed', # type: ignore # noqa: B023
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
schema = cs.no_info_after_validator_function(val_func, schema)
|
||||
else:
|
||||
# ignore any other unknown metadata
|
||||
return None
|
||||
|
||||
if chain_schema_steps:
|
||||
chain_schema_steps = [schema] + chain_schema_steps
|
||||
return cs.chain_schema(chain_schema_steps)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def collect_known_metadata(annotations: Iterable[Any]) -> tuple[dict[str, Any], list[Any]]:
|
||||
"""Split `annotations` into known metadata and unknown annotations.
|
||||
|
||||
Args:
|
||||
annotations: An iterable of annotations.
|
||||
|
||||
Returns:
|
||||
A tuple contains a dict of known metadata and a list of unknown annotations.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from annotated_types import Gt, Len
|
||||
|
||||
from pydantic._internal._known_annotated_metadata import collect_known_metadata
|
||||
|
||||
print(collect_known_metadata([Gt(1), Len(42), ...]))
|
||||
#> ({'gt': 1, 'min_length': 42}, [Ellipsis])
|
||||
```
|
||||
"""
|
||||
annotations = expand_grouped_metadata(annotations)
|
||||
|
||||
res: dict[str, Any] = {}
|
||||
remaining: list[Any] = []
|
||||
|
||||
for annotation in annotations:
|
||||
# isinstance(annotation, PydanticMetadata) also covers ._fields:_PydanticGeneralMetadata
|
||||
if isinstance(annotation, PydanticMetadata):
|
||||
res.update(annotation.__dict__)
|
||||
# we don't use dataclasses.asdict because that recursively calls asdict on the field values
|
||||
elif (annotation_type := type(annotation)) in (at_to_constraint_map := _get_at_to_constraint_map()):
|
||||
constraint = at_to_constraint_map[annotation_type]
|
||||
res[constraint] = getattr(annotation, constraint)
|
||||
elif isinstance(annotation, type) and issubclass(annotation, PydanticMetadata):
|
||||
# also support PydanticMetadata classes being used without initialisation,
|
||||
# e.g. `Annotated[int, Strict]` as well as `Annotated[int, Strict()]`
|
||||
res.update({k: v for k, v in vars(annotation).items() if not k.startswith('_')})
|
||||
else:
|
||||
remaining.append(annotation)
|
||||
# Nones can sneak in but pydantic-core will reject them
|
||||
# it'd be nice to clean things up so we don't put in None (we probably don't _need_ to, it was just easier)
|
||||
# but this is simple enough to kick that can down the road
|
||||
res = {k: v for k, v in res.items() if v is not None}
|
||||
return res, remaining
|
||||
|
||||
|
||||
def check_metadata(metadata: dict[str, Any], allowed: Iterable[str], source_type: Any) -> None:
|
||||
"""A small utility function to validate that the given metadata can be applied to the target.
|
||||
More than saving lines of code, this gives us a consistent error message for all of our internal implementations.
|
||||
|
||||
Args:
|
||||
metadata: A dict of metadata.
|
||||
allowed: An iterable of allowed metadata.
|
||||
source_type: The source type.
|
||||
|
||||
Raises:
|
||||
TypeError: If there is metadatas that can't be applied on source type.
|
||||
"""
|
||||
unknown = metadata.keys() - set(allowed)
|
||||
if unknown:
|
||||
raise TypeError(
|
||||
f'The following constraints cannot be applied to {source_type!r}: {", ".join([f"{k!r}" for k in unknown])}'
|
||||
)
|
||||
@@ -0,0 +1,235 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, Mapping, TypeVar, Union
|
||||
|
||||
from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ..errors import PydanticErrorCodes, PydanticUserError
|
||||
from ..plugin._schema_validator import PluggableSchemaValidator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..dataclasses import PydanticDataclass
|
||||
from ..main import BaseModel
|
||||
from ..type_adapter import TypeAdapter
|
||||
|
||||
|
||||
ValSer = TypeVar('ValSer', bound=Union[SchemaValidator, PluggableSchemaValidator, SchemaSerializer])
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class MockCoreSchema(Mapping[str, Any]):
|
||||
"""Mocker for `pydantic_core.CoreSchema` which optionally attempts to
|
||||
rebuild the thing it's mocking when one of its methods is accessed and raises an error if that fails.
|
||||
"""
|
||||
|
||||
__slots__ = '_error_message', '_code', '_attempt_rebuild', '_built_memo'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_message: str,
|
||||
*,
|
||||
code: PydanticErrorCodes,
|
||||
attempt_rebuild: Callable[[], CoreSchema | None] | None = None,
|
||||
) -> None:
|
||||
self._error_message = error_message
|
||||
self._code: PydanticErrorCodes = code
|
||||
self._attempt_rebuild = attempt_rebuild
|
||||
self._built_memo: CoreSchema | None = None
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self._get_built().__getitem__(key)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._get_built().__len__()
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return self._get_built().__iter__()
|
||||
|
||||
def _get_built(self) -> CoreSchema:
|
||||
if self._built_memo is not None:
|
||||
return self._built_memo
|
||||
|
||||
if self._attempt_rebuild:
|
||||
schema = self._attempt_rebuild()
|
||||
if schema is not None:
|
||||
self._built_memo = schema
|
||||
return schema
|
||||
raise PydanticUserError(self._error_message, code=self._code)
|
||||
|
||||
def rebuild(self) -> CoreSchema | None:
|
||||
self._built_memo = None
|
||||
if self._attempt_rebuild:
|
||||
schema = self._attempt_rebuild()
|
||||
if schema is not None:
|
||||
return schema
|
||||
else:
|
||||
raise PydanticUserError(self._error_message, code=self._code)
|
||||
return None
|
||||
|
||||
|
||||
class MockValSer(Generic[ValSer]):
|
||||
"""Mocker for `pydantic_core.SchemaValidator` or `pydantic_core.SchemaSerializer` which optionally attempts to
|
||||
rebuild the thing it's mocking when one of its methods is accessed and raises an error if that fails.
|
||||
"""
|
||||
|
||||
__slots__ = '_error_message', '_code', '_val_or_ser', '_attempt_rebuild'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_message: str,
|
||||
*,
|
||||
code: PydanticErrorCodes,
|
||||
val_or_ser: Literal['validator', 'serializer'],
|
||||
attempt_rebuild: Callable[[], ValSer | None] | None = None,
|
||||
) -> None:
|
||||
self._error_message = error_message
|
||||
self._val_or_ser = SchemaValidator if val_or_ser == 'validator' else SchemaSerializer
|
||||
self._code: PydanticErrorCodes = code
|
||||
self._attempt_rebuild = attempt_rebuild
|
||||
|
||||
def __getattr__(self, item: str) -> None:
|
||||
__tracebackhide__ = True
|
||||
if self._attempt_rebuild:
|
||||
val_ser = self._attempt_rebuild()
|
||||
if val_ser is not None:
|
||||
return getattr(val_ser, item)
|
||||
|
||||
# raise an AttributeError if `item` doesn't exist
|
||||
getattr(self._val_or_ser, item)
|
||||
raise PydanticUserError(self._error_message, code=self._code)
|
||||
|
||||
def rebuild(self) -> ValSer | None:
|
||||
if self._attempt_rebuild:
|
||||
val_ser = self._attempt_rebuild()
|
||||
if val_ser is not None:
|
||||
return val_ser
|
||||
else:
|
||||
raise PydanticUserError(self._error_message, code=self._code)
|
||||
return None
|
||||
|
||||
|
||||
def set_type_adapter_mocks(adapter: TypeAdapter, type_repr: str) -> None:
|
||||
"""Set `core_schema`, `validator` and `serializer` to mock core types on a type adapter instance.
|
||||
|
||||
Args:
|
||||
adapter: The type adapter instance to set the mocks on
|
||||
type_repr: Name of the type used in the adapter, used in error messages
|
||||
"""
|
||||
undefined_type_error_message = (
|
||||
f'`TypeAdapter[{type_repr}]` is not fully defined; you should define `{type_repr}` and all referenced types,'
|
||||
f' then call `.rebuild()` on the instance.'
|
||||
)
|
||||
|
||||
def attempt_rebuild_fn(attr_fn: Callable[[TypeAdapter], T]) -> Callable[[], T | None]:
|
||||
def handler() -> T | None:
|
||||
if adapter.rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
|
||||
return attr_fn(adapter)
|
||||
else:
|
||||
return None
|
||||
|
||||
return handler
|
||||
|
||||
adapter.core_schema = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.core_schema),
|
||||
)
|
||||
adapter.validator = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='validator',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.validator),
|
||||
)
|
||||
adapter.serializer = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='serializer',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.serializer),
|
||||
)
|
||||
|
||||
|
||||
def set_model_mocks(cls: type[BaseModel], cls_name: str, undefined_name: str = 'all referenced types') -> None:
|
||||
"""Set `__pydantic_core_schema__`, `__pydantic_validator__` and `__pydantic_serializer__` to mock core types on a model.
|
||||
|
||||
Args:
|
||||
cls: The model class to set the mocks on
|
||||
cls_name: Name of the model class, used in error messages
|
||||
undefined_name: Name of the undefined thing, used in error messages
|
||||
"""
|
||||
undefined_type_error_message = (
|
||||
f'`{cls_name}` is not fully defined; you should define {undefined_name},'
|
||||
f' then call `{cls_name}.model_rebuild()`.'
|
||||
)
|
||||
|
||||
def attempt_rebuild_fn(attr_fn: Callable[[type[BaseModel]], T]) -> Callable[[], T | None]:
|
||||
def handler() -> T | None:
|
||||
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
|
||||
return attr_fn(cls)
|
||||
else:
|
||||
return None
|
||||
|
||||
return handler
|
||||
|
||||
cls.__pydantic_core_schema__ = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_core_schema__),
|
||||
)
|
||||
cls.__pydantic_validator__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='validator',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_validator__),
|
||||
)
|
||||
cls.__pydantic_serializer__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='serializer',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_serializer__),
|
||||
)
|
||||
|
||||
|
||||
def set_dataclass_mocks(
|
||||
cls: type[PydanticDataclass], cls_name: str, undefined_name: str = 'all referenced types'
|
||||
) -> None:
|
||||
"""Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a dataclass.
|
||||
|
||||
Args:
|
||||
cls: The model class to set the mocks on
|
||||
cls_name: Name of the model class, used in error messages
|
||||
undefined_name: Name of the undefined thing, used in error messages
|
||||
"""
|
||||
from ..dataclasses import rebuild_dataclass
|
||||
|
||||
undefined_type_error_message = (
|
||||
f'`{cls_name}` is not fully defined; you should define {undefined_name},'
|
||||
f' then call `pydantic.dataclasses.rebuild_dataclass({cls_name})`.'
|
||||
)
|
||||
|
||||
def attempt_rebuild_fn(attr_fn: Callable[[type[PydanticDataclass]], T]) -> Callable[[], T | None]:
|
||||
def handler() -> T | None:
|
||||
if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5) is not False:
|
||||
return attr_fn(cls)
|
||||
else:
|
||||
return None
|
||||
|
||||
return handler
|
||||
|
||||
cls.__pydantic_core_schema__ = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_core_schema__),
|
||||
)
|
||||
cls.__pydantic_validator__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='validator',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_validator__),
|
||||
)
|
||||
cls.__pydantic_serializer__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='serializer',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_serializer__),
|
||||
)
|
||||
@@ -0,0 +1,848 @@
|
||||
"""Private logic for creating models."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import builtins
|
||||
import operator
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
import weakref
|
||||
from abc import ABCMeta
|
||||
from functools import lru_cache, partial
|
||||
from types import FunctionType
|
||||
from typing import Any, Callable, Generic, Literal, NoReturn, TypeVar, cast
|
||||
|
||||
from pydantic_core import PydanticUndefined, SchemaSerializer
|
||||
from typing_extensions import TypeAliasType, dataclass_transform, deprecated, get_args
|
||||
|
||||
from ..errors import PydanticUndefinedAnnotation, PydanticUserError
|
||||
from ..plugin._schema_validator import create_schema_validator
|
||||
from ..warnings import GenericBeforeBaseModelWarning, PydanticDeprecatedSince20
|
||||
from ._config import ConfigWrapper
|
||||
from ._decorators import DecoratorInfos, PydanticDescriptorProxy, get_attribute_from_bases, unwrap_wrapped_function
|
||||
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name
|
||||
from ._generate_schema import GenerateSchema
|
||||
from ._generics import PydanticGenericMetadata, get_model_typevars_map
|
||||
from ._import_utils import import_cached_base_model, import_cached_field_info
|
||||
from ._mock_val_ser import set_model_mocks
|
||||
from ._namespace_utils import NsResolver
|
||||
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
|
||||
from ._signature import generate_pydantic_signature
|
||||
from ._typing_extra import (
|
||||
_make_forward_ref,
|
||||
eval_type_backport,
|
||||
is_annotated,
|
||||
is_classvar_annotation,
|
||||
parent_frame_namespace,
|
||||
)
|
||||
from ._utils import LazyClassAttribute, SafeGetItemProxy
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ..fields import ComputedFieldInfo, FieldInfo, ModelPrivateAttr
|
||||
from ..fields import Field as PydanticModelField
|
||||
from ..fields import PrivateAttr as PydanticModelPrivateAttr
|
||||
from ..main import BaseModel
|
||||
else:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
PydanticModelField = object()
|
||||
PydanticModelPrivateAttr = object()
|
||||
|
||||
object_setattr = object.__setattr__
|
||||
|
||||
|
||||
class _ModelNamespaceDict(dict):
|
||||
"""A dictionary subclass that intercepts attribute setting on model classes and
|
||||
warns about overriding of decorators.
|
||||
"""
|
||||
|
||||
def __setitem__(self, k: str, v: object) -> None:
|
||||
existing: Any = self.get(k, None)
|
||||
if existing and v is not existing and isinstance(existing, PydanticDescriptorProxy):
|
||||
warnings.warn(f'`{k}` overrides an existing Pydantic `{existing.decorator_info.decorator_repr}` decorator')
|
||||
|
||||
return super().__setitem__(k, v)
|
||||
|
||||
|
||||
def NoInitField(
|
||||
*,
|
||||
init: Literal[False] = False,
|
||||
) -> Any:
|
||||
"""Only for typing purposes. Used as default value of `__pydantic_fields_set__`,
|
||||
`__pydantic_extra__`, `__pydantic_private__`, so they could be ignored when
|
||||
synthesizing the `__init__` signature.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass_transform(kw_only_default=True, field_specifiers=(PydanticModelField, PydanticModelPrivateAttr, NoInitField))
|
||||
class ModelMetaclass(ABCMeta):
|
||||
def __new__(
|
||||
mcs,
|
||||
cls_name: str,
|
||||
bases: tuple[type[Any], ...],
|
||||
namespace: dict[str, Any],
|
||||
__pydantic_generic_metadata__: PydanticGenericMetadata | None = None,
|
||||
__pydantic_reset_parent_namespace__: bool = True,
|
||||
_create_model_module: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> type:
|
||||
"""Metaclass for creating Pydantic models.
|
||||
|
||||
Args:
|
||||
cls_name: The name of the class to be created.
|
||||
bases: The base classes of the class to be created.
|
||||
namespace: The attribute dictionary of the class to be created.
|
||||
__pydantic_generic_metadata__: Metadata for generic models.
|
||||
__pydantic_reset_parent_namespace__: Reset parent namespace.
|
||||
_create_model_module: The module of the class to be created, if created by `create_model`.
|
||||
**kwargs: Catch-all for any other keyword arguments.
|
||||
|
||||
Returns:
|
||||
The new class created by the metaclass.
|
||||
"""
|
||||
# Note `ModelMetaclass` refers to `BaseModel`, but is also used to *create* `BaseModel`, so we rely on the fact
|
||||
# that `BaseModel` itself won't have any bases, but any subclass of it will, to determine whether the `__new__`
|
||||
# call we're in the middle of is for the `BaseModel` class.
|
||||
if bases:
|
||||
base_field_names, class_vars, base_private_attributes = mcs._collect_bases_data(bases)
|
||||
|
||||
config_wrapper = ConfigWrapper.for_model(bases, namespace, kwargs)
|
||||
namespace['model_config'] = config_wrapper.config_dict
|
||||
private_attributes = inspect_namespace(
|
||||
namespace, config_wrapper.ignored_types, class_vars, base_field_names
|
||||
)
|
||||
if private_attributes or base_private_attributes:
|
||||
original_model_post_init = get_model_post_init(namespace, bases)
|
||||
if original_model_post_init is not None:
|
||||
# if there are private_attributes and a model_post_init function, we handle both
|
||||
|
||||
def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
|
||||
"""We need to both initialize private attributes and call the user-defined model_post_init
|
||||
method.
|
||||
"""
|
||||
init_private_attributes(self, context)
|
||||
original_model_post_init(self, context)
|
||||
|
||||
namespace['model_post_init'] = wrapped_model_post_init
|
||||
else:
|
||||
namespace['model_post_init'] = init_private_attributes
|
||||
|
||||
namespace['__class_vars__'] = class_vars
|
||||
namespace['__private_attributes__'] = {**base_private_attributes, **private_attributes}
|
||||
if __pydantic_generic_metadata__:
|
||||
namespace['__pydantic_generic_metadata__'] = __pydantic_generic_metadata__
|
||||
|
||||
cls = cast('type[BaseModel]', super().__new__(mcs, cls_name, bases, namespace, **kwargs))
|
||||
BaseModel_ = import_cached_base_model()
|
||||
|
||||
mro = cls.__mro__
|
||||
if Generic in mro and mro.index(Generic) < mro.index(BaseModel_):
|
||||
warnings.warn(
|
||||
GenericBeforeBaseModelWarning(
|
||||
'Classes should inherit from `BaseModel` before generic classes (e.g. `typing.Generic[T]`) '
|
||||
'for pydantic generics to work properly.'
|
||||
),
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
cls.__pydantic_custom_init__ = not getattr(cls.__init__, '__pydantic_base_init__', False)
|
||||
cls.__pydantic_post_init__ = (
|
||||
None if cls.model_post_init is BaseModel_.model_post_init else 'model_post_init'
|
||||
)
|
||||
|
||||
cls.__pydantic_decorators__ = DecoratorInfos.build(cls)
|
||||
|
||||
# Use the getattr below to grab the __parameters__ from the `typing.Generic` parent class
|
||||
if __pydantic_generic_metadata__:
|
||||
cls.__pydantic_generic_metadata__ = __pydantic_generic_metadata__
|
||||
else:
|
||||
parent_parameters = getattr(cls, '__pydantic_generic_metadata__', {}).get('parameters', ())
|
||||
parameters = getattr(cls, '__parameters__', None) or parent_parameters
|
||||
if parameters and parent_parameters and not all(x in parameters for x in parent_parameters):
|
||||
from ..root_model import RootModelRootType
|
||||
|
||||
missing_parameters = tuple(x for x in parameters if x not in parent_parameters)
|
||||
if RootModelRootType in parent_parameters and RootModelRootType not in parameters:
|
||||
# This is a special case where the user has subclassed `RootModel`, but has not parametrized
|
||||
# RootModel with the generic type identifiers being used. Ex:
|
||||
# class MyModel(RootModel, Generic[T]):
|
||||
# root: T
|
||||
# Should instead just be:
|
||||
# class MyModel(RootModel[T]):
|
||||
# root: T
|
||||
parameters_str = ', '.join([x.__name__ for x in missing_parameters])
|
||||
error_message = (
|
||||
f'{cls.__name__} is a subclass of `RootModel`, but does not include the generic type identifier(s) '
|
||||
f'{parameters_str} in its parameters. '
|
||||
f'You should parametrize RootModel directly, e.g., `class {cls.__name__}(RootModel[{parameters_str}]): ...`.'
|
||||
)
|
||||
else:
|
||||
combined_parameters = parent_parameters + missing_parameters
|
||||
parameters_str = ', '.join([str(x) for x in combined_parameters])
|
||||
generic_type_label = f'typing.Generic[{parameters_str}]'
|
||||
error_message = (
|
||||
f'All parameters must be present on typing.Generic;'
|
||||
f' you should inherit from {generic_type_label}.'
|
||||
)
|
||||
if Generic not in bases: # pragma: no cover
|
||||
# We raise an error here not because it is desirable, but because some cases are mishandled.
|
||||
# It would be nice to remove this error and still have things behave as expected, it's just
|
||||
# challenging because we are using a custom `__class_getitem__` to parametrize generic models,
|
||||
# and not returning a typing._GenericAlias from it.
|
||||
bases_str = ', '.join([x.__name__ for x in bases] + [generic_type_label])
|
||||
error_message += (
|
||||
f' Note: `typing.Generic` must go last: `class {cls.__name__}({bases_str}): ...`)'
|
||||
)
|
||||
raise TypeError(error_message)
|
||||
|
||||
cls.__pydantic_generic_metadata__ = {
|
||||
'origin': None,
|
||||
'args': (),
|
||||
'parameters': parameters,
|
||||
}
|
||||
|
||||
cls.__pydantic_complete__ = False # Ensure this specific class gets completed
|
||||
|
||||
# preserve `__set_name__` protocol defined in https://peps.python.org/pep-0487
|
||||
# for attributes not in `new_namespace` (e.g. private attributes)
|
||||
for name, obj in private_attributes.items():
|
||||
obj.__set_name__(cls, name)
|
||||
|
||||
if __pydantic_reset_parent_namespace__:
|
||||
cls.__pydantic_parent_namespace__ = build_lenient_weakvaluedict(parent_frame_namespace())
|
||||
parent_namespace: dict[str, Any] | None = getattr(cls, '__pydantic_parent_namespace__', None)
|
||||
if isinstance(parent_namespace, dict):
|
||||
parent_namespace = unpack_lenient_weakvaluedict(parent_namespace)
|
||||
|
||||
ns_resolver = NsResolver(parent_namespace=parent_namespace)
|
||||
|
||||
set_model_fields(cls, bases, config_wrapper, ns_resolver)
|
||||
|
||||
if config_wrapper.frozen and '__hash__' not in namespace:
|
||||
set_default_hash_func(cls, bases)
|
||||
|
||||
complete_model_class(
|
||||
cls,
|
||||
cls_name,
|
||||
config_wrapper,
|
||||
raise_errors=False,
|
||||
ns_resolver=ns_resolver,
|
||||
create_model_module=_create_model_module,
|
||||
)
|
||||
|
||||
# If this is placed before the complete_model_class call above,
|
||||
# the generic computed fields return type is set to PydanticUndefined
|
||||
cls.__pydantic_computed_fields__ = {
|
||||
k: v.info for k, v in cls.__pydantic_decorators__.computed_fields.items()
|
||||
}
|
||||
|
||||
set_deprecated_descriptors(cls)
|
||||
|
||||
# using super(cls, cls) on the next line ensures we only call the parent class's __pydantic_init_subclass__
|
||||
# I believe the `type: ignore` is only necessary because mypy doesn't realize that this code branch is
|
||||
# only hit for _proper_ subclasses of BaseModel
|
||||
super(cls, cls).__pydantic_init_subclass__(**kwargs) # type: ignore[misc]
|
||||
return cls
|
||||
else:
|
||||
# These are instance variables, but have been assigned to `NoInitField` to trick the type checker.
|
||||
for instance_slot in '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__':
|
||||
namespace.pop(
|
||||
instance_slot,
|
||||
None, # In case the metaclass is used with a class other than `BaseModel`.
|
||||
)
|
||||
namespace.get('__annotations__', {}).clear()
|
||||
return super().__new__(mcs, cls_name, bases, namespace, **kwargs)
|
||||
|
||||
def mro(cls) -> list[type[Any]]:
|
||||
original_mro = super().mro()
|
||||
|
||||
if cls.__bases__ == (object,):
|
||||
return original_mro
|
||||
|
||||
generic_metadata: PydanticGenericMetadata | None = cls.__dict__.get('__pydantic_generic_metadata__')
|
||||
if not generic_metadata:
|
||||
return original_mro
|
||||
|
||||
assert_err_msg = 'Unexpected error occurred when generating MRO of generic subclass. Please report this issue on GitHub: https://github.com/pydantic/pydantic/issues.'
|
||||
|
||||
origin: type[BaseModel] | None
|
||||
origin, args = (
|
||||
generic_metadata['origin'],
|
||||
generic_metadata['args'],
|
||||
)
|
||||
if not origin:
|
||||
return original_mro
|
||||
|
||||
target_params = origin.__pydantic_generic_metadata__['parameters']
|
||||
param_dict = dict(zip(target_params, args))
|
||||
|
||||
indexed_origins = {origin}
|
||||
|
||||
new_mro: list[type[Any]] = [cls]
|
||||
for base in original_mro[1:]:
|
||||
base_origin: type[BaseModel] | None = getattr(base, '__pydantic_generic_metadata__', {}).get('origin')
|
||||
base_params: tuple[TypeVar, ...] = getattr(base, '__pydantic_generic_metadata__', {}).get('parameters', ())
|
||||
|
||||
if base_origin in indexed_origins:
|
||||
continue
|
||||
elif base not in indexed_origins and base_params:
|
||||
assert set(base_params) <= param_dict.keys(), assert_err_msg
|
||||
new_base_args = tuple(param_dict[param] for param in base_params)
|
||||
new_base = base[new_base_args] # type: ignore
|
||||
new_mro.append(new_base)
|
||||
|
||||
indexed_origins.add(base_origin or base)
|
||||
|
||||
if base_origin is not None:
|
||||
# dropped previous indexed origins
|
||||
continue
|
||||
else:
|
||||
indexed_origins.add(base_origin or base)
|
||||
|
||||
# Avoid redundunt case such as
|
||||
# class A(BaseModel, Generic[T]): ...
|
||||
# A[T] is A # True
|
||||
if base is not new_mro[-1]:
|
||||
new_mro.append(base)
|
||||
|
||||
return new_mro
|
||||
|
||||
if not typing.TYPE_CHECKING: # pragma: no branch
|
||||
# We put `__getattr__` in a non-TYPE_CHECKING block because otherwise, mypy allows arbitrary attribute access
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""This is necessary to keep attribute access working for class attribute access."""
|
||||
private_attributes = self.__dict__.get('__private_attributes__')
|
||||
if private_attributes and item in private_attributes:
|
||||
return private_attributes[item]
|
||||
raise AttributeError(item)
|
||||
|
||||
@classmethod
|
||||
def __prepare__(cls, *args: Any, **kwargs: Any) -> dict[str, object]:
|
||||
return _ModelNamespaceDict()
|
||||
|
||||
def __instancecheck__(self, instance: Any) -> bool:
|
||||
"""Avoid calling ABC _abc_subclasscheck unless we're pretty sure.
|
||||
|
||||
See #3829 and python/cpython#92810
|
||||
"""
|
||||
return hasattr(instance, '__pydantic_validator__') and super().__instancecheck__(instance)
|
||||
|
||||
@staticmethod
|
||||
def _collect_bases_data(bases: tuple[type[Any], ...]) -> tuple[set[str], set[str], dict[str, ModelPrivateAttr]]:
|
||||
BaseModel = import_cached_base_model()
|
||||
|
||||
field_names: set[str] = set()
|
||||
class_vars: set[str] = set()
|
||||
private_attributes: dict[str, ModelPrivateAttr] = {}
|
||||
for base in bases:
|
||||
if issubclass(base, BaseModel) and base is not BaseModel:
|
||||
# model_fields might not be defined yet in the case of generics, so we use getattr here:
|
||||
field_names.update(getattr(base, '__pydantic_fields__', {}).keys())
|
||||
class_vars.update(base.__class_vars__)
|
||||
private_attributes.update(base.__private_attributes__)
|
||||
return field_names, class_vars, private_attributes
|
||||
|
||||
@property
|
||||
@deprecated('The `__fields__` attribute is deprecated, use `model_fields` instead.', category=None)
|
||||
def __fields__(self) -> dict[str, FieldInfo]:
|
||||
warnings.warn(
|
||||
'The `__fields__` attribute is deprecated, use `model_fields` instead.',
|
||||
PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.model_fields
|
||||
|
||||
@property
|
||||
def model_fields(self) -> dict[str, FieldInfo]:
|
||||
"""Get metadata about the fields defined on the model.
|
||||
|
||||
Returns:
|
||||
A mapping of field names to [`FieldInfo`][pydantic.fields.FieldInfo] objects.
|
||||
"""
|
||||
return getattr(self, '__pydantic_fields__', {})
|
||||
|
||||
@property
|
||||
def model_computed_fields(self) -> dict[str, ComputedFieldInfo]:
|
||||
"""Get metadata about the computed fields defined on the model.
|
||||
|
||||
Returns:
|
||||
A mapping of computed field names to [`ComputedFieldInfo`][pydantic.fields.ComputedFieldInfo] objects.
|
||||
"""
|
||||
return getattr(self, '__pydantic_computed_fields__', {})
|
||||
|
||||
def __dir__(self) -> list[str]:
|
||||
attributes = list(super().__dir__())
|
||||
if '__fields__' in attributes:
|
||||
attributes.remove('__fields__')
|
||||
return attributes
|
||||
|
||||
|
||||
def init_private_attributes(self: BaseModel, context: Any, /) -> None:
|
||||
"""This function is meant to behave like a BaseModel method to initialise private attributes.
|
||||
|
||||
It takes context as an argument since that's what pydantic-core passes when calling it.
|
||||
|
||||
Args:
|
||||
self: The BaseModel instance.
|
||||
context: The context.
|
||||
"""
|
||||
if getattr(self, '__pydantic_private__', None) is None:
|
||||
pydantic_private = {}
|
||||
for name, private_attr in self.__private_attributes__.items():
|
||||
default = private_attr.get_default()
|
||||
if default is not PydanticUndefined:
|
||||
pydantic_private[name] = default
|
||||
object_setattr(self, '__pydantic_private__', pydantic_private)
|
||||
|
||||
|
||||
def get_model_post_init(namespace: dict[str, Any], bases: tuple[type[Any], ...]) -> Callable[..., Any] | None:
|
||||
"""Get the `model_post_init` method from the namespace or the class bases, or `None` if not defined."""
|
||||
if 'model_post_init' in namespace:
|
||||
return namespace['model_post_init']
|
||||
|
||||
BaseModel = import_cached_base_model()
|
||||
|
||||
model_post_init = get_attribute_from_bases(bases, 'model_post_init')
|
||||
if model_post_init is not BaseModel.model_post_init:
|
||||
return model_post_init
|
||||
|
||||
|
||||
def inspect_namespace( # noqa C901
|
||||
namespace: dict[str, Any],
|
||||
ignored_types: tuple[type[Any], ...],
|
||||
base_class_vars: set[str],
|
||||
base_class_fields: set[str],
|
||||
) -> dict[str, ModelPrivateAttr]:
|
||||
"""Iterate over the namespace and:
|
||||
* gather private attributes
|
||||
* check for items which look like fields but are not (e.g. have no annotation) and warn.
|
||||
|
||||
Args:
|
||||
namespace: The attribute dictionary of the class to be created.
|
||||
ignored_types: A tuple of ignore types.
|
||||
base_class_vars: A set of base class class variables.
|
||||
base_class_fields: A set of base class fields.
|
||||
|
||||
Returns:
|
||||
A dict contains private attributes info.
|
||||
|
||||
Raises:
|
||||
TypeError: If there is a `__root__` field in model.
|
||||
NameError: If private attribute name is invalid.
|
||||
PydanticUserError:
|
||||
- If a field does not have a type annotation.
|
||||
- If a field on base class was overridden by a non-annotated attribute.
|
||||
"""
|
||||
from ..fields import ModelPrivateAttr, PrivateAttr
|
||||
|
||||
FieldInfo = import_cached_field_info()
|
||||
|
||||
all_ignored_types = ignored_types + default_ignored_types()
|
||||
|
||||
private_attributes: dict[str, ModelPrivateAttr] = {}
|
||||
raw_annotations = namespace.get('__annotations__', {})
|
||||
|
||||
if '__root__' in raw_annotations or '__root__' in namespace:
|
||||
raise TypeError("To define root models, use `pydantic.RootModel` rather than a field called '__root__'")
|
||||
|
||||
ignored_names: set[str] = set()
|
||||
for var_name, value in list(namespace.items()):
|
||||
if var_name == 'model_config' or var_name == '__pydantic_extra__':
|
||||
continue
|
||||
elif (
|
||||
isinstance(value, type)
|
||||
and value.__module__ == namespace['__module__']
|
||||
and '__qualname__' in namespace
|
||||
and value.__qualname__.startswith(namespace['__qualname__'])
|
||||
):
|
||||
# `value` is a nested type defined in this namespace; don't error
|
||||
continue
|
||||
elif isinstance(value, all_ignored_types) or value.__class__.__module__ == 'functools':
|
||||
ignored_names.add(var_name)
|
||||
continue
|
||||
elif isinstance(value, ModelPrivateAttr):
|
||||
if var_name.startswith('__'):
|
||||
raise NameError(
|
||||
'Private attributes must not use dunder names;'
|
||||
f' use a single underscore prefix instead of {var_name!r}.'
|
||||
)
|
||||
elif is_valid_field_name(var_name):
|
||||
raise NameError(
|
||||
'Private attributes must not use valid field names;'
|
||||
f' use sunder names, e.g. {"_" + var_name!r} instead of {var_name!r}.'
|
||||
)
|
||||
private_attributes[var_name] = value
|
||||
del namespace[var_name]
|
||||
elif isinstance(value, FieldInfo) and not is_valid_field_name(var_name):
|
||||
suggested_name = var_name.lstrip('_') or 'my_field' # don't suggest '' for all-underscore name
|
||||
raise NameError(
|
||||
f'Fields must not use names with leading underscores;'
|
||||
f' e.g., use {suggested_name!r} instead of {var_name!r}.'
|
||||
)
|
||||
|
||||
elif var_name.startswith('__'):
|
||||
continue
|
||||
elif is_valid_privateattr_name(var_name):
|
||||
if var_name not in raw_annotations or not is_classvar_annotation(raw_annotations[var_name]):
|
||||
private_attributes[var_name] = cast(ModelPrivateAttr, PrivateAttr(default=value))
|
||||
del namespace[var_name]
|
||||
elif var_name in base_class_vars:
|
||||
continue
|
||||
elif var_name not in raw_annotations:
|
||||
if var_name in base_class_fields:
|
||||
raise PydanticUserError(
|
||||
f'Field {var_name!r} defined on a base class was overridden by a non-annotated attribute. '
|
||||
f'All field definitions, including overrides, require a type annotation.',
|
||||
code='model-field-overridden',
|
||||
)
|
||||
elif isinstance(value, FieldInfo):
|
||||
raise PydanticUserError(
|
||||
f'Field {var_name!r} requires a type annotation', code='model-field-missing-annotation'
|
||||
)
|
||||
else:
|
||||
raise PydanticUserError(
|
||||
f'A non-annotated attribute was detected: `{var_name} = {value!r}`. All model fields require a '
|
||||
f'type annotation; if `{var_name}` is not meant to be a field, you may be able to resolve this '
|
||||
f"error by annotating it as a `ClassVar` or updating `model_config['ignored_types']`.",
|
||||
code='model-field-missing-annotation',
|
||||
)
|
||||
|
||||
for ann_name, ann_type in raw_annotations.items():
|
||||
if (
|
||||
is_valid_privateattr_name(ann_name)
|
||||
and ann_name not in private_attributes
|
||||
and ann_name not in ignored_names
|
||||
# This condition can be a false negative when `ann_type` is stringified,
|
||||
# but it is handled in most cases in `set_model_fields`:
|
||||
and not is_classvar_annotation(ann_type)
|
||||
and ann_type not in all_ignored_types
|
||||
and getattr(ann_type, '__module__', None) != 'functools'
|
||||
):
|
||||
if isinstance(ann_type, str):
|
||||
# Walking up the frames to get the module namespace where the model is defined
|
||||
# (as the model class wasn't created yet, we unfortunately can't use `cls.__module__`):
|
||||
frame = sys._getframe(2)
|
||||
if frame is not None:
|
||||
try:
|
||||
ann_type = eval_type_backport(
|
||||
_make_forward_ref(ann_type, is_argument=False, is_class=True),
|
||||
globalns=frame.f_globals,
|
||||
localns=frame.f_locals,
|
||||
)
|
||||
except (NameError, TypeError):
|
||||
pass
|
||||
|
||||
if is_annotated(ann_type):
|
||||
_, *metadata = get_args(ann_type)
|
||||
private_attr = next((v for v in metadata if isinstance(v, ModelPrivateAttr)), None)
|
||||
if private_attr is not None:
|
||||
private_attributes[ann_name] = private_attr
|
||||
continue
|
||||
private_attributes[ann_name] = PrivateAttr()
|
||||
|
||||
return private_attributes
|
||||
|
||||
|
||||
def set_default_hash_func(cls: type[BaseModel], bases: tuple[type[Any], ...]) -> None:
|
||||
base_hash_func = get_attribute_from_bases(bases, '__hash__')
|
||||
new_hash_func = make_hash_func(cls)
|
||||
if base_hash_func in {None, object.__hash__} or getattr(base_hash_func, '__code__', None) == new_hash_func.__code__:
|
||||
# If `__hash__` is some default, we generate a hash function.
|
||||
# It will be `None` if not overridden from BaseModel.
|
||||
# It may be `object.__hash__` if there is another
|
||||
# parent class earlier in the bases which doesn't override `__hash__` (e.g. `typing.Generic`).
|
||||
# It may be a value set by `set_default_hash_func` if `cls` is a subclass of another frozen model.
|
||||
# In the last case we still need a new hash function to account for new `model_fields`.
|
||||
cls.__hash__ = new_hash_func
|
||||
|
||||
|
||||
def make_hash_func(cls: type[BaseModel]) -> Any:
|
||||
getter = operator.itemgetter(*cls.__pydantic_fields__.keys()) if cls.__pydantic_fields__ else lambda _: 0
|
||||
|
||||
def hash_func(self: Any) -> int:
|
||||
try:
|
||||
return hash(getter(self.__dict__))
|
||||
except KeyError:
|
||||
# In rare cases (such as when using the deprecated copy method), the __dict__ may not contain
|
||||
# all model fields, which is how we can get here.
|
||||
# getter(self.__dict__) is much faster than any 'safe' method that accounts for missing keys,
|
||||
# and wrapping it in a `try` doesn't slow things down much in the common case.
|
||||
return hash(getter(SafeGetItemProxy(self.__dict__)))
|
||||
|
||||
return hash_func
|
||||
|
||||
|
||||
def set_model_fields(
|
||||
cls: type[BaseModel],
|
||||
bases: tuple[type[Any], ...],
|
||||
config_wrapper: ConfigWrapper,
|
||||
ns_resolver: NsResolver | None,
|
||||
) -> None:
|
||||
"""Collect and set `cls.__pydantic_fields__` and `cls.__class_vars__`.
|
||||
|
||||
Args:
|
||||
cls: BaseModel or dataclass.
|
||||
bases: Parents of the class, generally `cls.__bases__`.
|
||||
config_wrapper: The config wrapper instance.
|
||||
ns_resolver: Namespace resolver to use when getting model annotations.
|
||||
"""
|
||||
typevars_map = get_model_typevars_map(cls)
|
||||
fields, class_vars = collect_model_fields(cls, bases, config_wrapper, ns_resolver, typevars_map=typevars_map)
|
||||
|
||||
cls.__pydantic_fields__ = fields
|
||||
cls.__class_vars__.update(class_vars)
|
||||
|
||||
for k in class_vars:
|
||||
# Class vars should not be private attributes
|
||||
# We remove them _here_ and not earlier because we rely on inspecting the class to determine its classvars,
|
||||
# but private attributes are determined by inspecting the namespace _prior_ to class creation.
|
||||
# In the case that a classvar with a leading-'_' is defined via a ForwardRef (e.g., when using
|
||||
# `__future__.annotations`), we want to remove the private attribute which was detected _before_ we knew it
|
||||
# evaluated to a classvar
|
||||
|
||||
value = cls.__private_attributes__.pop(k, None)
|
||||
if value is not None and value.default is not PydanticUndefined:
|
||||
setattr(cls, k, value.default)
|
||||
|
||||
|
||||
def complete_model_class(
|
||||
cls: type[BaseModel],
|
||||
cls_name: str,
|
||||
config_wrapper: ConfigWrapper,
|
||||
*,
|
||||
raise_errors: bool = True,
|
||||
ns_resolver: NsResolver | None = None,
|
||||
create_model_module: str | None = None,
|
||||
) -> bool:
|
||||
"""Finish building a model class.
|
||||
|
||||
This logic must be called after class has been created since validation functions must be bound
|
||||
and `get_type_hints` requires a class object.
|
||||
|
||||
Args:
|
||||
cls: BaseModel or dataclass.
|
||||
cls_name: The model or dataclass name.
|
||||
config_wrapper: The config wrapper instance.
|
||||
raise_errors: Whether to raise errors.
|
||||
ns_resolver: The namespace resolver instance to use during schema building.
|
||||
create_model_module: The module of the class to be created, if created by `create_model`.
|
||||
|
||||
Returns:
|
||||
`True` if the model is successfully completed, else `False`.
|
||||
|
||||
Raises:
|
||||
PydanticUndefinedAnnotation: If `PydanticUndefinedAnnotation` occurs in`__get_pydantic_core_schema__`
|
||||
and `raise_errors=True`.
|
||||
"""
|
||||
if config_wrapper.defer_build:
|
||||
set_model_mocks(cls, cls_name)
|
||||
return False
|
||||
|
||||
typevars_map = get_model_typevars_map(cls)
|
||||
gen_schema = GenerateSchema(
|
||||
config_wrapper,
|
||||
ns_resolver,
|
||||
typevars_map,
|
||||
)
|
||||
|
||||
handler = CallbackGetCoreSchemaHandler(
|
||||
partial(gen_schema.generate_schema, from_dunder_get_core_schema=False),
|
||||
gen_schema,
|
||||
ref_mode='unpack',
|
||||
)
|
||||
|
||||
try:
|
||||
schema = cls.__get_pydantic_core_schema__(cls, handler)
|
||||
except PydanticUndefinedAnnotation as e:
|
||||
if raise_errors:
|
||||
raise
|
||||
set_model_mocks(cls, cls_name, f'`{e.name}`')
|
||||
return False
|
||||
|
||||
core_config = config_wrapper.core_config(title=cls.__name__)
|
||||
|
||||
try:
|
||||
schema = gen_schema.clean_schema(schema)
|
||||
except gen_schema.CollectedInvalid:
|
||||
set_model_mocks(cls, cls_name)
|
||||
return False
|
||||
|
||||
# debug(schema)
|
||||
cls.__pydantic_core_schema__ = schema
|
||||
|
||||
cls.__pydantic_validator__ = create_schema_validator(
|
||||
schema,
|
||||
cls,
|
||||
create_model_module or cls.__module__,
|
||||
cls.__qualname__,
|
||||
'create_model' if create_model_module else 'BaseModel',
|
||||
core_config,
|
||||
config_wrapper.plugin_settings,
|
||||
)
|
||||
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
|
||||
cls.__pydantic_complete__ = True
|
||||
|
||||
# set __signature__ attr only for model class, but not for its instances
|
||||
# (because instances can define `__call__`, and `inspect.signature` shouldn't
|
||||
# use the `__signature__` attribute and instead generate from `__call__`).
|
||||
cls.__signature__ = LazyClassAttribute(
|
||||
'__signature__',
|
||||
partial(
|
||||
generate_pydantic_signature,
|
||||
init=cls.__init__,
|
||||
fields=cls.__pydantic_fields__,
|
||||
populate_by_name=config_wrapper.populate_by_name,
|
||||
extra=config_wrapper.extra,
|
||||
),
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def set_deprecated_descriptors(cls: type[BaseModel]) -> None:
|
||||
"""Set data descriptors on the class for deprecated fields."""
|
||||
for field, field_info in cls.__pydantic_fields__.items():
|
||||
if (msg := field_info.deprecation_message) is not None:
|
||||
desc = _DeprecatedFieldDescriptor(msg)
|
||||
desc.__set_name__(cls, field)
|
||||
setattr(cls, field, desc)
|
||||
|
||||
for field, computed_field_info in cls.__pydantic_computed_fields__.items():
|
||||
if (
|
||||
(msg := computed_field_info.deprecation_message) is not None
|
||||
# Avoid having two warnings emitted:
|
||||
and not hasattr(unwrap_wrapped_function(computed_field_info.wrapped_property), '__deprecated__')
|
||||
):
|
||||
desc = _DeprecatedFieldDescriptor(msg, computed_field_info.wrapped_property)
|
||||
desc.__set_name__(cls, field)
|
||||
setattr(cls, field, desc)
|
||||
|
||||
|
||||
class _DeprecatedFieldDescriptor:
|
||||
"""Read-only data descriptor used to emit a runtime deprecation warning before accessing a deprecated field.
|
||||
|
||||
Attributes:
|
||||
msg: The deprecation message to be emitted.
|
||||
wrapped_property: The property instance if the deprecated field is a computed field, or `None`.
|
||||
field_name: The name of the field being deprecated.
|
||||
"""
|
||||
|
||||
field_name: str
|
||||
|
||||
def __init__(self, msg: str, wrapped_property: property | None = None) -> None:
|
||||
self.msg = msg
|
||||
self.wrapped_property = wrapped_property
|
||||
|
||||
def __set_name__(self, cls: type[BaseModel], name: str) -> None:
|
||||
self.field_name = name
|
||||
|
||||
def __get__(self, obj: BaseModel | None, obj_type: type[BaseModel] | None = None) -> Any:
|
||||
if obj is None:
|
||||
if self.wrapped_property is not None:
|
||||
return self.wrapped_property.__get__(None, obj_type)
|
||||
raise AttributeError(self.field_name)
|
||||
|
||||
warnings.warn(self.msg, builtins.DeprecationWarning, stacklevel=2)
|
||||
|
||||
if self.wrapped_property is not None:
|
||||
return self.wrapped_property.__get__(obj, obj_type)
|
||||
return obj.__dict__[self.field_name]
|
||||
|
||||
# Defined to make it a data descriptor and take precedence over the instance's dictionary.
|
||||
# Note that it will not be called when setting a value on a model instance
|
||||
# as `BaseModel.__setattr__` is defined and takes priority.
|
||||
def __set__(self, obj: Any, value: Any) -> NoReturn:
|
||||
raise AttributeError(self.field_name)
|
||||
|
||||
|
||||
class _PydanticWeakRef:
|
||||
"""Wrapper for `weakref.ref` that enables `pickle` serialization.
|
||||
|
||||
Cloudpickle fails to serialize `weakref.ref` objects due to an arcane error related
|
||||
to abstract base classes (`abc.ABC`). This class works around the issue by wrapping
|
||||
`weakref.ref` instead of subclassing it.
|
||||
|
||||
See https://github.com/pydantic/pydantic/issues/6763 for context.
|
||||
|
||||
Semantics:
|
||||
- If not pickled, behaves the same as a `weakref.ref`.
|
||||
- If pickled along with the referenced object, the same `weakref.ref` behavior
|
||||
will be maintained between them after unpickling.
|
||||
- If pickled without the referenced object, after unpickling the underlying
|
||||
reference will be cleared (`__call__` will always return `None`).
|
||||
"""
|
||||
|
||||
def __init__(self, obj: Any):
|
||||
if obj is None:
|
||||
# The object will be `None` upon deserialization if the serialized weakref
|
||||
# had lost its underlying object.
|
||||
self._wr = None
|
||||
else:
|
||||
self._wr = weakref.ref(obj)
|
||||
|
||||
def __call__(self) -> Any:
|
||||
if self._wr is None:
|
||||
return None
|
||||
else:
|
||||
return self._wr()
|
||||
|
||||
def __reduce__(self) -> tuple[Callable, tuple[weakref.ReferenceType | None]]:
|
||||
return _PydanticWeakRef, (self(),)
|
||||
|
||||
|
||||
def build_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
"""Takes an input dictionary, and produces a new value that (invertibly) replaces the values with weakrefs.
|
||||
|
||||
We can't just use a WeakValueDictionary because many types (including int, str, etc.) can't be stored as values
|
||||
in a WeakValueDictionary.
|
||||
|
||||
The `unpack_lenient_weakvaluedict` function can be used to reverse this operation.
|
||||
"""
|
||||
if d is None:
|
||||
return None
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
try:
|
||||
proxy = _PydanticWeakRef(v)
|
||||
except TypeError:
|
||||
proxy = v
|
||||
result[k] = proxy
|
||||
return result
|
||||
|
||||
|
||||
def unpack_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
"""Inverts the transform performed by `build_lenient_weakvaluedict`."""
|
||||
if d is None:
|
||||
return None
|
||||
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if isinstance(v, _PydanticWeakRef):
|
||||
v = v()
|
||||
if v is not None:
|
||||
result[k] = v
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def default_ignored_types() -> tuple[type[Any], ...]:
|
||||
from ..fields import ComputedFieldInfo
|
||||
|
||||
ignored_types = [
|
||||
FunctionType,
|
||||
property,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
PydanticDescriptorProxy,
|
||||
ComputedFieldInfo,
|
||||
TypeAliasType, # from `typing_extensions`
|
||||
]
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
ignored_types.append(typing.TypeAliasType)
|
||||
|
||||
return tuple(ignored_types)
|
||||
@@ -0,0 +1,284 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, Iterator, Mapping, NamedTuple, TypeVar
|
||||
|
||||
from typing_extensions import ParamSpec, TypeAlias, TypeAliasType, TypeVarTuple
|
||||
|
||||
GlobalsNamespace: TypeAlias = 'dict[str, Any]'
|
||||
"""A global namespace.
|
||||
|
||||
In most cases, this is a reference to the `__dict__` attribute of a module.
|
||||
This namespace type is expected as the `globals` argument during annotations evaluation.
|
||||
"""
|
||||
|
||||
MappingNamespace: TypeAlias = Mapping[str, Any]
|
||||
"""Any kind of namespace.
|
||||
|
||||
In most cases, this is a local namespace (e.g. the `__dict__` attribute of a class,
|
||||
the [`f_locals`][frame.f_locals] attribute of a frame object, when dealing with types
|
||||
defined inside functions).
|
||||
This namespace type is expected as the `locals` argument during annotations evaluation.
|
||||
"""
|
||||
|
||||
_TypeVarLike: TypeAlias = 'TypeVar | ParamSpec | TypeVarTuple'
|
||||
|
||||
|
||||
class NamespacesTuple(NamedTuple):
|
||||
"""A tuple of globals and locals to be used during annotations evaluation.
|
||||
|
||||
This datastructure is defined as a named tuple so that it can easily be unpacked:
|
||||
|
||||
```python {lint="skip" test="skip"}
|
||||
def eval_type(typ: type[Any], ns: NamespacesTuple) -> None:
|
||||
return eval(typ, *ns)
|
||||
```
|
||||
"""
|
||||
|
||||
globals: GlobalsNamespace
|
||||
"""The namespace to be used as the `globals` argument during annotations evaluation."""
|
||||
|
||||
locals: MappingNamespace
|
||||
"""The namespace to be used as the `locals` argument during annotations evaluation."""
|
||||
|
||||
|
||||
def get_module_ns_of(obj: Any) -> dict[str, Any]:
|
||||
"""Get the namespace of the module where the object is defined.
|
||||
|
||||
Caution: this function does not return a copy of the module namespace, so the result
|
||||
should not be mutated. The burden of enforcing this is on the caller.
|
||||
"""
|
||||
module_name = getattr(obj, '__module__', None)
|
||||
if module_name:
|
||||
try:
|
||||
return sys.modules[module_name].__dict__
|
||||
except KeyError:
|
||||
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
|
||||
return {}
|
||||
return {}
|
||||
|
||||
|
||||
# Note that this class is almost identical to `collections.ChainMap`, but need to enforce
|
||||
# immutable mappings here:
|
||||
class LazyLocalNamespace(Mapping[str, Any]):
|
||||
"""A lazily evaluated mapping, to be used as the `locals` argument during annotations evaluation.
|
||||
|
||||
While the [`eval`][eval] function expects a mapping as the `locals` argument, it only
|
||||
performs `__getitem__` calls. The [`Mapping`][collections.abc.Mapping] abstract base class
|
||||
is fully implemented only for type checking purposes.
|
||||
|
||||
Args:
|
||||
*namespaces: The namespaces to consider, in ascending order of priority.
|
||||
|
||||
Example:
|
||||
```python {lint="skip" test="skip"}
|
||||
ns = LazyLocalNamespace({'a': 1, 'b': 2}, {'a': 3})
|
||||
ns['a']
|
||||
#> 3
|
||||
ns['b']
|
||||
#> 2
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, *namespaces: MappingNamespace) -> None:
|
||||
self._namespaces = namespaces
|
||||
|
||||
@cached_property
|
||||
def data(self) -> dict[str, Any]:
|
||||
return {k: v for ns in self._namespaces for k, v in ns.items()}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self.data[key]
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
return key in self.data
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self.data)
|
||||
|
||||
|
||||
def ns_for_function(obj: Callable[..., Any], parent_namespace: MappingNamespace | None = None) -> NamespacesTuple:
|
||||
"""Return the global and local namespaces to be used when evaluating annotations for the provided function.
|
||||
|
||||
The global namespace will be the `__dict__` attribute of the module the function was defined in.
|
||||
The local namespace will contain the `__type_params__` introduced by PEP 695.
|
||||
|
||||
Args:
|
||||
obj: The object to use when building namespaces.
|
||||
parent_namespace: Optional namespace to be added with the lowest priority in the local namespace.
|
||||
If the passed function is a method, the `parent_namespace` will be the namespace of the class
|
||||
the method is defined in. Thus, we also fetch type `__type_params__` from there (i.e. the
|
||||
class-scoped type variables).
|
||||
"""
|
||||
locals_list: list[MappingNamespace] = []
|
||||
if parent_namespace is not None:
|
||||
locals_list.append(parent_namespace)
|
||||
|
||||
# Get the `__type_params__` attribute introduced by PEP 695.
|
||||
# Note that the `typing._eval_type` function expects type params to be
|
||||
# passed as a separate argument. However, internally, `_eval_type` calls
|
||||
# `ForwardRef._evaluate` which will merge type params with the localns,
|
||||
# essentially mimicking what we do here.
|
||||
type_params: tuple[_TypeVarLike, ...]
|
||||
if hasattr(obj, '__type_params__'):
|
||||
type_params = obj.__type_params__
|
||||
else:
|
||||
type_params = ()
|
||||
if parent_namespace is not None:
|
||||
# We also fetch type params from the parent namespace. If present, it probably
|
||||
# means the function was defined in a class. This is to support the following:
|
||||
# https://github.com/python/cpython/issues/124089.
|
||||
type_params += parent_namespace.get('__type_params__', ())
|
||||
|
||||
locals_list.append({t.__name__: t for t in type_params})
|
||||
|
||||
# What about short-cirtuiting to `obj.__globals__`?
|
||||
globalns = get_module_ns_of(obj)
|
||||
|
||||
return NamespacesTuple(globalns, LazyLocalNamespace(*locals_list))
|
||||
|
||||
|
||||
class NsResolver:
|
||||
"""A class responsible for the namespaces resolving logic for annotations evaluation.
|
||||
|
||||
This class handles the namespace logic when evaluating annotations mainly for class objects.
|
||||
|
||||
It holds a stack of classes that are being inspected during the core schema building,
|
||||
and the `types_namespace` property exposes the globals and locals to be used for
|
||||
type annotation evaluation. Additionally -- if no class is present in the stack -- a
|
||||
fallback globals and locals can be provided using the `namespaces_tuple` argument
|
||||
(this is useful when generating a schema for a simple annotation, e.g. when using
|
||||
`TypeAdapter`).
|
||||
|
||||
The namespace creation logic is unfortunately flawed in some cases, for backwards
|
||||
compatibility reasons and to better support valid edge cases. See the description
|
||||
for the `parent_namespace` argument and the example for more details.
|
||||
|
||||
Args:
|
||||
namespaces_tuple: The default globals and locals to use if no class is present
|
||||
on the stack. This can be useful when using the `GenerateSchema` class
|
||||
with `TypeAdapter`, where the "type" being analyzed is a simple annotation.
|
||||
parent_namespace: An optional parent namespace that will be added to the locals
|
||||
with the lowest priority. For a given class defined in a function, the locals
|
||||
of this function are usually used as the parent namespace:
|
||||
|
||||
```python {lint="skip" test="skip"}
|
||||
from pydantic import BaseModel
|
||||
|
||||
def func() -> None:
|
||||
SomeType = int
|
||||
|
||||
class Model(BaseModel):
|
||||
f: 'SomeType'
|
||||
|
||||
# when collecting fields, an namespace resolver instance will be created
|
||||
# this way:
|
||||
# ns_resolver = NsResolver(parent_namespace={'SomeType': SomeType})
|
||||
```
|
||||
|
||||
For backwards compatibility reasons and to support valid edge cases, this parent
|
||||
namespace will be used for *every* type being pushed to the stack. In the future,
|
||||
we might want to be smarter by only doing so when the type being pushed is defined
|
||||
in the same module as the parent namespace.
|
||||
|
||||
Example:
|
||||
```python {lint="skip" test="skip"}
|
||||
ns_resolver = NsResolver(
|
||||
parent_namespace={'fallback': 1},
|
||||
)
|
||||
|
||||
class Sub:
|
||||
m: 'Model'
|
||||
|
||||
class Model:
|
||||
some_local = 1
|
||||
sub: Sub
|
||||
|
||||
ns_resolver = NsResolver()
|
||||
|
||||
# This is roughly what happens when we build a core schema for `Model`:
|
||||
with ns_resolver.push(Model):
|
||||
ns_resolver.types_namespace
|
||||
#> NamespacesTuple({'Sub': Sub}, {'Model': Model, 'some_local': 1})
|
||||
# First thing to notice here, the model being pushed is added to the locals.
|
||||
# Because `NsResolver` is being used during the model definition, it is not
|
||||
# yet added to the globals. This is useful when resolving self-referencing annotations.
|
||||
|
||||
with ns_resolver.push(Sub):
|
||||
ns_resolver.types_namespace
|
||||
#> NamespacesTuple({'Sub': Sub}, {'Sub': Sub, 'Model': Model})
|
||||
# Second thing to notice: `Sub` is present in both the globals and locals.
|
||||
# This is not an issue, just that as described above, the model being pushed
|
||||
# is added to the locals, but it happens to be present in the globals as well
|
||||
# because it is already defined.
|
||||
# Third thing to notice: `Model` is also added in locals. This is a backwards
|
||||
# compatibility workaround that allows for `Sub` to be able to resolve `'Model'`
|
||||
# correctly (as otherwise models would have to be rebuilt even though this
|
||||
# doesn't look necessary).
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
namespaces_tuple: NamespacesTuple | None = None,
|
||||
parent_namespace: MappingNamespace | None = None,
|
||||
) -> None:
|
||||
self._base_ns_tuple = namespaces_tuple or NamespacesTuple({}, {})
|
||||
self._parent_ns = parent_namespace
|
||||
self._types_stack: list[type[Any] | TypeAliasType] = []
|
||||
|
||||
@cached_property
|
||||
def types_namespace(self) -> NamespacesTuple:
|
||||
"""The current global and local namespaces to be used for annotations evaluation."""
|
||||
if not self._types_stack:
|
||||
# TODO: should we merge the parent namespace here?
|
||||
# This is relevant for TypeAdapter, where there are no types on the stack, and we might
|
||||
# need access to the parent_ns. Right now, we sidestep this in `type_adapter.py` by passing
|
||||
# locals to both parent_ns and the base_ns_tuple, but this is a bit hacky.
|
||||
# we might consider something like:
|
||||
# if self._parent_ns is not None:
|
||||
# # Hacky workarounds, see class docstring:
|
||||
# # An optional parent namespace that will be added to the locals with the lowest priority
|
||||
# locals_list: list[MappingNamespace] = [self._parent_ns, self._base_ns_tuple.locals]
|
||||
# return NamespacesTuple(self._base_ns_tuple.globals, LazyLocalNamespace(*locals_list))
|
||||
return self._base_ns_tuple
|
||||
|
||||
typ = self._types_stack[-1]
|
||||
|
||||
globalns = get_module_ns_of(typ)
|
||||
|
||||
locals_list: list[MappingNamespace] = []
|
||||
# Hacky workarounds, see class docstring:
|
||||
# An optional parent namespace that will be added to the locals with the lowest priority
|
||||
if self._parent_ns is not None:
|
||||
locals_list.append(self._parent_ns)
|
||||
if len(self._types_stack) > 1:
|
||||
first_type = self._types_stack[0]
|
||||
locals_list.append({first_type.__name__: first_type})
|
||||
|
||||
if hasattr(typ, '__dict__'):
|
||||
# TypeAliasType is the exception.
|
||||
locals_list.append(vars(typ))
|
||||
|
||||
# The len check above presents this from being added twice:
|
||||
locals_list.append({typ.__name__: typ})
|
||||
|
||||
return NamespacesTuple(globalns, LazyLocalNamespace(*locals_list))
|
||||
|
||||
@contextmanager
|
||||
def push(self, typ: type[Any] | TypeAliasType, /) -> Generator[None]:
|
||||
"""Push a type to the stack."""
|
||||
self._types_stack.append(typ)
|
||||
# Reset the cached property:
|
||||
self.__dict__.pop('types_namespace', None)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._types_stack.pop()
|
||||
self.__dict__.pop('types_namespace', None)
|
||||
@@ -0,0 +1,123 @@
|
||||
"""Tools to provide pretty/human-readable display of objects."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import types
|
||||
import typing
|
||||
from typing import Any
|
||||
|
||||
import typing_extensions
|
||||
|
||||
from . import _typing_extra
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
ReprArgs: typing_extensions.TypeAlias = 'typing.Iterable[tuple[str | None, Any]]'
|
||||
RichReprResult: typing_extensions.TypeAlias = (
|
||||
'typing.Iterable[Any | tuple[Any] | tuple[str, Any] | tuple[str, Any, Any]]'
|
||||
)
|
||||
|
||||
|
||||
class PlainRepr(str):
|
||||
"""String class where repr doesn't include quotes. Useful with Representation when you want to return a string
|
||||
representation of something that is valid (or pseudo-valid) python.
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self)
|
||||
|
||||
|
||||
class Representation:
|
||||
# Mixin to provide `__str__`, `__repr__`, and `__pretty__` and `__rich_repr__` methods.
|
||||
# `__pretty__` is used by [devtools](https://python-devtools.helpmanual.io/).
|
||||
# `__rich_repr__` is used by [rich](https://rich.readthedocs.io/en/stable/pretty.html).
|
||||
# (this is not a docstring to avoid adding a docstring to classes which inherit from Representation)
|
||||
|
||||
# we don't want to use a type annotation here as it can break get_type_hints
|
||||
__slots__ = () # type: typing.Collection[str]
|
||||
|
||||
def __repr_args__(self) -> ReprArgs:
|
||||
"""Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden.
|
||||
|
||||
Can either return:
|
||||
* name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]`
|
||||
* or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]`
|
||||
"""
|
||||
attrs_names = self.__slots__
|
||||
if not attrs_names and hasattr(self, '__dict__'):
|
||||
attrs_names = self.__dict__.keys()
|
||||
attrs = ((s, getattr(self, s)) for s in attrs_names)
|
||||
return [(a, v if v is not self else self.__repr_recursion__(v)) for a, v in attrs if v is not None]
|
||||
|
||||
def __repr_name__(self) -> str:
|
||||
"""Name of the instance's class, used in __repr__."""
|
||||
return self.__class__.__name__
|
||||
|
||||
def __repr_recursion__(self, object: Any) -> str:
|
||||
"""Returns the string representation of a recursive object."""
|
||||
# This is copied over from the stdlib `pprint` module:
|
||||
return f'<Recursion on {type(object).__name__} with id={id(object)}>'
|
||||
|
||||
def __repr_str__(self, join_str: str) -> str:
|
||||
return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
|
||||
|
||||
def __pretty__(self, fmt: typing.Callable[[Any], Any], **kwargs: Any) -> typing.Generator[Any, None, None]:
|
||||
"""Used by devtools (https://python-devtools.helpmanual.io/) to pretty print objects."""
|
||||
yield self.__repr_name__() + '('
|
||||
yield 1
|
||||
for name, value in self.__repr_args__():
|
||||
if name is not None:
|
||||
yield name + '='
|
||||
yield fmt(value)
|
||||
yield ','
|
||||
yield 0
|
||||
yield -1
|
||||
yield ')'
|
||||
|
||||
def __rich_repr__(self) -> RichReprResult:
|
||||
"""Used by Rich (https://rich.readthedocs.io/en/stable/pretty.html) to pretty print objects."""
|
||||
for name, field_repr in self.__repr_args__():
|
||||
if name is None:
|
||||
yield field_repr
|
||||
else:
|
||||
yield name, field_repr
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr_str__(' ')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__repr_name__()}({self.__repr_str__(", ")})'
|
||||
|
||||
|
||||
def display_as_type(obj: Any) -> str:
|
||||
"""Pretty representation of a type, should be as close as possible to the original type definition string.
|
||||
|
||||
Takes some logic from `typing._type_repr`.
|
||||
"""
|
||||
if isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)):
|
||||
return obj.__name__
|
||||
elif obj is ...:
|
||||
return '...'
|
||||
elif isinstance(obj, Representation):
|
||||
return repr(obj)
|
||||
elif isinstance(obj, typing.ForwardRef) or _typing_extra.is_type_alias_type(obj):
|
||||
return str(obj)
|
||||
|
||||
if not isinstance(obj, (_typing_extra.typing_base, _typing_extra.WithArgsTypes, type)):
|
||||
obj = obj.__class__
|
||||
|
||||
if _typing_extra.origin_is_union(typing_extensions.get_origin(obj)):
|
||||
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
|
||||
return f'Union[{args}]'
|
||||
elif isinstance(obj, _typing_extra.WithArgsTypes):
|
||||
if _typing_extra.is_literal(obj):
|
||||
args = ', '.join(map(repr, typing_extensions.get_args(obj)))
|
||||
else:
|
||||
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
|
||||
try:
|
||||
return f'{obj.__qualname__}[{args}]'
|
||||
except AttributeError:
|
||||
return str(obj).replace('typing.', '').replace('typing_extensions.', '') # handles TypeAliasType in 3.12
|
||||
elif isinstance(obj, type):
|
||||
return obj.__qualname__
|
||||
else:
|
||||
return repr(obj).replace('typing.', '').replace('typing_extensions.', '')
|
||||
+126
@@ -0,0 +1,126 @@
|
||||
"""Types and utility functions used by various other internal tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from pydantic_core import core_schema
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ..annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..json_schema import GenerateJsonSchema, JsonSchemaValue
|
||||
from ._core_utils import CoreSchemaOrField
|
||||
from ._generate_schema import GenerateSchema
|
||||
from ._namespace_utils import NamespacesTuple
|
||||
|
||||
GetJsonSchemaFunction = Callable[[CoreSchemaOrField, GetJsonSchemaHandler], JsonSchemaValue]
|
||||
HandlerOverride = Callable[[CoreSchemaOrField], JsonSchemaValue]
|
||||
|
||||
|
||||
class GenerateJsonSchemaHandler(GetJsonSchemaHandler):
|
||||
"""JsonSchemaHandler implementation that doesn't do ref unwrapping by default.
|
||||
|
||||
This is used for any Annotated metadata so that we don't end up with conflicting
|
||||
modifications to the definition schema.
|
||||
|
||||
Used internally by Pydantic, please do not rely on this implementation.
|
||||
See `GetJsonSchemaHandler` for the handler API.
|
||||
"""
|
||||
|
||||
def __init__(self, generate_json_schema: GenerateJsonSchema, handler_override: HandlerOverride | None) -> None:
|
||||
self.generate_json_schema = generate_json_schema
|
||||
self.handler = handler_override or generate_json_schema.generate_inner
|
||||
self.mode = generate_json_schema.mode
|
||||
|
||||
def __call__(self, core_schema: CoreSchemaOrField, /) -> JsonSchemaValue:
|
||||
return self.handler(core_schema)
|
||||
|
||||
def resolve_ref_schema(self, maybe_ref_json_schema: JsonSchemaValue) -> JsonSchemaValue:
|
||||
"""Resolves `$ref` in the json schema.
|
||||
|
||||
This returns the input json schema if there is no `$ref` in json schema.
|
||||
|
||||
Args:
|
||||
maybe_ref_json_schema: The input json schema that may contains `$ref`.
|
||||
|
||||
Returns:
|
||||
Resolved json schema.
|
||||
|
||||
Raises:
|
||||
LookupError: If it can't find the definition for `$ref`.
|
||||
"""
|
||||
if '$ref' not in maybe_ref_json_schema:
|
||||
return maybe_ref_json_schema
|
||||
ref = maybe_ref_json_schema['$ref']
|
||||
json_schema = self.generate_json_schema.get_schema_from_definitions(ref)
|
||||
if json_schema is None:
|
||||
raise LookupError(
|
||||
f'Could not find a ref for {ref}.'
|
||||
' Maybe you tried to call resolve_ref_schema from within a recursive model?'
|
||||
)
|
||||
return json_schema
|
||||
|
||||
|
||||
class CallbackGetCoreSchemaHandler(GetCoreSchemaHandler):
|
||||
"""Wrapper to use an arbitrary function as a `GetCoreSchemaHandler`.
|
||||
|
||||
Used internally by Pydantic, please do not rely on this implementation.
|
||||
See `GetCoreSchemaHandler` for the handler API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler: Callable[[Any], core_schema.CoreSchema],
|
||||
generate_schema: GenerateSchema,
|
||||
ref_mode: Literal['to-def', 'unpack'] = 'to-def',
|
||||
) -> None:
|
||||
self._handler = handler
|
||||
self._generate_schema = generate_schema
|
||||
self._ref_mode = ref_mode
|
||||
|
||||
def __call__(self, source_type: Any, /) -> core_schema.CoreSchema:
|
||||
schema = self._handler(source_type)
|
||||
ref = schema.get('ref')
|
||||
if self._ref_mode == 'to-def':
|
||||
if ref is not None:
|
||||
self._generate_schema.defs.definitions[ref] = schema
|
||||
return core_schema.definition_reference_schema(ref)
|
||||
return schema
|
||||
else: # ref_mode = 'unpack
|
||||
return self.resolve_ref_schema(schema)
|
||||
|
||||
def _get_types_namespace(self) -> NamespacesTuple:
|
||||
return self._generate_schema._types_namespace
|
||||
|
||||
def generate_schema(self, source_type: Any, /) -> core_schema.CoreSchema:
|
||||
return self._generate_schema.generate_schema(source_type)
|
||||
|
||||
@property
|
||||
def field_name(self) -> str | None:
|
||||
return self._generate_schema.field_name_stack.get()
|
||||
|
||||
def resolve_ref_schema(self, maybe_ref_schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
|
||||
"""Resolves reference in the core schema.
|
||||
|
||||
Args:
|
||||
maybe_ref_schema: The input core schema that may contains reference.
|
||||
|
||||
Returns:
|
||||
Resolved core schema.
|
||||
|
||||
Raises:
|
||||
LookupError: If it can't find the definition for reference.
|
||||
"""
|
||||
if maybe_ref_schema['type'] == 'definition-ref':
|
||||
ref = maybe_ref_schema['schema_ref']
|
||||
if ref not in self._generate_schema.defs.definitions:
|
||||
raise LookupError(
|
||||
f'Could not find a ref for {ref}.'
|
||||
' Maybe you tried to call resolve_ref_schema from within a recursive model?'
|
||||
)
|
||||
return self._generate_schema.defs.definitions[ref]
|
||||
elif maybe_ref_schema['type'] == 'definitions':
|
||||
return self.resolve_ref_schema(maybe_ref_schema['schema'])
|
||||
return maybe_ref_schema
|
||||
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import collections.abc
|
||||
import typing
|
||||
from typing import Any
|
||||
|
||||
from pydantic_core import PydanticOmit, core_schema
|
||||
|
||||
SEQUENCE_ORIGIN_MAP: dict[Any, Any] = {
|
||||
typing.Deque: collections.deque,
|
||||
collections.deque: collections.deque,
|
||||
list: list,
|
||||
typing.List: list,
|
||||
set: set,
|
||||
typing.AbstractSet: set,
|
||||
typing.Set: set,
|
||||
frozenset: frozenset,
|
||||
typing.FrozenSet: frozenset,
|
||||
typing.Sequence: list,
|
||||
typing.MutableSequence: list,
|
||||
typing.MutableSet: set,
|
||||
# this doesn't handle subclasses of these
|
||||
# parametrized typing.Set creates one of these
|
||||
collections.abc.MutableSet: set,
|
||||
collections.abc.Set: frozenset,
|
||||
}
|
||||
|
||||
|
||||
def serialize_sequence_via_list(
|
||||
v: Any, handler: core_schema.SerializerFunctionWrapHandler, info: core_schema.SerializationInfo
|
||||
) -> Any:
|
||||
items: list[Any] = []
|
||||
|
||||
mapped_origin = SEQUENCE_ORIGIN_MAP.get(type(v), None)
|
||||
if mapped_origin is None:
|
||||
# we shouldn't hit this branch, should probably add a serialization error or something
|
||||
return v
|
||||
|
||||
for index, item in enumerate(v):
|
||||
try:
|
||||
v = handler(item, index)
|
||||
except PydanticOmit:
|
||||
pass
|
||||
else:
|
||||
items.append(v)
|
||||
|
||||
if info.mode_is_json():
|
||||
return items
|
||||
else:
|
||||
return mapped_origin(items)
|
||||
@@ -0,0 +1,188 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from inspect import Parameter, Signature, signature
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from ._utils import is_valid_identifier
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..config import ExtraValues
|
||||
from ..fields import FieldInfo
|
||||
|
||||
|
||||
# Copied over from stdlib dataclasses
|
||||
class _HAS_DEFAULT_FACTORY_CLASS:
|
||||
def __repr__(self):
|
||||
return '<factory>'
|
||||
|
||||
|
||||
_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS()
|
||||
|
||||
|
||||
def _field_name_for_signature(field_name: str, field_info: FieldInfo) -> str:
|
||||
"""Extract the correct name to use for the field when generating a signature.
|
||||
|
||||
Assuming the field has a valid alias, this will return the alias. Otherwise, it will return the field name.
|
||||
First priority is given to the alias, then the validation_alias, then the field name.
|
||||
|
||||
Args:
|
||||
field_name: The name of the field
|
||||
field_info: The corresponding FieldInfo object.
|
||||
|
||||
Returns:
|
||||
The correct name to use when generating a signature.
|
||||
"""
|
||||
if isinstance(field_info.alias, str) and is_valid_identifier(field_info.alias):
|
||||
return field_info.alias
|
||||
if isinstance(field_info.validation_alias, str) and is_valid_identifier(field_info.validation_alias):
|
||||
return field_info.validation_alias
|
||||
|
||||
return field_name
|
||||
|
||||
|
||||
def _process_param_defaults(param: Parameter) -> Parameter:
|
||||
"""Modify the signature for a parameter in a dataclass where the default value is a FieldInfo instance.
|
||||
|
||||
Args:
|
||||
param (Parameter): The parameter
|
||||
|
||||
Returns:
|
||||
Parameter: The custom processed parameter
|
||||
"""
|
||||
from ..fields import FieldInfo
|
||||
|
||||
param_default = param.default
|
||||
if isinstance(param_default, FieldInfo):
|
||||
annotation = param.annotation
|
||||
# Replace the annotation if appropriate
|
||||
# inspect does "clever" things to show annotations as strings because we have
|
||||
# `from __future__ import annotations` in main, we don't want that
|
||||
if annotation == 'Any':
|
||||
annotation = Any
|
||||
|
||||
# Replace the field default
|
||||
default = param_default.default
|
||||
if default is PydanticUndefined:
|
||||
if param_default.default_factory is PydanticUndefined:
|
||||
default = Signature.empty
|
||||
else:
|
||||
# this is used by dataclasses to indicate a factory exists:
|
||||
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
|
||||
return param.replace(
|
||||
annotation=annotation, name=_field_name_for_signature(param.name, param_default), default=default
|
||||
)
|
||||
return param
|
||||
|
||||
|
||||
def _generate_signature_parameters( # noqa: C901 (ignore complexity, could use a refactor)
|
||||
init: Callable[..., None],
|
||||
fields: dict[str, FieldInfo],
|
||||
populate_by_name: bool,
|
||||
extra: ExtraValues | None,
|
||||
) -> dict[str, Parameter]:
|
||||
"""Generate a mapping of parameter names to Parameter objects for a pydantic BaseModel or dataclass."""
|
||||
from itertools import islice
|
||||
|
||||
present_params = signature(init).parameters.values()
|
||||
merged_params: dict[str, Parameter] = {}
|
||||
var_kw = None
|
||||
use_var_kw = False
|
||||
|
||||
for param in islice(present_params, 1, None): # skip self arg
|
||||
# inspect does "clever" things to show annotations as strings because we have
|
||||
# `from __future__ import annotations` in main, we don't want that
|
||||
if fields.get(param.name):
|
||||
# exclude params with init=False
|
||||
if getattr(fields[param.name], 'init', True) is False:
|
||||
continue
|
||||
param = param.replace(name=_field_name_for_signature(param.name, fields[param.name]))
|
||||
if param.annotation == 'Any':
|
||||
param = param.replace(annotation=Any)
|
||||
if param.kind is param.VAR_KEYWORD:
|
||||
var_kw = param
|
||||
continue
|
||||
merged_params[param.name] = param
|
||||
|
||||
if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
|
||||
allow_names = populate_by_name
|
||||
for field_name, field in fields.items():
|
||||
# when alias is a str it should be used for signature generation
|
||||
param_name = _field_name_for_signature(field_name, field)
|
||||
|
||||
if field_name in merged_params or param_name in merged_params:
|
||||
continue
|
||||
|
||||
if not is_valid_identifier(param_name):
|
||||
if allow_names:
|
||||
param_name = field_name
|
||||
else:
|
||||
use_var_kw = True
|
||||
continue
|
||||
|
||||
if field.is_required():
|
||||
default = Parameter.empty
|
||||
elif field.default_factory is not None:
|
||||
# Mimics stdlib dataclasses:
|
||||
default = _HAS_DEFAULT_FACTORY
|
||||
else:
|
||||
default = field.default
|
||||
merged_params[param_name] = Parameter(
|
||||
param_name,
|
||||
Parameter.KEYWORD_ONLY,
|
||||
annotation=field.rebuild_annotation(),
|
||||
default=default,
|
||||
)
|
||||
|
||||
if extra == 'allow':
|
||||
use_var_kw = True
|
||||
|
||||
if var_kw and use_var_kw:
|
||||
# Make sure the parameter for extra kwargs
|
||||
# does not have the same name as a field
|
||||
default_model_signature = [
|
||||
('self', Parameter.POSITIONAL_ONLY),
|
||||
('data', Parameter.VAR_KEYWORD),
|
||||
]
|
||||
if [(p.name, p.kind) for p in present_params] == default_model_signature:
|
||||
# if this is the standard model signature, use extra_data as the extra args name
|
||||
var_kw_name = 'extra_data'
|
||||
else:
|
||||
# else start from var_kw
|
||||
var_kw_name = var_kw.name
|
||||
|
||||
# generate a name that's definitely unique
|
||||
while var_kw_name in fields:
|
||||
var_kw_name += '_'
|
||||
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
|
||||
|
||||
return merged_params
|
||||
|
||||
|
||||
def generate_pydantic_signature(
|
||||
init: Callable[..., None],
|
||||
fields: dict[str, FieldInfo],
|
||||
populate_by_name: bool,
|
||||
extra: ExtraValues | None,
|
||||
is_dataclass: bool = False,
|
||||
) -> Signature:
|
||||
"""Generate signature for a pydantic BaseModel or dataclass.
|
||||
|
||||
Args:
|
||||
init: The class init.
|
||||
fields: The model fields.
|
||||
populate_by_name: The `populate_by_name` value of the config.
|
||||
extra: The `extra` value of the config.
|
||||
is_dataclass: Whether the model is a dataclass.
|
||||
|
||||
Returns:
|
||||
The dataclass/BaseModel subclass signature.
|
||||
"""
|
||||
merged_params = _generate_signature_parameters(init, fields, populate_by_name, extra)
|
||||
|
||||
if is_dataclass:
|
||||
merged_params = {k: _process_param_defaults(v) for k, v in merged_params.items()}
|
||||
|
||||
return Signature(parameters=list(merged_params.values()), return_annotation=None)
|
||||
@@ -0,0 +1,404 @@
|
||||
"""Logic for generating pydantic-core schemas for standard library types.
|
||||
|
||||
Import of this module is deferred since it contains imports of many standard library modules.
|
||||
"""
|
||||
|
||||
# TODO: eventually, we'd like to move all of the types handled here to have pydantic-core validators
|
||||
# so that we can avoid this annotation injection and just use the standard pydantic-core schema generation
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import collections
|
||||
import collections.abc
|
||||
import dataclasses
|
||||
import os
|
||||
import typing
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Iterable, Tuple, TypeVar, cast
|
||||
|
||||
import typing_extensions
|
||||
from pydantic_core import (
|
||||
CoreSchema,
|
||||
PydanticCustomError,
|
||||
core_schema,
|
||||
)
|
||||
from typing_extensions import get_args, get_origin
|
||||
|
||||
from pydantic._internal._serializers import serialize_sequence_via_list
|
||||
from pydantic.errors import PydanticSchemaGenerationError
|
||||
from pydantic.types import Strict
|
||||
|
||||
from ..json_schema import JsonSchemaValue
|
||||
from . import _known_annotated_metadata, _typing_extra
|
||||
from ._import_utils import import_cached_field_info
|
||||
from ._internal_dataclass import slots_true
|
||||
from ._schema_generation_shared import GetCoreSchemaHandler, GetJsonSchemaHandler
|
||||
|
||||
FieldInfo = import_cached_field_info()
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ._generate_schema import GenerateSchema
|
||||
|
||||
StdSchemaFunction = Callable[[GenerateSchema, type[Any]], core_schema.CoreSchema]
|
||||
|
||||
|
||||
@dataclasses.dataclass(**slots_true)
|
||||
class InnerSchemaValidator:
|
||||
"""Use a fixed CoreSchema, avoiding interference from outward annotations."""
|
||||
|
||||
core_schema: CoreSchema
|
||||
js_schema: JsonSchemaValue | None = None
|
||||
js_core_schema: CoreSchema | None = None
|
||||
js_schema_update: JsonSchemaValue | None = None
|
||||
|
||||
def __get_pydantic_json_schema__(self, _schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
if self.js_schema is not None:
|
||||
return self.js_schema
|
||||
js_schema = handler(self.js_core_schema or self.core_schema)
|
||||
if self.js_schema_update is not None:
|
||||
js_schema.update(self.js_schema_update)
|
||||
return js_schema
|
||||
|
||||
def __get_pydantic_core_schema__(self, _source_type: Any, _handler: GetCoreSchemaHandler) -> CoreSchema:
|
||||
return self.core_schema
|
||||
|
||||
|
||||
def path_schema_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any]
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
import pathlib
|
||||
|
||||
orig_source_type: Any = get_origin(source_type) or source_type
|
||||
if (
|
||||
(source_type_args := get_args(source_type))
|
||||
and orig_source_type is os.PathLike
|
||||
and source_type_args[0] not in {str, bytes, Any}
|
||||
):
|
||||
return None
|
||||
|
||||
if orig_source_type not in {
|
||||
os.PathLike,
|
||||
pathlib.Path,
|
||||
pathlib.PurePath,
|
||||
pathlib.PosixPath,
|
||||
pathlib.PurePosixPath,
|
||||
pathlib.PureWindowsPath,
|
||||
}:
|
||||
return None
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.STR_CONSTRAINTS, orig_source_type)
|
||||
|
||||
is_first_arg_byte = source_type_args and source_type_args[0] is bytes
|
||||
construct_path = pathlib.PurePath if orig_source_type is os.PathLike else orig_source_type
|
||||
constrained_schema = (
|
||||
core_schema.bytes_schema(**metadata) if is_first_arg_byte else core_schema.str_schema(**metadata)
|
||||
)
|
||||
|
||||
def path_validator(input_value: str | bytes) -> os.PathLike[Any]: # type: ignore
|
||||
try:
|
||||
if is_first_arg_byte:
|
||||
if isinstance(input_value, bytes):
|
||||
try:
|
||||
input_value = input_value.decode()
|
||||
except UnicodeDecodeError as e:
|
||||
raise PydanticCustomError('bytes_type', 'Input must be valid bytes') from e
|
||||
else:
|
||||
raise PydanticCustomError('bytes_type', 'Input must be bytes')
|
||||
elif not isinstance(input_value, str):
|
||||
raise PydanticCustomError('path_type', 'Input is not a valid path')
|
||||
|
||||
return construct_path(input_value)
|
||||
except TypeError as e:
|
||||
raise PydanticCustomError('path_type', 'Input is not a valid path') from e
|
||||
|
||||
instance_schema = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.no_info_after_validator_function(path_validator, constrained_schema),
|
||||
python_schema=core_schema.is_instance_schema(orig_source_type),
|
||||
)
|
||||
|
||||
strict: bool | None = None
|
||||
for annotation in annotations:
|
||||
if isinstance(annotation, Strict):
|
||||
strict = annotation.strict
|
||||
|
||||
schema = core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.union_schema(
|
||||
[
|
||||
instance_schema,
|
||||
core_schema.no_info_after_validator_function(path_validator, constrained_schema),
|
||||
],
|
||||
custom_error_type='path_type',
|
||||
custom_error_message=f'Input is not a valid path for {orig_source_type}',
|
||||
strict=True,
|
||||
),
|
||||
strict_schema=instance_schema,
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
return (
|
||||
orig_source_type,
|
||||
[
|
||||
InnerSchemaValidator(schema, js_core_schema=constrained_schema, js_schema_update={'format': 'path'}),
|
||||
*remaining_annotations,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def deque_validator(
|
||||
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, maxlen: None | int
|
||||
) -> collections.deque[Any]:
|
||||
if isinstance(input_value, collections.deque):
|
||||
maxlens = [v for v in (input_value.maxlen, maxlen) if v is not None]
|
||||
if maxlens:
|
||||
maxlen = min(maxlens)
|
||||
return collections.deque(handler(input_value), maxlen=maxlen)
|
||||
else:
|
||||
return collections.deque(handler(input_value), maxlen=maxlen)
|
||||
|
||||
|
||||
@dataclasses.dataclass(**slots_true)
|
||||
class DequeValidator:
|
||||
item_source_type: type[Any]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
|
||||
if _typing_extra.is_any(self.item_source_type):
|
||||
items_schema = None
|
||||
else:
|
||||
items_schema = handler.generate_schema(self.item_source_type)
|
||||
|
||||
# if we have a MaxLen annotation might as well set that as the default maxlen on the deque
|
||||
# this lets us reuse existing metadata annotations to let users set the maxlen on a dequeue
|
||||
# that e.g. comes from JSON
|
||||
coerce_instance_wrap = partial(
|
||||
core_schema.no_info_wrap_validator_function,
|
||||
partial(deque_validator, maxlen=self.metadata.get('max_length', None)),
|
||||
)
|
||||
|
||||
# we have to use a lax list schema here, because we need to validate the deque's
|
||||
# items via a list schema, but it's ok if the deque itself is not a list
|
||||
metadata_with_strict_override = {**self.metadata, 'strict': False}
|
||||
constrained_schema = core_schema.list_schema(items_schema, **metadata_with_strict_override)
|
||||
|
||||
check_instance = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.list_schema(),
|
||||
python_schema=core_schema.is_instance_schema(collections.deque),
|
||||
)
|
||||
|
||||
serialization = core_schema.wrap_serializer_function_ser_schema(
|
||||
serialize_sequence_via_list, schema=items_schema or core_schema.any_schema(), info_arg=True
|
||||
)
|
||||
|
||||
strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)])
|
||||
|
||||
if self.metadata.get('strict', False):
|
||||
schema = strict
|
||||
else:
|
||||
lax = coerce_instance_wrap(constrained_schema)
|
||||
schema = core_schema.lax_or_strict_schema(lax_schema=lax, strict_schema=strict)
|
||||
schema['serialization'] = serialization
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def deque_schema_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any]
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
args = get_args(source_type)
|
||||
|
||||
if not args:
|
||||
args = typing.cast(Tuple[Any], (Any,))
|
||||
elif len(args) != 1:
|
||||
raise ValueError('Expected deque to have exactly 1 generic parameter')
|
||||
|
||||
item_source_type = args[0]
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.SEQUENCE_CONSTRAINTS, source_type)
|
||||
|
||||
return (source_type, [DequeValidator(item_source_type, metadata), *remaining_annotations])
|
||||
|
||||
|
||||
MAPPING_ORIGIN_MAP: dict[Any, Any] = {
|
||||
typing.DefaultDict: collections.defaultdict,
|
||||
collections.defaultdict: collections.defaultdict,
|
||||
collections.OrderedDict: collections.OrderedDict,
|
||||
typing_extensions.OrderedDict: collections.OrderedDict,
|
||||
dict: dict,
|
||||
typing.Dict: dict,
|
||||
collections.Counter: collections.Counter,
|
||||
typing.Counter: collections.Counter,
|
||||
# this doesn't handle subclasses of these
|
||||
typing.Mapping: dict,
|
||||
typing.MutableMapping: dict,
|
||||
# parametrized typing.{Mutable}Mapping creates one of these
|
||||
collections.abc.MutableMapping: dict,
|
||||
collections.abc.Mapping: dict,
|
||||
}
|
||||
|
||||
|
||||
def defaultdict_validator(
|
||||
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any]
|
||||
) -> collections.defaultdict[Any, Any]:
|
||||
if isinstance(input_value, collections.defaultdict):
|
||||
default_factory = input_value.default_factory
|
||||
return collections.defaultdict(default_factory, handler(input_value))
|
||||
else:
|
||||
return collections.defaultdict(default_default_factory, handler(input_value))
|
||||
|
||||
|
||||
def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]:
|
||||
def infer_default() -> Callable[[], Any]:
|
||||
allowed_default_types: dict[Any, Any] = {
|
||||
typing.Tuple: tuple,
|
||||
tuple: tuple,
|
||||
collections.abc.Sequence: tuple,
|
||||
collections.abc.MutableSequence: list,
|
||||
typing.List: list,
|
||||
list: list,
|
||||
typing.Sequence: list,
|
||||
typing.Set: set,
|
||||
set: set,
|
||||
typing.MutableSet: set,
|
||||
collections.abc.MutableSet: set,
|
||||
collections.abc.Set: frozenset,
|
||||
typing.MutableMapping: dict,
|
||||
typing.Mapping: dict,
|
||||
collections.abc.Mapping: dict,
|
||||
collections.abc.MutableMapping: dict,
|
||||
float: float,
|
||||
int: int,
|
||||
str: str,
|
||||
bool: bool,
|
||||
}
|
||||
values_type_origin = get_origin(values_source_type) or values_source_type
|
||||
instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`'
|
||||
if isinstance(values_type_origin, TypeVar):
|
||||
|
||||
def type_var_default_factory() -> None:
|
||||
raise RuntimeError(
|
||||
'Generic defaultdict cannot be used without a concrete value type or an'
|
||||
' explicit default factory, ' + instructions
|
||||
)
|
||||
|
||||
return type_var_default_factory
|
||||
elif values_type_origin not in allowed_default_types:
|
||||
# a somewhat subjective set of types that have reasonable default values
|
||||
allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())])
|
||||
raise PydanticSchemaGenerationError(
|
||||
f'Unable to infer a default factory for keys of type {values_source_type}.'
|
||||
f' Only {allowed_msg} are supported, other types require an explicit default factory'
|
||||
' ' + instructions
|
||||
)
|
||||
return allowed_default_types[values_type_origin]
|
||||
|
||||
# Assume Annotated[..., Field(...)]
|
||||
if _typing_extra.is_annotated(values_source_type):
|
||||
field_info = next((v for v in get_args(values_source_type) if isinstance(v, FieldInfo)), None)
|
||||
else:
|
||||
field_info = None
|
||||
if field_info and field_info.default_factory:
|
||||
# Assume the default factory does not take any argument:
|
||||
default_default_factory = cast(Callable[[], Any], field_info.default_factory)
|
||||
else:
|
||||
default_default_factory = infer_default()
|
||||
return default_default_factory
|
||||
|
||||
|
||||
@dataclasses.dataclass(**slots_true)
|
||||
class MappingValidator:
|
||||
mapped_origin: type[Any]
|
||||
keys_source_type: type[Any]
|
||||
values_source_type: type[Any]
|
||||
min_length: int | None = None
|
||||
max_length: int | None = None
|
||||
strict: bool = False
|
||||
|
||||
def serialize_mapping_via_dict(self, v: Any, handler: core_schema.SerializerFunctionWrapHandler) -> Any:
|
||||
return handler(v)
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
|
||||
if _typing_extra.is_any(self.keys_source_type):
|
||||
keys_schema = None
|
||||
else:
|
||||
keys_schema = handler.generate_schema(self.keys_source_type)
|
||||
if _typing_extra.is_any(self.values_source_type):
|
||||
values_schema = None
|
||||
else:
|
||||
values_schema = handler.generate_schema(self.values_source_type)
|
||||
|
||||
metadata = {'min_length': self.min_length, 'max_length': self.max_length, 'strict': self.strict}
|
||||
|
||||
if self.mapped_origin is dict:
|
||||
schema = core_schema.dict_schema(keys_schema, values_schema, **metadata)
|
||||
else:
|
||||
constrained_schema = core_schema.dict_schema(keys_schema, values_schema, **metadata)
|
||||
check_instance = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.dict_schema(),
|
||||
python_schema=core_schema.is_instance_schema(self.mapped_origin),
|
||||
)
|
||||
|
||||
if self.mapped_origin is collections.defaultdict:
|
||||
default_default_factory = get_defaultdict_default_default_factory(self.values_source_type)
|
||||
coerce_instance_wrap = partial(
|
||||
core_schema.no_info_wrap_validator_function,
|
||||
partial(defaultdict_validator, default_default_factory=default_default_factory),
|
||||
)
|
||||
else:
|
||||
coerce_instance_wrap = partial(core_schema.no_info_after_validator_function, self.mapped_origin)
|
||||
|
||||
serialization = core_schema.wrap_serializer_function_ser_schema(
|
||||
self.serialize_mapping_via_dict,
|
||||
schema=core_schema.dict_schema(
|
||||
keys_schema or core_schema.any_schema(), values_schema or core_schema.any_schema()
|
||||
),
|
||||
info_arg=False,
|
||||
)
|
||||
|
||||
strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)])
|
||||
|
||||
if metadata.get('strict', False):
|
||||
schema = strict
|
||||
else:
|
||||
lax = coerce_instance_wrap(constrained_schema)
|
||||
schema = core_schema.lax_or_strict_schema(lax_schema=lax, strict_schema=strict)
|
||||
schema['serialization'] = serialization
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def mapping_like_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any]
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
origin: Any = get_origin(source_type)
|
||||
|
||||
mapped_origin = MAPPING_ORIGIN_MAP.get(origin, None) if origin else MAPPING_ORIGIN_MAP.get(source_type, None)
|
||||
if mapped_origin is None:
|
||||
return None
|
||||
|
||||
args = get_args(source_type)
|
||||
|
||||
if not args:
|
||||
args = typing.cast(Tuple[Any, Any], (Any, Any))
|
||||
elif mapped_origin is collections.Counter:
|
||||
# a single generic
|
||||
if len(args) != 1:
|
||||
raise ValueError('Expected Counter to have exactly 1 generic parameter')
|
||||
args = (args[0], int) # keys are always an int
|
||||
elif len(args) != 2:
|
||||
raise ValueError('Expected mapping to have exactly 2 generic parameters')
|
||||
|
||||
keys_source_type, values_source_type = args
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.SEQUENCE_CONSTRAINTS, source_type)
|
||||
|
||||
return (
|
||||
source_type,
|
||||
[
|
||||
MappingValidator(mapped_origin, keys_source_type, values_source_type, **metadata),
|
||||
*remaining_annotations,
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,893 @@
|
||||
"""Logic for interacting with type annotations, mostly extensions, shims and hacks to wrap Python's typing module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections.abc
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
import warnings
|
||||
from functools import lru_cache, partial
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
import typing_extensions
|
||||
from typing_extensions import TypeIs, deprecated, get_args, get_origin
|
||||
|
||||
from ._namespace_utils import GlobalsNamespace, MappingNamespace, NsResolver, get_module_ns_of
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
NoneType = type(None)
|
||||
EllipsisType = type(Ellipsis)
|
||||
else:
|
||||
from types import EllipsisType as EllipsisType
|
||||
from types import NoneType as NoneType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
# See https://typing-extensions.readthedocs.io/en/latest/#runtime-use-of-types:
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _get_typing_objects_by_name_of(name: str) -> tuple[Any, ...]:
|
||||
"""Get the member named `name` from both `typing` and `typing-extensions` (if it exists)."""
|
||||
result = tuple(getattr(module, name) for module in (typing, typing_extensions) if hasattr(module, name))
|
||||
if not result:
|
||||
raise ValueError(f'Neither `typing` nor `typing_extensions` has an object called {name!r}')
|
||||
return result
|
||||
|
||||
|
||||
# As suggested by the `typing-extensions` documentation, we could apply caching to this method,
|
||||
# but it doesn't seem to improve performance. This also requires `obj` to be hashable, which
|
||||
# might not be always the case:
|
||||
def _is_typing_name(obj: object, name: str) -> bool:
|
||||
"""Return whether `obj` is the member of the typing modules (includes the `typing-extensions` one) named `name`."""
|
||||
# Using `any()` is slower:
|
||||
for thing in _get_typing_objects_by_name_of(name):
|
||||
if obj is thing:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_any(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is the `Any` special form.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_any(Any)
|
||||
#> True
|
||||
```
|
||||
"""
|
||||
return _is_typing_name(tp, name='Any')
|
||||
|
||||
|
||||
def is_union(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a `Union` special form.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_union(Union[int, str])
|
||||
#> True
|
||||
is_union(int | str)
|
||||
#> False
|
||||
```
|
||||
"""
|
||||
return _is_typing_name(get_origin(tp), name='Union')
|
||||
|
||||
|
||||
def is_literal(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a `Literal` special form.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_literal(Literal[42])
|
||||
#> True
|
||||
```
|
||||
"""
|
||||
return _is_typing_name(get_origin(tp), name='Literal')
|
||||
|
||||
|
||||
# TODO remove and replace with `get_args` when we drop support for Python 3.8
|
||||
# (see https://docs.python.org/3/whatsnew/3.9.html#id4).
|
||||
def literal_values(tp: Any, /) -> list[Any]:
|
||||
"""Return the values contained in the provided `Literal` special form."""
|
||||
if not is_literal(tp):
|
||||
return [tp]
|
||||
|
||||
values = get_args(tp)
|
||||
return [x for value in values for x in literal_values(value)]
|
||||
|
||||
|
||||
def is_annotated(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a `Annotated` special form.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_annotated(Annotated[int, ...])
|
||||
#> True
|
||||
```
|
||||
"""
|
||||
return _is_typing_name(get_origin(tp), name='Annotated')
|
||||
|
||||
|
||||
def annotated_type(tp: Any, /) -> Any | None:
|
||||
"""Return the type of the `Annotated` special form, or `None`."""
|
||||
return get_args(tp)[0] if is_annotated(tp) else None
|
||||
|
||||
|
||||
def is_unpack(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a `Unpack` special form.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_unpack(Unpack[Ts])
|
||||
#> True
|
||||
```
|
||||
"""
|
||||
return _is_typing_name(get_origin(tp), name='Unpack')
|
||||
|
||||
|
||||
def unpack_type(tp: Any, /) -> Any | None:
|
||||
"""Return the type wrapped by the `Unpack` special form, or `None`."""
|
||||
return get_args(tp)[0] if is_unpack(tp) else None
|
||||
|
||||
|
||||
def is_self(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is the `Self` special form.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_self(Self)
|
||||
#> True
|
||||
```
|
||||
"""
|
||||
return _is_typing_name(tp, name='Self')
|
||||
|
||||
|
||||
def is_new_type(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a `NewType`.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_new_type(NewType('MyInt', int))
|
||||
#> True
|
||||
```
|
||||
"""
|
||||
if sys.version_info < (3, 10):
|
||||
# On Python < 3.10, `typing.NewType` is a function
|
||||
return hasattr(tp, '__supertype__')
|
||||
else:
|
||||
return _is_typing_name(type(tp), name='NewType')
|
||||
|
||||
|
||||
def is_hashable(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is the `Hashable` class.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_hashable(Hashable)
|
||||
#> True
|
||||
```
|
||||
"""
|
||||
# `get_origin` is documented as normalizing any typing-module aliases to `collections` classes,
|
||||
# hence the second check:
|
||||
return tp is collections.abc.Hashable or get_origin(tp) is collections.abc.Hashable
|
||||
|
||||
|
||||
def is_callable(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a `Callable`, parametrized or not.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_callable(Callable[[int], str])
|
||||
#> True
|
||||
is_callable(typing.Callable)
|
||||
#> True
|
||||
is_callable(collections.abc.Callable)
|
||||
#> True
|
||||
```
|
||||
"""
|
||||
# `get_origin` is documented as normalizing any typing-module aliases to `collections` classes,
|
||||
# hence the second check:
|
||||
return tp is collections.abc.Callable or get_origin(tp) is collections.abc.Callable
|
||||
|
||||
|
||||
_PARAMSPEC_TYPES: tuple[type[typing_extensions.ParamSpec], ...] = (typing_extensions.ParamSpec,)
|
||||
if sys.version_info >= (3, 10):
|
||||
_PARAMSPEC_TYPES = (*_PARAMSPEC_TYPES, typing.ParamSpec) # pyright: ignore[reportAssignmentType]
|
||||
|
||||
|
||||
def is_paramspec(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a `ParamSpec`.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
P = ParamSpec('P')
|
||||
is_paramspec(P)
|
||||
#> True
|
||||
```
|
||||
"""
|
||||
return isinstance(tp, _PARAMSPEC_TYPES)
|
||||
|
||||
|
||||
_TYPE_ALIAS_TYPES: tuple[type[typing_extensions.TypeAliasType], ...] = (typing_extensions.TypeAliasType,)
|
||||
if sys.version_info >= (3, 12):
|
||||
_TYPE_ALIAS_TYPES = (*_TYPE_ALIAS_TYPES, typing.TypeAliasType)
|
||||
|
||||
|
||||
def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]:
|
||||
"""Return whether the provided argument is an instance of `TypeAliasType`.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
type Int = int
|
||||
is_type_alias_type(Int)
|
||||
#> True
|
||||
Str = TypeAliasType('Str', str)
|
||||
is_type_alias_type(Str)
|
||||
#> True
|
||||
```
|
||||
"""
|
||||
return isinstance(tp, _TYPE_ALIAS_TYPES)
|
||||
|
||||
|
||||
def is_classvar(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a `ClassVar` special form, parametrized or not.
|
||||
|
||||
Note that in most cases, you will want to use the `is_classvar_annotation` function,
|
||||
which is used to check if an annotation (in the context of a Pydantic model or dataclass)
|
||||
should be treated as being a class variable.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_classvar(ClassVar[int])
|
||||
#> True
|
||||
is_classvar(ClassVar)
|
||||
#> True
|
||||
"""
|
||||
# ClassVar is not necessarily parametrized:
|
||||
return _is_typing_name(tp, name='ClassVar') or _is_typing_name(get_origin(tp), name='ClassVar')
|
||||
|
||||
|
||||
_classvar_re = re.compile(r'((\w+\.)?Annotated\[)?(\w+\.)?ClassVar\[')
|
||||
|
||||
|
||||
def is_classvar_annotation(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument represents a class variable annotation.
|
||||
|
||||
Although not explicitly stated by the typing specification, `ClassVar` can be used
|
||||
inside `Annotated` and as such, this function checks for this specific scenario.
|
||||
|
||||
Because this function is used to detect class variables before evaluating forward references
|
||||
(or because evaluation failed), we also implement a naive regex match implementation. This is
|
||||
required because class variables are inspected before fields are collected, so we try to be
|
||||
as accurate as possible.
|
||||
"""
|
||||
if is_classvar(tp) or (anntp := annotated_type(tp)) is not None and is_classvar(anntp):
|
||||
return True
|
||||
|
||||
str_ann: str | None = None
|
||||
if isinstance(tp, typing.ForwardRef):
|
||||
str_ann = tp.__forward_arg__
|
||||
if isinstance(tp, str):
|
||||
str_ann = tp
|
||||
|
||||
if str_ann is not None and _classvar_re.match(str_ann):
|
||||
# stdlib dataclasses do something similar, although a bit more advanced
|
||||
# (see `dataclass._is_type`).
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# TODO implement `is_finalvar_annotation` as Final can be wrapped with other special forms:
|
||||
def is_finalvar(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a `Final` special form, parametrized or not.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_finalvar(Final[int])
|
||||
#> True
|
||||
is_finalvar(Final)
|
||||
#> True
|
||||
"""
|
||||
# Final is not necessarily parametrized:
|
||||
return _is_typing_name(tp, name='Final') or _is_typing_name(get_origin(tp), name='Final')
|
||||
|
||||
|
||||
def is_required(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a `Required` special form.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_required(Required[int])
|
||||
#> True
|
||||
"""
|
||||
return _is_typing_name(get_origin(tp), name='Required')
|
||||
|
||||
|
||||
def is_not_required(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a `NotRequired` special form.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_required(Required[int])
|
||||
#> True
|
||||
"""
|
||||
return _is_typing_name(get_origin(tp), name='NotRequired')
|
||||
|
||||
|
||||
def is_no_return(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is the `NoReturn` special form.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_no_return(NoReturn)
|
||||
#> True
|
||||
```
|
||||
"""
|
||||
return _is_typing_name(tp, name='NoReturn')
|
||||
|
||||
|
||||
def is_never(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is the `Never` special form.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_never(Never)
|
||||
#> True
|
||||
```
|
||||
"""
|
||||
return _is_typing_name(tp, name='Never')
|
||||
|
||||
|
||||
_DEPRECATED_TYPES: tuple[type[typing_extensions.deprecated], ...] = (typing_extensions.deprecated,)
|
||||
if hasattr(warnings, 'deprecated'):
|
||||
_DEPRECATED_TYPES = (*_DEPRECATED_TYPES, warnings.deprecated) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
|
||||
def is_deprecated_instance(obj: Any, /) -> TypeIs[deprecated]:
|
||||
"""Return whether the argument is an instance of the `warnings.deprecated` class or the `typing_extensions` backport."""
|
||||
return isinstance(obj, _DEPRECATED_TYPES)
|
||||
|
||||
|
||||
_NONE_TYPES: tuple[Any, ...] = (None, NoneType, typing.Literal[None], typing_extensions.Literal[None])
|
||||
|
||||
|
||||
def is_none_type(tp: Any, /) -> bool:
|
||||
"""Return whether the argument represents the `None` type as part of an annotation.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_none_type(None)
|
||||
#> True
|
||||
is_none_type(NoneType)
|
||||
#> True
|
||||
is_none_type(Literal[None])
|
||||
#> True
|
||||
is_none_type(type[None])
|
||||
#> False
|
||||
"""
|
||||
return tp in _NONE_TYPES
|
||||
|
||||
|
||||
def is_namedtuple(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a named tuple class.
|
||||
|
||||
The class can be created using `typing.NamedTuple` or `collections.namedtuple`.
|
||||
Parametrized generic classes are *not* assumed to be named tuples.
|
||||
"""
|
||||
from ._utils import lenient_issubclass # circ. import
|
||||
|
||||
return lenient_issubclass(tp, tuple) and hasattr(tp, '_fields')
|
||||
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
|
||||
def is_zoneinfo_type(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is the `zoneinfo.ZoneInfo` type."""
|
||||
return False
|
||||
|
||||
else:
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
def is_zoneinfo_type(tp: Any, /) -> TypeIs[type[ZoneInfo]]:
|
||||
"""Return whether the provided argument is the `zoneinfo.ZoneInfo` type."""
|
||||
return tp is ZoneInfo
|
||||
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
|
||||
def origin_is_union(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is the `Union` special form."""
|
||||
return _is_typing_name(tp, name='Union')
|
||||
|
||||
def is_generic_alias(type_: type[Any]) -> bool:
|
||||
return isinstance(type_, typing._GenericAlias) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
else:
|
||||
|
||||
def origin_is_union(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is the `Union` special form or the `UnionType`."""
|
||||
return _is_typing_name(tp, name='Union') or tp is types.UnionType
|
||||
|
||||
def is_generic_alias(tp: Any, /) -> bool:
|
||||
return isinstance(tp, (types.GenericAlias, typing._GenericAlias)) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
|
||||
# TODO: Ideally, we should avoid relying on the private `typing` constructs:
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
WithArgsTypes: tuple[Any, ...] = (typing._GenericAlias,) # pyright: ignore[reportAttributeAccessIssue]
|
||||
elif sys.version_info < (3, 10):
|
||||
WithArgsTypes: tuple[Any, ...] = (typing._GenericAlias, types.GenericAlias) # pyright: ignore[reportAttributeAccessIssue]
|
||||
else:
|
||||
WithArgsTypes: tuple[Any, ...] = (typing._GenericAlias, types.GenericAlias, types.UnionType) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
|
||||
# Similarly, we shouldn't rely on this `_Final` class, which is even more private than `_GenericAlias`:
|
||||
typing_base: Any = typing._Final # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
|
||||
### Annotation evaluations functions:
|
||||
|
||||
|
||||
def parent_frame_namespace(*, parent_depth: int = 2, force: bool = False) -> dict[str, Any] | None:
|
||||
"""We allow use of items in parent namespace to get around the issue with `get_type_hints` only looking in the
|
||||
global module namespace. See https://github.com/pydantic/pydantic/issues/2678#issuecomment-1008139014 -> Scope
|
||||
and suggestion at the end of the next comment by @gvanrossum.
|
||||
|
||||
WARNING 1: it matters exactly where this is called. By default, this function will build a namespace from the
|
||||
parent of where it is called.
|
||||
|
||||
WARNING 2: this only looks in the parent namespace, not other parents since (AFAIK) there's no way to collect a
|
||||
dict of exactly what's in scope. Using `f_back` would work sometimes but would be very wrong and confusing in many
|
||||
other cases. See https://discuss.python.org/t/is-there-a-way-to-access-parent-nested-namespaces/20659.
|
||||
|
||||
There are some cases where we want to force fetching the parent namespace, ex: during a `model_rebuild` call.
|
||||
In this case, we want both the namespace of the class' module, if applicable, and the parent namespace of the
|
||||
module where the rebuild is called.
|
||||
|
||||
In other cases, like during initial schema build, if a class is defined at the top module level, we don't need to
|
||||
fetch that module's namespace, because the class' __module__ attribute can be used to access the parent namespace.
|
||||
This is done in `_namespace_utils.get_module_ns_of`. Thus, there's no need to cache the parent frame namespace in this case.
|
||||
"""
|
||||
frame = sys._getframe(parent_depth)
|
||||
|
||||
# note, we don't copy frame.f_locals here (or during the last return call), because we don't expect the namespace to be modified down the line
|
||||
# if this becomes a problem, we could implement some sort of frozen mapping structure to enforce this
|
||||
if force:
|
||||
return frame.f_locals
|
||||
|
||||
# if either of the following conditions are true, the class is defined at the top module level
|
||||
# to better understand why we need both of these checks, see
|
||||
# https://github.com/pydantic/pydantic/pull/10113#discussion_r1714981531
|
||||
if frame.f_back is None or frame.f_code.co_name == '<module>':
|
||||
return None
|
||||
|
||||
return frame.f_locals
|
||||
|
||||
|
||||
def _type_convert(arg: Any) -> Any:
|
||||
"""Convert `None` to `NoneType` and strings to `ForwardRef` instances.
|
||||
|
||||
This is a backport of the private `typing._type_convert` function. When
|
||||
evaluating a type, `ForwardRef._evaluate` ends up being called, and is
|
||||
responsible for making this conversion. However, we still have to apply
|
||||
it for the first argument passed to our type evaluation functions, similarly
|
||||
to the `typing.get_type_hints` function.
|
||||
"""
|
||||
if arg is None:
|
||||
return NoneType
|
||||
if isinstance(arg, str):
|
||||
# Like `typing.get_type_hints`, assume the arg can be in any context,
|
||||
# hence the proper `is_argument` and `is_class` args:
|
||||
return _make_forward_ref(arg, is_argument=False, is_class=True)
|
||||
return arg
|
||||
|
||||
|
||||
def get_model_type_hints(
|
||||
obj: type[BaseModel],
|
||||
*,
|
||||
ns_resolver: NsResolver | None = None,
|
||||
) -> dict[str, tuple[Any, bool]]:
|
||||
"""Collect annotations from a Pydantic model class, including those from parent classes.
|
||||
|
||||
Args:
|
||||
obj: The Pydantic model to inspect.
|
||||
ns_resolver: A namespace resolver instance to use. Defaults to an empty instance.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping annotation names to a two-tuple: the first element is the evaluated
|
||||
type or the original annotation if a `NameError` occurred, the second element is a boolean
|
||||
indicating if whether the evaluation succeeded.
|
||||
"""
|
||||
hints: dict[str, Any] | dict[str, tuple[Any, bool]] = {}
|
||||
ns_resolver = ns_resolver or NsResolver()
|
||||
|
||||
for base in reversed(obj.__mro__):
|
||||
ann: dict[str, Any] | None = base.__dict__.get('__annotations__')
|
||||
if not ann or isinstance(ann, types.GetSetDescriptorType):
|
||||
continue
|
||||
with ns_resolver.push(base):
|
||||
globalns, localns = ns_resolver.types_namespace
|
||||
for name, value in ann.items():
|
||||
if name.startswith('_'):
|
||||
# For private attributes, we only need the annotation to detect the `ClassVar` special form.
|
||||
# For this reason, we still try to evaluate it, but we also catch any possible exception (on
|
||||
# top of the `NameError`s caught in `try_eval_type`) that could happen so that users are free
|
||||
# to use any kind of forward annotation for private fields (e.g. circular imports, new typing
|
||||
# syntax, etc).
|
||||
try:
|
||||
hints[name] = try_eval_type(value, globalns, localns)
|
||||
except Exception:
|
||||
hints[name] = (value, False)
|
||||
else:
|
||||
hints[name] = try_eval_type(value, globalns, localns)
|
||||
return hints
|
||||
|
||||
|
||||
def get_cls_type_hints(
|
||||
obj: type[Any],
|
||||
*,
|
||||
ns_resolver: NsResolver | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Collect annotations from a class, including those from parent classes.
|
||||
|
||||
Args:
|
||||
obj: The class to inspect.
|
||||
ns_resolver: A namespace resolver instance to use. Defaults to an empty instance.
|
||||
"""
|
||||
hints: dict[str, Any] | dict[str, tuple[Any, bool]] = {}
|
||||
ns_resolver = ns_resolver or NsResolver()
|
||||
|
||||
for base in reversed(obj.__mro__):
|
||||
ann: dict[str, Any] | None = base.__dict__.get('__annotations__')
|
||||
if not ann or isinstance(ann, types.GetSetDescriptorType):
|
||||
continue
|
||||
with ns_resolver.push(base):
|
||||
globalns, localns = ns_resolver.types_namespace
|
||||
for name, value in ann.items():
|
||||
hints[name] = eval_type(value, globalns, localns)
|
||||
return hints
|
||||
|
||||
|
||||
def try_eval_type(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
) -> tuple[Any, bool]:
|
||||
"""Try evaluating the annotation using the provided namespaces.
|
||||
|
||||
Args:
|
||||
value: The value to evaluate. If `None`, it will be replaced by `type[None]`. If an instance
|
||||
of `str`, it will be converted to a `ForwardRef`.
|
||||
localns: The global namespace to use during annotation evaluation.
|
||||
globalns: The local namespace to use during annotation evaluation.
|
||||
|
||||
Returns:
|
||||
A two-tuple containing the possibly evaluated type and a boolean indicating
|
||||
whether the evaluation succeeded or not.
|
||||
"""
|
||||
value = _type_convert(value)
|
||||
|
||||
try:
|
||||
return eval_type_backport(value, globalns, localns), True
|
||||
except NameError:
|
||||
return value, False
|
||||
|
||||
|
||||
def eval_type(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
) -> Any:
|
||||
"""Evaluate the annotation using the provided namespaces.
|
||||
|
||||
Args:
|
||||
value: The value to evaluate. If `None`, it will be replaced by `type[None]`. If an instance
|
||||
of `str`, it will be converted to a `ForwardRef`.
|
||||
localns: The global namespace to use during annotation evaluation.
|
||||
globalns: The local namespace to use during annotation evaluation.
|
||||
"""
|
||||
value = _type_convert(value)
|
||||
return eval_type_backport(value, globalns, localns)
|
||||
|
||||
|
||||
@deprecated(
|
||||
'`eval_type_lenient` is deprecated, use `try_eval_type` instead.',
|
||||
category=None,
|
||||
)
|
||||
def eval_type_lenient(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
) -> Any:
|
||||
ev, _ = try_eval_type(value, globalns, localns)
|
||||
return ev
|
||||
|
||||
|
||||
def eval_type_backport(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
type_params: tuple[Any, ...] | None = None,
|
||||
) -> Any:
|
||||
"""An enhanced version of `typing._eval_type` which will fall back to using the `eval_type_backport`
|
||||
package if it's installed to let older Python versions use newer typing constructs.
|
||||
|
||||
Specifically, this transforms `X | Y` into `typing.Union[X, Y]` and `list[X]` into `typing.List[X]`
|
||||
(as well as all the types made generic in PEP 585) if the original syntax is not supported in the
|
||||
current Python version.
|
||||
|
||||
This function will also display a helpful error if the value passed fails to evaluate.
|
||||
"""
|
||||
try:
|
||||
return _eval_type_backport(value, globalns, localns, type_params)
|
||||
except TypeError as e:
|
||||
if 'Unable to evaluate type annotation' in str(e):
|
||||
raise
|
||||
|
||||
# If it is a `TypeError` and value isn't a `ForwardRef`, it would have failed during annotation definition.
|
||||
# Thus we assert here for type checking purposes:
|
||||
assert isinstance(value, typing.ForwardRef)
|
||||
|
||||
message = f'Unable to evaluate type annotation {value.__forward_arg__!r}.'
|
||||
if sys.version_info >= (3, 11):
|
||||
e.add_note(message)
|
||||
raise
|
||||
else:
|
||||
raise TypeError(message) from e
|
||||
|
||||
|
||||
def _eval_type_backport(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
type_params: tuple[Any, ...] | None = None,
|
||||
) -> Any:
|
||||
try:
|
||||
return _eval_type(value, globalns, localns, type_params)
|
||||
except TypeError as e:
|
||||
if not (isinstance(value, typing.ForwardRef) and is_backport_fixable_error(e)):
|
||||
raise
|
||||
|
||||
try:
|
||||
from eval_type_backport import eval_type_backport
|
||||
except ImportError:
|
||||
raise TypeError(
|
||||
f'Unable to evaluate type annotation {value.__forward_arg__!r}. If you are making use '
|
||||
'of the new typing syntax (unions using `|` since Python 3.10 or builtins subscripting '
|
||||
'since Python 3.9), you should either replace the use of new syntax with the existing '
|
||||
'`typing` constructs or install the `eval_type_backport` package.'
|
||||
) from e
|
||||
|
||||
return eval_type_backport(
|
||||
value,
|
||||
globalns,
|
||||
localns, # pyright: ignore[reportArgumentType], waiting on a new `eval_type_backport` release.
|
||||
try_default=False,
|
||||
)
|
||||
|
||||
|
||||
def _eval_type(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
type_params: tuple[Any, ...] | None = None,
|
||||
) -> Any:
|
||||
if sys.version_info >= (3, 13):
|
||||
return typing._eval_type( # type: ignore
|
||||
value, globalns, localns, type_params=type_params
|
||||
)
|
||||
else:
|
||||
return typing._eval_type( # type: ignore
|
||||
value, globalns, localns
|
||||
)
|
||||
|
||||
|
||||
def is_backport_fixable_error(e: TypeError) -> bool:
|
||||
msg = str(e)
|
||||
|
||||
return (
|
||||
sys.version_info < (3, 10)
|
||||
and msg.startswith('unsupported operand type(s) for |: ')
|
||||
or sys.version_info < (3, 9)
|
||||
and "' object is not subscriptable" in msg
|
||||
)
|
||||
|
||||
|
||||
def get_function_type_hints(
|
||||
function: Callable[..., Any],
|
||||
*,
|
||||
include_keys: set[str] | None = None,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return type hints for a function.
|
||||
|
||||
This is similar to the `typing.get_type_hints` function, with a few differences:
|
||||
- Support `functools.partial` by using the underlying `func` attribute.
|
||||
- If `function` happens to be a built-in type (e.g. `int`), assume it doesn't have annotations
|
||||
but specify the `return` key as being the actual type.
|
||||
- Do not wrap type annotation of a parameter with `Optional` if it has a default value of `None`
|
||||
(related bug: https://github.com/python/cpython/issues/90353, only fixed in 3.11+).
|
||||
"""
|
||||
try:
|
||||
if isinstance(function, partial):
|
||||
annotations = function.func.__annotations__
|
||||
else:
|
||||
annotations = function.__annotations__
|
||||
except AttributeError:
|
||||
type_hints = get_type_hints(function)
|
||||
if isinstance(function, type):
|
||||
# `type[...]` is a callable, which returns an instance of itself.
|
||||
# At some point, we might even look into the return type of `__new__`
|
||||
# if it returns something else.
|
||||
type_hints.setdefault('return', function)
|
||||
return type_hints
|
||||
|
||||
if globalns is None:
|
||||
globalns = get_module_ns_of(function)
|
||||
type_params: tuple[Any, ...] | None = None
|
||||
if localns is None:
|
||||
# If localns was specified, it is assumed to already contain type params. This is because
|
||||
# Pydantic has more advanced logic to do so (see `_namespace_utils.ns_for_function`).
|
||||
type_params = getattr(function, '__type_params__', ())
|
||||
|
||||
type_hints = {}
|
||||
for name, value in annotations.items():
|
||||
if include_keys is not None and name not in include_keys:
|
||||
continue
|
||||
if value is None:
|
||||
value = NoneType
|
||||
elif isinstance(value, str):
|
||||
value = _make_forward_ref(value)
|
||||
|
||||
type_hints[name] = eval_type_backport(value, globalns, localns, type_params)
|
||||
|
||||
return type_hints
|
||||
|
||||
|
||||
if sys.version_info < (3, 9, 8) or (3, 10) <= sys.version_info < (3, 10, 1):
|
||||
|
||||
def _make_forward_ref(
|
||||
arg: Any,
|
||||
is_argument: bool = True,
|
||||
*,
|
||||
is_class: bool = False,
|
||||
) -> typing.ForwardRef:
|
||||
"""Wrapper for ForwardRef that accounts for the `is_class` argument missing in older versions.
|
||||
The `module` argument is omitted as it breaks <3.9.8, =3.10.0 and isn't used in the calls below.
|
||||
|
||||
See https://github.com/python/cpython/pull/28560 for some background.
|
||||
The backport happened on 3.9.8, see:
|
||||
https://github.com/pydantic/pydantic/discussions/6244#discussioncomment-6275458,
|
||||
and on 3.10.1 for the 3.10 branch, see:
|
||||
https://github.com/pydantic/pydantic/issues/6912
|
||||
|
||||
Implemented as EAFP with memory.
|
||||
"""
|
||||
return typing.ForwardRef(arg, is_argument)
|
||||
|
||||
else:
|
||||
_make_forward_ref = typing.ForwardRef
|
||||
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
get_type_hints = typing.get_type_hints
|
||||
|
||||
else:
|
||||
"""
|
||||
For older versions of python, we have a custom implementation of `get_type_hints` which is a close as possible to
|
||||
the implementation in CPython 3.10.8.
|
||||
"""
|
||||
|
||||
@typing.no_type_check
|
||||
def get_type_hints( # noqa: C901
|
||||
obj: Any,
|
||||
globalns: dict[str, Any] | None = None,
|
||||
localns: dict[str, Any] | None = None,
|
||||
include_extras: bool = False,
|
||||
) -> dict[str, Any]: # pragma: no cover
|
||||
"""Taken verbatim from python 3.10.8 unchanged, except:
|
||||
* type annotations of the function definition above.
|
||||
* prefixing `typing.` where appropriate
|
||||
* Use `_make_forward_ref` instead of `typing.ForwardRef` to handle the `is_class` argument.
|
||||
|
||||
https://github.com/python/cpython/blob/aaaf5174241496afca7ce4d4584570190ff972fe/Lib/typing.py#L1773-L1875
|
||||
|
||||
DO NOT CHANGE THIS METHOD UNLESS ABSOLUTELY NECESSARY.
|
||||
======================================================
|
||||
|
||||
Return type hints for an object.
|
||||
|
||||
This is often the same as obj.__annotations__, but it handles
|
||||
forward references encoded as string literals, adds Optional[t] if a
|
||||
default value equal to None is set and recursively replaces all
|
||||
'Annotated[T, ...]' with 'T' (unless 'include_extras=True').
|
||||
|
||||
The argument may be a module, class, method, or function. The annotations
|
||||
are returned as a dictionary. For classes, annotations include also
|
||||
inherited members.
|
||||
|
||||
TypeError is raised if the argument is not of a type that can contain
|
||||
annotations, and an empty dictionary is returned if no annotations are
|
||||
present.
|
||||
|
||||
BEWARE -- the behavior of globalns and localns is counterintuitive
|
||||
(unless you are familiar with how eval() and exec() work). The
|
||||
search order is locals first, then globals.
|
||||
|
||||
- If no dict arguments are passed, an attempt is made to use the
|
||||
globals from obj (or the respective module's globals for classes),
|
||||
and these are also used as the locals. If the object does not appear
|
||||
to have globals, an empty dictionary is used. For classes, the search
|
||||
order is globals first then locals.
|
||||
|
||||
- If one dict argument is passed, it is used for both globals and
|
||||
locals.
|
||||
|
||||
- If two dict arguments are passed, they specify globals and
|
||||
locals, respectively.
|
||||
"""
|
||||
if getattr(obj, '__no_type_check__', None):
|
||||
return {}
|
||||
# Classes require a special treatment.
|
||||
if isinstance(obj, type):
|
||||
hints = {}
|
||||
for base in reversed(obj.__mro__):
|
||||
if globalns is None:
|
||||
base_globals = getattr(sys.modules.get(base.__module__, None), '__dict__', {})
|
||||
else:
|
||||
base_globals = globalns
|
||||
ann = base.__dict__.get('__annotations__', {})
|
||||
if isinstance(ann, types.GetSetDescriptorType):
|
||||
ann = {}
|
||||
base_locals = dict(vars(base)) if localns is None else localns
|
||||
if localns is None and globalns is None:
|
||||
# This is surprising, but required. Before Python 3.10,
|
||||
# get_type_hints only evaluated the globalns of
|
||||
# a class. To maintain backwards compatibility, we reverse
|
||||
# the globalns and localns order so that eval() looks into
|
||||
# *base_globals* first rather than *base_locals*.
|
||||
# This only affects ForwardRefs.
|
||||
base_globals, base_locals = base_locals, base_globals
|
||||
for name, value in ann.items():
|
||||
if value is None:
|
||||
value = type(None)
|
||||
if isinstance(value, str):
|
||||
value = _make_forward_ref(value, is_argument=False, is_class=True)
|
||||
|
||||
value = eval_type_backport(value, base_globals, base_locals)
|
||||
hints[name] = value
|
||||
if not include_extras and hasattr(typing, '_strip_annotations'):
|
||||
return {
|
||||
k: typing._strip_annotations(t) # type: ignore
|
||||
for k, t in hints.items()
|
||||
}
|
||||
else:
|
||||
return hints
|
||||
|
||||
if globalns is None:
|
||||
if isinstance(obj, types.ModuleType):
|
||||
globalns = obj.__dict__
|
||||
else:
|
||||
nsobj = obj
|
||||
# Find globalns for the unwrapped object.
|
||||
while hasattr(nsobj, '__wrapped__'):
|
||||
nsobj = nsobj.__wrapped__
|
||||
globalns = getattr(nsobj, '__globals__', {})
|
||||
if localns is None:
|
||||
localns = globalns
|
||||
elif localns is None:
|
||||
localns = globalns
|
||||
hints = getattr(obj, '__annotations__', None)
|
||||
if hints is None:
|
||||
# Return empty annotations for something that _could_ have them.
|
||||
if isinstance(obj, typing._allowed_types): # type: ignore
|
||||
return {}
|
||||
else:
|
||||
raise TypeError(f'{obj!r} is not a module, class, method, ' 'or function.')
|
||||
defaults = typing._get_defaults(obj) # type: ignore
|
||||
hints = dict(hints)
|
||||
for name, value in hints.items():
|
||||
if value is None:
|
||||
value = type(None)
|
||||
if isinstance(value, str):
|
||||
# class-level forward refs were handled above, this must be either
|
||||
# a module-level annotation or a function argument annotation
|
||||
|
||||
value = _make_forward_ref(
|
||||
value,
|
||||
is_argument=not isinstance(obj, types.ModuleType),
|
||||
is_class=False,
|
||||
)
|
||||
value = eval_type_backport(value, globalns, localns)
|
||||
if name in defaults and defaults[name] is None:
|
||||
value = typing.Optional[value]
|
||||
hints[name] = value
|
||||
return hints if include_extras else {k: typing._strip_annotations(t) for k, t in hints.items()} # type: ignore
|
||||
@@ -0,0 +1,389 @@
|
||||
"""Bucket of reusable internal utilities.
|
||||
|
||||
This should be reduced as much as possible with functions only used in one place, moved to that place.
|
||||
"""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses
|
||||
import keyword
|
||||
import typing
|
||||
import weakref
|
||||
from collections import OrderedDict, defaultdict, deque
|
||||
from copy import deepcopy
|
||||
from functools import cached_property
|
||||
from inspect import Parameter
|
||||
from itertools import zip_longest
|
||||
from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
|
||||
from typing import Any, Callable, Mapping, TypeVar
|
||||
|
||||
from typing_extensions import TypeAlias, TypeGuard
|
||||
|
||||
from . import _repr, _typing_extra
|
||||
from ._import_utils import import_cached_base_model
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
MappingIntStrAny: TypeAlias = 'typing.Mapping[int, Any] | typing.Mapping[str, Any]'
|
||||
AbstractSetIntStr: TypeAlias = 'typing.AbstractSet[int] | typing.AbstractSet[str]'
|
||||
from ..main import BaseModel
|
||||
|
||||
|
||||
# these are types that are returned unchanged by deepcopy
|
||||
IMMUTABLE_NON_COLLECTIONS_TYPES: set[type[Any]] = {
|
||||
int,
|
||||
float,
|
||||
complex,
|
||||
str,
|
||||
bool,
|
||||
bytes,
|
||||
type,
|
||||
_typing_extra.NoneType,
|
||||
FunctionType,
|
||||
BuiltinFunctionType,
|
||||
LambdaType,
|
||||
weakref.ref,
|
||||
CodeType,
|
||||
# note: including ModuleType will differ from behaviour of deepcopy by not producing error.
|
||||
# It might be not a good idea in general, but considering that this function used only internally
|
||||
# against default values of fields, this will allow to actually have a field with module as default value
|
||||
ModuleType,
|
||||
NotImplemented.__class__,
|
||||
Ellipsis.__class__,
|
||||
}
|
||||
|
||||
# these are types that if empty, might be copied with simple copy() instead of deepcopy()
|
||||
BUILTIN_COLLECTIONS: set[type[Any]] = {
|
||||
list,
|
||||
set,
|
||||
tuple,
|
||||
frozenset,
|
||||
dict,
|
||||
OrderedDict,
|
||||
defaultdict,
|
||||
deque,
|
||||
}
|
||||
|
||||
|
||||
def can_be_positional(param: Parameter) -> bool:
|
||||
"""Return whether the parameter accepts a positional argument.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
def func(a, /, b, *, c):
|
||||
pass
|
||||
|
||||
params = inspect.signature(func).parameters
|
||||
can_be_positional(params['a'])
|
||||
#> True
|
||||
can_be_positional(params['b'])
|
||||
#> True
|
||||
can_be_positional(params['c'])
|
||||
#> False
|
||||
```
|
||||
"""
|
||||
return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
|
||||
|
||||
|
||||
def sequence_like(v: Any) -> bool:
|
||||
return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque))
|
||||
|
||||
|
||||
def lenient_isinstance(o: Any, class_or_tuple: type[Any] | tuple[type[Any], ...] | None) -> bool: # pragma: no cover
|
||||
try:
|
||||
return isinstance(o, class_or_tuple) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def lenient_issubclass(cls: Any, class_or_tuple: Any) -> bool: # pragma: no cover
|
||||
try:
|
||||
return isinstance(cls, type) and issubclass(cls, class_or_tuple)
|
||||
except TypeError:
|
||||
if isinstance(cls, _typing_extra.WithArgsTypes):
|
||||
return False
|
||||
raise # pragma: no cover
|
||||
|
||||
|
||||
def is_model_class(cls: Any) -> TypeGuard[type[BaseModel]]:
|
||||
"""Returns true if cls is a _proper_ subclass of BaseModel, and provides proper type-checking,
|
||||
unlike raw calls to lenient_issubclass.
|
||||
"""
|
||||
BaseModel = import_cached_base_model()
|
||||
|
||||
return lenient_issubclass(cls, BaseModel) and cls is not BaseModel
|
||||
|
||||
|
||||
def is_valid_identifier(identifier: str) -> bool:
|
||||
"""Checks that a string is a valid identifier and not a Python keyword.
|
||||
:param identifier: The identifier to test.
|
||||
:return: True if the identifier is valid.
|
||||
"""
|
||||
return identifier.isidentifier() and not keyword.iskeyword(identifier)
|
||||
|
||||
|
||||
KeyType = TypeVar('KeyType')
|
||||
|
||||
|
||||
def deep_update(mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]) -> dict[KeyType, Any]:
|
||||
updated_mapping = mapping.copy()
|
||||
for updating_mapping in updating_mappings:
|
||||
for k, v in updating_mapping.items():
|
||||
if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict):
|
||||
updated_mapping[k] = deep_update(updated_mapping[k], v)
|
||||
else:
|
||||
updated_mapping[k] = v
|
||||
return updated_mapping
|
||||
|
||||
|
||||
def update_not_none(mapping: dict[Any, Any], **update: Any) -> None:
|
||||
mapping.update({k: v for k, v in update.items() if v is not None})
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def unique_list(
|
||||
input_list: list[T] | tuple[T, ...],
|
||||
*,
|
||||
name_factory: typing.Callable[[T], str] = str,
|
||||
) -> list[T]:
|
||||
"""Make a list unique while maintaining order.
|
||||
We update the list if another one with the same name is set
|
||||
(e.g. model validator overridden in subclass).
|
||||
"""
|
||||
result: list[T] = []
|
||||
result_names: list[str] = []
|
||||
for v in input_list:
|
||||
v_name = name_factory(v)
|
||||
if v_name not in result_names:
|
||||
result_names.append(v_name)
|
||||
result.append(v)
|
||||
else:
|
||||
result[result_names.index(v_name)] = v
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ValueItems(_repr.Representation):
|
||||
"""Class for more convenient calculation of excluded or included fields on values."""
|
||||
|
||||
__slots__ = ('_items', '_type')
|
||||
|
||||
def __init__(self, value: Any, items: AbstractSetIntStr | MappingIntStrAny) -> None:
|
||||
items = self._coerce_items(items)
|
||||
|
||||
if isinstance(value, (list, tuple)):
|
||||
items = self._normalize_indexes(items, len(value)) # type: ignore
|
||||
|
||||
self._items: MappingIntStrAny = items # type: ignore
|
||||
|
||||
def is_excluded(self, item: Any) -> bool:
|
||||
"""Check if item is fully excluded.
|
||||
|
||||
:param item: key or index of a value
|
||||
"""
|
||||
return self.is_true(self._items.get(item))
|
||||
|
||||
def is_included(self, item: Any) -> bool:
|
||||
"""Check if value is contained in self._items.
|
||||
|
||||
:param item: key or index of value
|
||||
"""
|
||||
return item in self._items
|
||||
|
||||
def for_element(self, e: int | str) -> AbstractSetIntStr | MappingIntStrAny | None:
|
||||
""":param e: key or index of element on value
|
||||
:return: raw values for element if self._items is dict and contain needed element
|
||||
"""
|
||||
item = self._items.get(e) # type: ignore
|
||||
return item if not self.is_true(item) else None
|
||||
|
||||
def _normalize_indexes(self, items: MappingIntStrAny, v_length: int) -> dict[int | str, Any]:
|
||||
""":param items: dict or set of indexes which will be normalized
|
||||
:param v_length: length of sequence indexes of which will be
|
||||
|
||||
>>> self._normalize_indexes({0: True, -2: True, -1: True}, 4)
|
||||
{0: True, 2: True, 3: True}
|
||||
>>> self._normalize_indexes({'__all__': True}, 4)
|
||||
{0: True, 1: True, 2: True, 3: True}
|
||||
"""
|
||||
normalized_items: dict[int | str, Any] = {}
|
||||
all_items = None
|
||||
for i, v in items.items():
|
||||
if not (isinstance(v, typing.Mapping) or isinstance(v, typing.AbstractSet) or self.is_true(v)):
|
||||
raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}')
|
||||
if i == '__all__':
|
||||
all_items = self._coerce_value(v)
|
||||
continue
|
||||
if not isinstance(i, int):
|
||||
raise TypeError(
|
||||
'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: '
|
||||
'expected integer keys or keyword "__all__"'
|
||||
)
|
||||
normalized_i = v_length + i if i < 0 else i
|
||||
normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i))
|
||||
|
||||
if not all_items:
|
||||
return normalized_items
|
||||
if self.is_true(all_items):
|
||||
for i in range(v_length):
|
||||
normalized_items.setdefault(i, ...)
|
||||
return normalized_items
|
||||
for i in range(v_length):
|
||||
normalized_item = normalized_items.setdefault(i, {})
|
||||
if not self.is_true(normalized_item):
|
||||
normalized_items[i] = self.merge(all_items, normalized_item)
|
||||
return normalized_items
|
||||
|
||||
@classmethod
|
||||
def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any:
|
||||
"""Merge a `base` item with an `override` item.
|
||||
|
||||
Both `base` and `override` are converted to dictionaries if possible.
|
||||
Sets are converted to dictionaries with the sets entries as keys and
|
||||
Ellipsis as values.
|
||||
|
||||
Each key-value pair existing in `base` is merged with `override`,
|
||||
while the rest of the key-value pairs are updated recursively with this function.
|
||||
|
||||
Merging takes place based on the "union" of keys if `intersect` is
|
||||
set to `False` (default) and on the intersection of keys if
|
||||
`intersect` is set to `True`.
|
||||
"""
|
||||
override = cls._coerce_value(override)
|
||||
base = cls._coerce_value(base)
|
||||
if override is None:
|
||||
return base
|
||||
if cls.is_true(base) or base is None:
|
||||
return override
|
||||
if cls.is_true(override):
|
||||
return base if intersect else override
|
||||
|
||||
# intersection or union of keys while preserving ordering:
|
||||
if intersect:
|
||||
merge_keys = [k for k in base if k in override] + [k for k in override if k in base]
|
||||
else:
|
||||
merge_keys = list(base) + [k for k in override if k not in base]
|
||||
|
||||
merged: dict[int | str, Any] = {}
|
||||
for k in merge_keys:
|
||||
merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect)
|
||||
if merged_item is not None:
|
||||
merged[k] = merged_item
|
||||
|
||||
return merged
|
||||
|
||||
@staticmethod
|
||||
def _coerce_items(items: AbstractSetIntStr | MappingIntStrAny) -> MappingIntStrAny:
|
||||
if isinstance(items, typing.Mapping):
|
||||
pass
|
||||
elif isinstance(items, typing.AbstractSet):
|
||||
items = dict.fromkeys(items, ...) # type: ignore
|
||||
else:
|
||||
class_name = getattr(items, '__class__', '???')
|
||||
raise TypeError(f'Unexpected type of exclude value {class_name}')
|
||||
return items # type: ignore
|
||||
|
||||
@classmethod
|
||||
def _coerce_value(cls, value: Any) -> Any:
|
||||
if value is None or cls.is_true(value):
|
||||
return value
|
||||
return cls._coerce_items(value)
|
||||
|
||||
@staticmethod
|
||||
def is_true(v: Any) -> bool:
|
||||
return v is True or v is ...
|
||||
|
||||
def __repr_args__(self) -> _repr.ReprArgs:
|
||||
return [(None, self._items)]
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
|
||||
def LazyClassAttribute(name: str, get_value: Callable[[], T]) -> T: ...
|
||||
|
||||
else:
|
||||
|
||||
class LazyClassAttribute:
|
||||
"""A descriptor exposing an attribute only accessible on a class (hidden from instances).
|
||||
|
||||
The attribute is lazily computed and cached during the first access.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, get_value: Callable[[], Any]) -> None:
|
||||
self.name = name
|
||||
self.get_value = get_value
|
||||
|
||||
@cached_property
|
||||
def value(self) -> Any:
|
||||
return self.get_value()
|
||||
|
||||
def __get__(self, instance: Any, owner: type[Any]) -> None:
|
||||
if instance is None:
|
||||
return self.value
|
||||
raise AttributeError(f'{self.name!r} attribute of {owner.__name__!r} is class-only')
|
||||
|
||||
|
||||
Obj = TypeVar('Obj')
|
||||
|
||||
|
||||
def smart_deepcopy(obj: Obj) -> Obj:
|
||||
"""Return type as is for immutable built-in types
|
||||
Use obj.copy() for built-in empty collections
|
||||
Use copy.deepcopy() for non-empty collections and unknown objects.
|
||||
"""
|
||||
obj_type = obj.__class__
|
||||
if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES:
|
||||
return obj # fastest case: obj is immutable and not collection therefore will not be copied anyway
|
||||
try:
|
||||
if not obj and obj_type in BUILTIN_COLLECTIONS:
|
||||
# faster way for empty collections, no need to copy its members
|
||||
return obj if obj_type is tuple else obj.copy() # tuple doesn't have copy method # type: ignore
|
||||
except (TypeError, ValueError, RuntimeError):
|
||||
# do we really dare to catch ALL errors? Seems a bit risky
|
||||
pass
|
||||
|
||||
return deepcopy(obj) # slowest way when we actually might need a deepcopy
|
||||
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bool:
|
||||
"""Check that the items of `left` are the same objects as those in `right`.
|
||||
|
||||
>>> a, b = object(), object()
|
||||
>>> all_identical([a, b, a], [a, b, a])
|
||||
True
|
||||
>>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical"
|
||||
False
|
||||
"""
|
||||
for left_item, right_item in zip_longest(left, right, fillvalue=_SENTINEL):
|
||||
if left_item is not right_item:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SafeGetItemProxy:
|
||||
"""Wrapper redirecting `__getitem__` to `get` with a sentinel value as default
|
||||
|
||||
This makes is safe to use in `operator.itemgetter` when some keys may be missing
|
||||
"""
|
||||
|
||||
# Define __slots__manually for performances
|
||||
# @dataclasses.dataclass() only support slots=True in python>=3.10
|
||||
__slots__ = ('wrapped',)
|
||||
|
||||
wrapped: Mapping[str, Any]
|
||||
|
||||
def __getitem__(self, key: str, /) -> Any:
|
||||
return self.wrapped.get(key, _SENTINEL)
|
||||
|
||||
# required to pass the object to operator.itemgetter() instances due to a quirk of typeshed
|
||||
# https://github.com/python/mypy/issues/13713
|
||||
# https://github.com/python/typeshed/pull/8785
|
||||
# Since this is typing-only, hide it in a typing.TYPE_CHECKING block
|
||||
if typing.TYPE_CHECKING:
|
||||
|
||||
def __contains__(self, key: str, /) -> bool:
|
||||
return self.wrapped.__contains__(key)
|
||||
@@ -0,0 +1,115 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from functools import partial
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
import pydantic_core
|
||||
|
||||
from ..config import ConfigDict
|
||||
from ..plugin._schema_validator import create_schema_validator
|
||||
from ._config import ConfigWrapper
|
||||
from ._generate_schema import GenerateSchema, ValidateCallSupportedTypes
|
||||
from ._namespace_utils import MappingNamespace, NsResolver, ns_for_function
|
||||
|
||||
|
||||
def extract_function_name(func: ValidateCallSupportedTypes) -> str:
|
||||
"""Extract the name of a `ValidateCallSupportedTypes` object."""
|
||||
return f'partial({func.func.__name__})' if isinstance(func, functools.partial) else func.__name__
|
||||
|
||||
|
||||
def extract_function_qualname(func: ValidateCallSupportedTypes) -> str:
|
||||
"""Extract the qualname of a `ValidateCallSupportedTypes` object."""
|
||||
return f'partial({func.func.__qualname__})' if isinstance(func, functools.partial) else func.__qualname__
|
||||
|
||||
|
||||
def update_wrapper_attributes(wrapped: ValidateCallSupportedTypes, wrapper: Callable[..., Any]):
|
||||
"""Update the `wrapper` function with the attributes of the `wrapped` function. Return the updated function."""
|
||||
if inspect.iscoroutinefunction(wrapped):
|
||||
|
||||
@functools.wraps(wrapped)
|
||||
async def wrapper_function(*args, **kwargs): # type: ignore
|
||||
return await wrapper(*args, **kwargs)
|
||||
else:
|
||||
|
||||
@functools.wraps(wrapped)
|
||||
def wrapper_function(*args, **kwargs):
|
||||
return wrapper(*args, **kwargs)
|
||||
|
||||
# We need to manually update this because `partial` object has no `__name__` and `__qualname__`.
|
||||
wrapper_function.__name__ = extract_function_name(wrapped)
|
||||
wrapper_function.__qualname__ = extract_function_qualname(wrapped)
|
||||
wrapper_function.raw_function = wrapped # type: ignore
|
||||
|
||||
return wrapper_function
|
||||
|
||||
|
||||
class ValidateCallWrapper:
|
||||
"""This is a wrapper around a function that validates the arguments passed to it, and optionally the return value."""
|
||||
|
||||
__slots__ = ('__pydantic_validator__', '__return_pydantic_validator__')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
function: ValidateCallSupportedTypes,
|
||||
config: ConfigDict | None,
|
||||
validate_return: bool,
|
||||
parent_namespace: MappingNamespace | None,
|
||||
) -> None:
|
||||
if isinstance(function, partial):
|
||||
schema_type = function.func
|
||||
module = function.func.__module__
|
||||
else:
|
||||
schema_type = function
|
||||
module = function.__module__
|
||||
qualname = extract_function_qualname(function)
|
||||
|
||||
ns_resolver = NsResolver(namespaces_tuple=ns_for_function(schema_type, parent_namespace=parent_namespace))
|
||||
|
||||
config_wrapper = ConfigWrapper(config)
|
||||
gen_schema = GenerateSchema(config_wrapper, ns_resolver)
|
||||
schema = gen_schema.clean_schema(gen_schema.generate_schema(function))
|
||||
core_config = config_wrapper.core_config(title=qualname)
|
||||
|
||||
self.__pydantic_validator__ = create_schema_validator(
|
||||
schema,
|
||||
schema_type,
|
||||
module,
|
||||
qualname,
|
||||
'validate_call',
|
||||
core_config,
|
||||
config_wrapper.plugin_settings,
|
||||
)
|
||||
|
||||
if validate_return:
|
||||
signature = inspect.signature(function)
|
||||
return_type = signature.return_annotation if signature.return_annotation is not signature.empty else Any
|
||||
gen_schema = GenerateSchema(config_wrapper, ns_resolver)
|
||||
schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
|
||||
validator = create_schema_validator(
|
||||
schema,
|
||||
schema_type,
|
||||
module,
|
||||
qualname,
|
||||
'validate_call',
|
||||
core_config,
|
||||
config_wrapper.plugin_settings,
|
||||
)
|
||||
if inspect.iscoroutinefunction(function):
|
||||
|
||||
async def return_val_wrapper(aw: Awaitable[Any]) -> None:
|
||||
return validator.validate_python(await aw)
|
||||
|
||||
self.__return_pydantic_validator__ = return_val_wrapper
|
||||
else:
|
||||
self.__return_pydantic_validator__ = validator.validate_python
|
||||
else:
|
||||
self.__return_pydantic_validator__ = None
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
|
||||
if self.__return_pydantic_validator__:
|
||||
return self.__return_pydantic_validator__(res)
|
||||
else:
|
||||
return res
|
||||
@@ -0,0 +1,424 @@
|
||||
"""Validator functions for standard library types.
|
||||
|
||||
Import of this module is deferred since it contains imports of many standard library modules.
|
||||
"""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import math
|
||||
import re
|
||||
import typing
|
||||
from decimal import Decimal
|
||||
from fractions import Fraction
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from pydantic_core import PydanticCustomError, core_schema
|
||||
from pydantic_core._pydantic_core import PydanticKnownError
|
||||
|
||||
|
||||
def sequence_validator(
|
||||
input_value: typing.Sequence[Any],
|
||||
/,
|
||||
validator: core_schema.ValidatorFunctionWrapHandler,
|
||||
) -> typing.Sequence[Any]:
|
||||
"""Validator for `Sequence` types, isinstance(v, Sequence) has already been called."""
|
||||
value_type = type(input_value)
|
||||
|
||||
# We don't accept any plain string as a sequence
|
||||
# Relevant issue: https://github.com/pydantic/pydantic/issues/5595
|
||||
if issubclass(value_type, (str, bytes)):
|
||||
raise PydanticCustomError(
|
||||
'sequence_str',
|
||||
"'{type_name}' instances are not allowed as a Sequence value",
|
||||
{'type_name': value_type.__name__},
|
||||
)
|
||||
|
||||
# TODO: refactor sequence validation to validate with either a list or a tuple
|
||||
# schema, depending on the type of the value.
|
||||
# Additionally, we should be able to remove one of either this validator or the
|
||||
# SequenceValidator in _std_types_schema.py (preferably this one, while porting over some logic).
|
||||
# Effectively, a refactor for sequence validation is needed.
|
||||
if value_type is tuple:
|
||||
input_value = list(input_value)
|
||||
|
||||
v_list = validator(input_value)
|
||||
|
||||
# the rest of the logic is just re-creating the original type from `v_list`
|
||||
if value_type is list:
|
||||
return v_list
|
||||
elif issubclass(value_type, range):
|
||||
# return the list as we probably can't re-create the range
|
||||
return v_list
|
||||
elif value_type is tuple:
|
||||
return tuple(v_list)
|
||||
else:
|
||||
# best guess at how to re-create the original type, more custom construction logic might be required
|
||||
return value_type(v_list) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def import_string(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return _import_string_logic(value)
|
||||
except ImportError as e:
|
||||
raise PydanticCustomError('import_error', 'Invalid python path: {error}', {'error': str(e)}) from e
|
||||
else:
|
||||
# otherwise we just return the value and let the next validator do the rest of the work
|
||||
return value
|
||||
|
||||
|
||||
def _import_string_logic(dotted_path: str) -> Any:
|
||||
"""Inspired by uvicorn — dotted paths should include a colon before the final item if that item is not a module.
|
||||
(This is necessary to distinguish between a submodule and an attribute when there is a conflict.).
|
||||
|
||||
If the dotted path does not include a colon and the final item is not a valid module, importing as an attribute
|
||||
rather than a submodule will be attempted automatically.
|
||||
|
||||
So, for example, the following values of `dotted_path` result in the following returned values:
|
||||
* 'collections': <module 'collections'>
|
||||
* 'collections.abc': <module 'collections.abc'>
|
||||
* 'collections.abc:Mapping': <class 'collections.abc.Mapping'>
|
||||
* `collections.abc.Mapping`: <class 'collections.abc.Mapping'> (though this is a bit slower than the previous line)
|
||||
|
||||
An error will be raised under any of the following scenarios:
|
||||
* `dotted_path` contains more than one colon (e.g., 'collections:abc:Mapping')
|
||||
* the substring of `dotted_path` before the colon is not a valid module in the environment (e.g., '123:Mapping')
|
||||
* the substring of `dotted_path` after the colon is not an attribute of the module (e.g., 'collections:abc123')
|
||||
"""
|
||||
from importlib import import_module
|
||||
|
||||
components = dotted_path.strip().split(':')
|
||||
if len(components) > 2:
|
||||
raise ImportError(f"Import strings should have at most one ':'; received {dotted_path!r}")
|
||||
|
||||
module_path = components[0]
|
||||
if not module_path:
|
||||
raise ImportError(f'Import strings should have a nonempty module name; received {dotted_path!r}')
|
||||
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
except ModuleNotFoundError as e:
|
||||
if '.' in module_path:
|
||||
# Check if it would be valid if the final item was separated from its module with a `:`
|
||||
maybe_module_path, maybe_attribute = dotted_path.strip().rsplit('.', 1)
|
||||
try:
|
||||
return _import_string_logic(f'{maybe_module_path}:{maybe_attribute}')
|
||||
except ImportError:
|
||||
pass
|
||||
raise ImportError(f'No module named {module_path!r}') from e
|
||||
raise e
|
||||
|
||||
if len(components) > 1:
|
||||
attribute = components[1]
|
||||
try:
|
||||
return getattr(module, attribute)
|
||||
except AttributeError as e:
|
||||
raise ImportError(f'cannot import name {attribute!r} from {module_path!r}') from e
|
||||
else:
|
||||
return module
|
||||
|
||||
|
||||
def pattern_either_validator(input_value: Any, /) -> typing.Pattern[Any]:
|
||||
if isinstance(input_value, typing.Pattern):
|
||||
return input_value
|
||||
elif isinstance(input_value, (str, bytes)):
|
||||
# todo strict mode
|
||||
return compile_pattern(input_value) # type: ignore
|
||||
else:
|
||||
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
|
||||
|
||||
|
||||
def pattern_str_validator(input_value: Any, /) -> typing.Pattern[str]:
|
||||
if isinstance(input_value, typing.Pattern):
|
||||
if isinstance(input_value.pattern, str):
|
||||
return input_value
|
||||
else:
|
||||
raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
|
||||
elif isinstance(input_value, str):
|
||||
return compile_pattern(input_value)
|
||||
elif isinstance(input_value, bytes):
|
||||
raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
|
||||
else:
|
||||
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
|
||||
|
||||
|
||||
def pattern_bytes_validator(input_value: Any, /) -> typing.Pattern[bytes]:
|
||||
if isinstance(input_value, typing.Pattern):
|
||||
if isinstance(input_value.pattern, bytes):
|
||||
return input_value
|
||||
else:
|
||||
raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
|
||||
elif isinstance(input_value, bytes):
|
||||
return compile_pattern(input_value)
|
||||
elif isinstance(input_value, str):
|
||||
raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
|
||||
else:
|
||||
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
|
||||
|
||||
|
||||
PatternType = typing.TypeVar('PatternType', str, bytes)
|
||||
|
||||
|
||||
def compile_pattern(pattern: PatternType) -> typing.Pattern[PatternType]:
|
||||
try:
|
||||
return re.compile(pattern)
|
||||
except re.error:
|
||||
raise PydanticCustomError('pattern_regex', 'Input should be a valid regular expression')
|
||||
|
||||
|
||||
def ip_v4_address_validator(input_value: Any, /) -> IPv4Address:
|
||||
if isinstance(input_value, IPv4Address):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv4Address(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v4_address', 'Input is not a valid IPv4 address')
|
||||
|
||||
|
||||
def ip_v6_address_validator(input_value: Any, /) -> IPv6Address:
|
||||
if isinstance(input_value, IPv6Address):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv6Address(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v6_address', 'Input is not a valid IPv6 address')
|
||||
|
||||
|
||||
def ip_v4_network_validator(input_value: Any, /) -> IPv4Network:
|
||||
"""Assume IPv4Network initialised with a default `strict` argument.
|
||||
|
||||
See more:
|
||||
https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network
|
||||
"""
|
||||
if isinstance(input_value, IPv4Network):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv4Network(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v4_network', 'Input is not a valid IPv4 network')
|
||||
|
||||
|
||||
def ip_v6_network_validator(input_value: Any, /) -> IPv6Network:
|
||||
"""Assume IPv6Network initialised with a default `strict` argument.
|
||||
|
||||
See more:
|
||||
https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network
|
||||
"""
|
||||
if isinstance(input_value, IPv6Network):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv6Network(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v6_network', 'Input is not a valid IPv6 network')
|
||||
|
||||
|
||||
def ip_v4_interface_validator(input_value: Any, /) -> IPv4Interface:
|
||||
if isinstance(input_value, IPv4Interface):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv4Interface(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v4_interface', 'Input is not a valid IPv4 interface')
|
||||
|
||||
|
||||
def ip_v6_interface_validator(input_value: Any, /) -> IPv6Interface:
|
||||
if isinstance(input_value, IPv6Interface):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv6Interface(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v6_interface', 'Input is not a valid IPv6 interface')
|
||||
|
||||
|
||||
def fraction_validator(input_value: Any, /) -> Fraction:
|
||||
if isinstance(input_value, Fraction):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return Fraction(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('fraction_parsing', 'Input is not a valid fraction')
|
||||
|
||||
|
||||
def forbid_inf_nan_check(x: Any) -> Any:
|
||||
if not math.isfinite(x):
|
||||
raise PydanticKnownError('finite_number')
|
||||
return x
|
||||
|
||||
|
||||
def _safe_repr(v: Any) -> int | float | str:
|
||||
"""The context argument for `PydanticKnownError` requires a number or str type, so we do a simple repr() coercion for types like timedelta.
|
||||
|
||||
See tests/test_types.py::test_annotated_metadata_any_order for some context.
|
||||
"""
|
||||
if isinstance(v, (int, float, str)):
|
||||
return v
|
||||
return repr(v)
|
||||
|
||||
|
||||
def greater_than_validator(x: Any, gt: Any) -> Any:
|
||||
try:
|
||||
if not (x > gt):
|
||||
raise PydanticKnownError('greater_than', {'gt': _safe_repr(gt)})
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'gt' to supplied value {x}")
|
||||
|
||||
|
||||
def greater_than_or_equal_validator(x: Any, ge: Any) -> Any:
|
||||
try:
|
||||
if not (x >= ge):
|
||||
raise PydanticKnownError('greater_than_equal', {'ge': _safe_repr(ge)})
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'ge' to supplied value {x}")
|
||||
|
||||
|
||||
def less_than_validator(x: Any, lt: Any) -> Any:
|
||||
try:
|
||||
if not (x < lt):
|
||||
raise PydanticKnownError('less_than', {'lt': _safe_repr(lt)})
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'lt' to supplied value {x}")
|
||||
|
||||
|
||||
def less_than_or_equal_validator(x: Any, le: Any) -> Any:
|
||||
try:
|
||||
if not (x <= le):
|
||||
raise PydanticKnownError('less_than_equal', {'le': _safe_repr(le)})
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'le' to supplied value {x}")
|
||||
|
||||
|
||||
def multiple_of_validator(x: Any, multiple_of: Any) -> Any:
|
||||
try:
|
||||
if x % multiple_of:
|
||||
raise PydanticKnownError('multiple_of', {'multiple_of': _safe_repr(multiple_of)})
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'multiple_of' to supplied value {x}")
|
||||
|
||||
|
||||
def min_length_validator(x: Any, min_length: Any) -> Any:
|
||||
try:
|
||||
if not (len(x) >= min_length):
|
||||
raise PydanticKnownError(
|
||||
'too_short', {'field_type': 'Value', 'min_length': min_length, 'actual_length': len(x)}
|
||||
)
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'min_length' to supplied value {x}")
|
||||
|
||||
|
||||
def max_length_validator(x: Any, max_length: Any) -> Any:
|
||||
try:
|
||||
if len(x) > max_length:
|
||||
raise PydanticKnownError(
|
||||
'too_long',
|
||||
{'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)},
|
||||
)
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'max_length' to supplied value {x}")
|
||||
|
||||
|
||||
def _extract_decimal_digits_info(decimal: Decimal) -> tuple[int, int]:
|
||||
"""Compute the total number of digits and decimal places for a given [`Decimal`][decimal.Decimal] instance.
|
||||
|
||||
This function handles both normalized and non-normalized Decimal instances.
|
||||
Example: Decimal('1.230') -> 4 digits, 3 decimal places
|
||||
|
||||
Args:
|
||||
decimal (Decimal): The decimal number to analyze.
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: A tuple containing the number of decimal places and total digits.
|
||||
|
||||
Though this could be divided into two separate functions, the logic is easier to follow if we couple the computation
|
||||
of the number of decimals and digits together.
|
||||
"""
|
||||
decimal_tuple = decimal.as_tuple()
|
||||
if not isinstance(decimal_tuple.exponent, int):
|
||||
raise TypeError(f'Unable to extract decimal digits info from supplied value {decimal}')
|
||||
exponent = decimal_tuple.exponent
|
||||
num_digits = len(decimal_tuple.digits)
|
||||
|
||||
if exponent >= 0:
|
||||
# A positive exponent adds that many trailing zeros
|
||||
# Ex: digit_tuple=(1, 2, 3), exponent=2 -> 12300 -> 0 decimal places, 5 digits
|
||||
num_digits += exponent
|
||||
decimal_places = 0
|
||||
else:
|
||||
# If the absolute value of the negative exponent is larger than the
|
||||
# number of digits, then it's the same as the number of digits,
|
||||
# because it'll consume all the digits in digit_tuple and then
|
||||
# add abs(exponent) - len(digit_tuple) leading zeros after the decimal point.
|
||||
# Ex: digit_tuple=(1, 2, 3), exponent=-2 -> 1.23 -> 2 decimal places, 3 digits
|
||||
# Ex: digit_tuple=(1, 2, 3), exponent=-4 -> 0.0123 -> 4 decimal places, 4 digits
|
||||
decimal_places = abs(exponent)
|
||||
num_digits = max(num_digits, decimal_places)
|
||||
|
||||
return decimal_places, num_digits
|
||||
|
||||
|
||||
def max_digits_validator(x: Any, max_digits: Any) -> Any:
|
||||
_, num_digits = _extract_decimal_digits_info(x)
|
||||
_, normalized_num_digits = _extract_decimal_digits_info(x.normalize())
|
||||
|
||||
try:
|
||||
if (num_digits > max_digits) and (normalized_num_digits > max_digits):
|
||||
raise PydanticKnownError(
|
||||
'decimal_max_digits',
|
||||
{'max_digits': max_digits},
|
||||
)
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'max_digits' to supplied value {x}")
|
||||
|
||||
|
||||
def decimal_places_validator(x: Any, decimal_places: Any) -> Any:
|
||||
decimal_places_, _ = _extract_decimal_digits_info(x)
|
||||
normalized_decimal_places, _ = _extract_decimal_digits_info(x.normalize())
|
||||
|
||||
try:
|
||||
if (decimal_places_ > decimal_places) and (normalized_decimal_places > decimal_places):
|
||||
raise PydanticKnownError(
|
||||
'decimal_max_places',
|
||||
{'decimal_places': decimal_places},
|
||||
)
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'decimal_places' to supplied value {x}")
|
||||
|
||||
|
||||
NUMERIC_VALIDATOR_LOOKUP: dict[str, Callable] = {
|
||||
'gt': greater_than_validator,
|
||||
'ge': greater_than_or_equal_validator,
|
||||
'lt': less_than_validator,
|
||||
'le': less_than_or_equal_validator,
|
||||
'multiple_of': multiple_of_validator,
|
||||
'min_length': min_length_validator,
|
||||
'max_length': max_length_validator,
|
||||
'max_digits': max_digits_validator,
|
||||
'decimal_places': decimal_places_validator,
|
||||
}
|
||||
|
||||
IpType = Union[IPv4Address, IPv6Address, IPv4Network, IPv6Network, IPv4Interface, IPv6Interface]
|
||||
|
||||
IP_VALIDATOR_LOOKUP: dict[type[IpType], Callable] = {
|
||||
IPv4Address: ip_v4_address_validator,
|
||||
IPv6Address: ip_v6_address_validator,
|
||||
IPv4Network: ip_v4_network_validator,
|
||||
IPv6Network: ip_v6_network_validator,
|
||||
IPv4Interface: ip_v4_interface_validator,
|
||||
IPv6Interface: ip_v6_interface_validator,
|
||||
}
|
||||
@@ -0,0 +1,308 @@
|
||||
import sys
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
from .version import version_short
|
||||
|
||||
MOVED_IN_V2 = {
|
||||
'pydantic.utils:version_info': 'pydantic.version:version_info',
|
||||
'pydantic.error_wrappers:ValidationError': 'pydantic:ValidationError',
|
||||
'pydantic.utils:to_camel': 'pydantic.alias_generators:to_pascal',
|
||||
'pydantic.utils:to_lower_camel': 'pydantic.alias_generators:to_camel',
|
||||
'pydantic:PyObject': 'pydantic.types:ImportString',
|
||||
'pydantic.types:PyObject': 'pydantic.types:ImportString',
|
||||
'pydantic.generics:GenericModel': 'pydantic.BaseModel',
|
||||
}
|
||||
|
||||
DEPRECATED_MOVED_IN_V2 = {
|
||||
'pydantic.tools:schema_of': 'pydantic.deprecated.tools:schema_of',
|
||||
'pydantic.tools:parse_obj_as': 'pydantic.deprecated.tools:parse_obj_as',
|
||||
'pydantic.tools:schema_json_of': 'pydantic.deprecated.tools:schema_json_of',
|
||||
'pydantic.json:pydantic_encoder': 'pydantic.deprecated.json:pydantic_encoder',
|
||||
'pydantic:validate_arguments': 'pydantic.deprecated.decorator:validate_arguments',
|
||||
'pydantic.json:custom_pydantic_encoder': 'pydantic.deprecated.json:custom_pydantic_encoder',
|
||||
'pydantic.json:timedelta_isoformat': 'pydantic.deprecated.json:timedelta_isoformat',
|
||||
'pydantic.decorator:validate_arguments': 'pydantic.deprecated.decorator:validate_arguments',
|
||||
'pydantic.class_validators:validator': 'pydantic.deprecated.class_validators:validator',
|
||||
'pydantic.class_validators:root_validator': 'pydantic.deprecated.class_validators:root_validator',
|
||||
'pydantic.config:BaseConfig': 'pydantic.deprecated.config:BaseConfig',
|
||||
'pydantic.config:Extra': 'pydantic.deprecated.config:Extra',
|
||||
}
|
||||
|
||||
REDIRECT_TO_V1 = {
|
||||
f'pydantic.utils:{obj}': f'pydantic.v1.utils:{obj}'
|
||||
for obj in (
|
||||
'deep_update',
|
||||
'GetterDict',
|
||||
'lenient_issubclass',
|
||||
'lenient_isinstance',
|
||||
'is_valid_field',
|
||||
'update_not_none',
|
||||
'import_string',
|
||||
'Representation',
|
||||
'ROOT_KEY',
|
||||
'smart_deepcopy',
|
||||
'sequence_like',
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
REMOVED_IN_V2 = {
|
||||
'pydantic:ConstrainedBytes',
|
||||
'pydantic:ConstrainedDate',
|
||||
'pydantic:ConstrainedDecimal',
|
||||
'pydantic:ConstrainedFloat',
|
||||
'pydantic:ConstrainedFrozenSet',
|
||||
'pydantic:ConstrainedInt',
|
||||
'pydantic:ConstrainedList',
|
||||
'pydantic:ConstrainedSet',
|
||||
'pydantic:ConstrainedStr',
|
||||
'pydantic:JsonWrapper',
|
||||
'pydantic:NoneBytes',
|
||||
'pydantic:NoneStr',
|
||||
'pydantic:NoneStrBytes',
|
||||
'pydantic:Protocol',
|
||||
'pydantic:Required',
|
||||
'pydantic:StrBytes',
|
||||
'pydantic:compiled',
|
||||
'pydantic.config:get_config',
|
||||
'pydantic.config:inherit_config',
|
||||
'pydantic.config:prepare_config',
|
||||
'pydantic:create_model_from_namedtuple',
|
||||
'pydantic:create_model_from_typeddict',
|
||||
'pydantic.dataclasses:create_pydantic_model_from_dataclass',
|
||||
'pydantic.dataclasses:make_dataclass_validator',
|
||||
'pydantic.dataclasses:set_validation',
|
||||
'pydantic.datetime_parse:parse_date',
|
||||
'pydantic.datetime_parse:parse_time',
|
||||
'pydantic.datetime_parse:parse_datetime',
|
||||
'pydantic.datetime_parse:parse_duration',
|
||||
'pydantic.error_wrappers:ErrorWrapper',
|
||||
'pydantic.errors:AnyStrMaxLengthError',
|
||||
'pydantic.errors:AnyStrMinLengthError',
|
||||
'pydantic.errors:ArbitraryTypeError',
|
||||
'pydantic.errors:BoolError',
|
||||
'pydantic.errors:BytesError',
|
||||
'pydantic.errors:CallableError',
|
||||
'pydantic.errors:ClassError',
|
||||
'pydantic.errors:ColorError',
|
||||
'pydantic.errors:ConfigError',
|
||||
'pydantic.errors:DataclassTypeError',
|
||||
'pydantic.errors:DateError',
|
||||
'pydantic.errors:DateNotInTheFutureError',
|
||||
'pydantic.errors:DateNotInThePastError',
|
||||
'pydantic.errors:DateTimeError',
|
||||
'pydantic.errors:DecimalError',
|
||||
'pydantic.errors:DecimalIsNotFiniteError',
|
||||
'pydantic.errors:DecimalMaxDigitsError',
|
||||
'pydantic.errors:DecimalMaxPlacesError',
|
||||
'pydantic.errors:DecimalWholeDigitsError',
|
||||
'pydantic.errors:DictError',
|
||||
'pydantic.errors:DurationError',
|
||||
'pydantic.errors:EmailError',
|
||||
'pydantic.errors:EnumError',
|
||||
'pydantic.errors:EnumMemberError',
|
||||
'pydantic.errors:ExtraError',
|
||||
'pydantic.errors:FloatError',
|
||||
'pydantic.errors:FrozenSetError',
|
||||
'pydantic.errors:FrozenSetMaxLengthError',
|
||||
'pydantic.errors:FrozenSetMinLengthError',
|
||||
'pydantic.errors:HashableError',
|
||||
'pydantic.errors:IPv4AddressError',
|
||||
'pydantic.errors:IPv4InterfaceError',
|
||||
'pydantic.errors:IPv4NetworkError',
|
||||
'pydantic.errors:IPv6AddressError',
|
||||
'pydantic.errors:IPv6InterfaceError',
|
||||
'pydantic.errors:IPv6NetworkError',
|
||||
'pydantic.errors:IPvAnyAddressError',
|
||||
'pydantic.errors:IPvAnyInterfaceError',
|
||||
'pydantic.errors:IPvAnyNetworkError',
|
||||
'pydantic.errors:IntEnumError',
|
||||
'pydantic.errors:IntegerError',
|
||||
'pydantic.errors:InvalidByteSize',
|
||||
'pydantic.errors:InvalidByteSizeUnit',
|
||||
'pydantic.errors:InvalidDiscriminator',
|
||||
'pydantic.errors:InvalidLengthForBrand',
|
||||
'pydantic.errors:JsonError',
|
||||
'pydantic.errors:JsonTypeError',
|
||||
'pydantic.errors:ListError',
|
||||
'pydantic.errors:ListMaxLengthError',
|
||||
'pydantic.errors:ListMinLengthError',
|
||||
'pydantic.errors:ListUniqueItemsError',
|
||||
'pydantic.errors:LuhnValidationError',
|
||||
'pydantic.errors:MissingDiscriminator',
|
||||
'pydantic.errors:MissingError',
|
||||
'pydantic.errors:NoneIsAllowedError',
|
||||
'pydantic.errors:NoneIsNotAllowedError',
|
||||
'pydantic.errors:NotDigitError',
|
||||
'pydantic.errors:NotNoneError',
|
||||
'pydantic.errors:NumberNotGeError',
|
||||
'pydantic.errors:NumberNotGtError',
|
||||
'pydantic.errors:NumberNotLeError',
|
||||
'pydantic.errors:NumberNotLtError',
|
||||
'pydantic.errors:NumberNotMultipleError',
|
||||
'pydantic.errors:PathError',
|
||||
'pydantic.errors:PathNotADirectoryError',
|
||||
'pydantic.errors:PathNotAFileError',
|
||||
'pydantic.errors:PathNotExistsError',
|
||||
'pydantic.errors:PatternError',
|
||||
'pydantic.errors:PyObjectError',
|
||||
'pydantic.errors:PydanticTypeError',
|
||||
'pydantic.errors:PydanticValueError',
|
||||
'pydantic.errors:SequenceError',
|
||||
'pydantic.errors:SetError',
|
||||
'pydantic.errors:SetMaxLengthError',
|
||||
'pydantic.errors:SetMinLengthError',
|
||||
'pydantic.errors:StrError',
|
||||
'pydantic.errors:StrRegexError',
|
||||
'pydantic.errors:StrictBoolError',
|
||||
'pydantic.errors:SubclassError',
|
||||
'pydantic.errors:TimeError',
|
||||
'pydantic.errors:TupleError',
|
||||
'pydantic.errors:TupleLengthError',
|
||||
'pydantic.errors:UUIDError',
|
||||
'pydantic.errors:UUIDVersionError',
|
||||
'pydantic.errors:UrlError',
|
||||
'pydantic.errors:UrlExtraError',
|
||||
'pydantic.errors:UrlHostError',
|
||||
'pydantic.errors:UrlHostTldError',
|
||||
'pydantic.errors:UrlPortError',
|
||||
'pydantic.errors:UrlSchemeError',
|
||||
'pydantic.errors:UrlSchemePermittedError',
|
||||
'pydantic.errors:UrlUserInfoError',
|
||||
'pydantic.errors:WrongConstantError',
|
||||
'pydantic.main:validate_model',
|
||||
'pydantic.networks:stricturl',
|
||||
'pydantic:parse_file_as',
|
||||
'pydantic:parse_raw_as',
|
||||
'pydantic:stricturl',
|
||||
'pydantic.tools:parse_file_as',
|
||||
'pydantic.tools:parse_raw_as',
|
||||
'pydantic.types:ConstrainedBytes',
|
||||
'pydantic.types:ConstrainedDate',
|
||||
'pydantic.types:ConstrainedDecimal',
|
||||
'pydantic.types:ConstrainedFloat',
|
||||
'pydantic.types:ConstrainedFrozenSet',
|
||||
'pydantic.types:ConstrainedInt',
|
||||
'pydantic.types:ConstrainedList',
|
||||
'pydantic.types:ConstrainedSet',
|
||||
'pydantic.types:ConstrainedStr',
|
||||
'pydantic.types:JsonWrapper',
|
||||
'pydantic.types:NoneBytes',
|
||||
'pydantic.types:NoneStr',
|
||||
'pydantic.types:NoneStrBytes',
|
||||
'pydantic.types:StrBytes',
|
||||
'pydantic.typing:evaluate_forwardref',
|
||||
'pydantic.typing:AbstractSetIntStr',
|
||||
'pydantic.typing:AnyCallable',
|
||||
'pydantic.typing:AnyClassMethod',
|
||||
'pydantic.typing:CallableGenerator',
|
||||
'pydantic.typing:DictAny',
|
||||
'pydantic.typing:DictIntStrAny',
|
||||
'pydantic.typing:DictStrAny',
|
||||
'pydantic.typing:IntStr',
|
||||
'pydantic.typing:ListStr',
|
||||
'pydantic.typing:MappingIntStrAny',
|
||||
'pydantic.typing:NoArgAnyCallable',
|
||||
'pydantic.typing:NoneType',
|
||||
'pydantic.typing:ReprArgs',
|
||||
'pydantic.typing:SetStr',
|
||||
'pydantic.typing:StrPath',
|
||||
'pydantic.typing:TupleGenerator',
|
||||
'pydantic.typing:WithArgsTypes',
|
||||
'pydantic.typing:all_literal_values',
|
||||
'pydantic.typing:display_as_type',
|
||||
'pydantic.typing:get_all_type_hints',
|
||||
'pydantic.typing:get_args',
|
||||
'pydantic.typing:get_origin',
|
||||
'pydantic.typing:get_sub_types',
|
||||
'pydantic.typing:is_callable_type',
|
||||
'pydantic.typing:is_classvar',
|
||||
'pydantic.typing:is_finalvar',
|
||||
'pydantic.typing:is_literal_type',
|
||||
'pydantic.typing:is_namedtuple',
|
||||
'pydantic.typing:is_new_type',
|
||||
'pydantic.typing:is_none_type',
|
||||
'pydantic.typing:is_typeddict',
|
||||
'pydantic.typing:is_typeddict_special',
|
||||
'pydantic.typing:is_union',
|
||||
'pydantic.typing:new_type_supertype',
|
||||
'pydantic.typing:resolve_annotations',
|
||||
'pydantic.typing:typing_base',
|
||||
'pydantic.typing:update_field_forward_refs',
|
||||
'pydantic.typing:update_model_forward_refs',
|
||||
'pydantic.utils:ClassAttribute',
|
||||
'pydantic.utils:DUNDER_ATTRIBUTES',
|
||||
'pydantic.utils:PyObjectStr',
|
||||
'pydantic.utils:ValueItems',
|
||||
'pydantic.utils:almost_equal_floats',
|
||||
'pydantic.utils:get_discriminator_alias_and_values',
|
||||
'pydantic.utils:get_model',
|
||||
'pydantic.utils:get_unique_discriminator_alias',
|
||||
'pydantic.utils:in_ipython',
|
||||
'pydantic.utils:is_valid_identifier',
|
||||
'pydantic.utils:path_type',
|
||||
'pydantic.utils:validate_field_name',
|
||||
'pydantic:validate_model',
|
||||
}
|
||||
|
||||
|
||||
def getattr_migration(module: str) -> Callable[[str], Any]:
|
||||
"""Implement PEP 562 for objects that were either moved or removed on the migration
|
||||
to V2.
|
||||
|
||||
Args:
|
||||
module: The module name.
|
||||
|
||||
Returns:
|
||||
A callable that will raise an error if the object is not found.
|
||||
"""
|
||||
# This avoids circular import with errors.py.
|
||||
from .errors import PydanticImportError
|
||||
|
||||
def wrapper(name: str) -> object:
|
||||
"""Raise an error if the object is not found, or warn if it was moved.
|
||||
|
||||
In case it was moved, it still returns the object.
|
||||
|
||||
Args:
|
||||
name: The object name.
|
||||
|
||||
Returns:
|
||||
The object.
|
||||
"""
|
||||
if name == '__path__':
|
||||
raise AttributeError(f'module {module!r} has no attribute {name!r}')
|
||||
|
||||
import warnings
|
||||
|
||||
from ._internal._validators import import_string
|
||||
|
||||
import_path = f'{module}:{name}'
|
||||
if import_path in MOVED_IN_V2.keys():
|
||||
new_location = MOVED_IN_V2[import_path]
|
||||
warnings.warn(f'`{import_path}` has been moved to `{new_location}`.')
|
||||
return import_string(MOVED_IN_V2[import_path])
|
||||
if import_path in DEPRECATED_MOVED_IN_V2:
|
||||
# skip the warning here because a deprecation warning will be raised elsewhere
|
||||
return import_string(DEPRECATED_MOVED_IN_V2[import_path])
|
||||
if import_path in REDIRECT_TO_V1:
|
||||
new_location = REDIRECT_TO_V1[import_path]
|
||||
warnings.warn(
|
||||
f'`{import_path}` has been removed. We are importing from `{new_location}` instead.'
|
||||
'See the migration guide for more details: https://docs.pydantic.dev/latest/migration/'
|
||||
)
|
||||
return import_string(REDIRECT_TO_V1[import_path])
|
||||
if import_path == 'pydantic:BaseSettings':
|
||||
raise PydanticImportError(
|
||||
'`BaseSettings` has been moved to the `pydantic-settings` package. '
|
||||
f'See https://docs.pydantic.dev/{version_short()}/migration/#basesettings-has-moved-to-pydantic-settings '
|
||||
'for more details.'
|
||||
)
|
||||
if import_path in REMOVED_IN_V2:
|
||||
raise PydanticImportError(f'`{import_path}` has been removed in V2.')
|
||||
globals: Dict[str, Any] = sys.modules[module].__dict__
|
||||
if name in globals:
|
||||
return globals[name]
|
||||
raise AttributeError(f'module {module!r} has no attribute {name!r}')
|
||||
|
||||
return wrapper
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Alias generators for converting between different capitalization conventions."""
|
||||
|
||||
import re
|
||||
|
||||
__all__ = ('to_pascal', 'to_camel', 'to_snake')
|
||||
|
||||
# TODO: in V3, change the argument names to be more descriptive
|
||||
# Generally, don't only convert from snake_case, or name the functions
|
||||
# more specifically like snake_to_camel.
|
||||
|
||||
|
||||
def to_pascal(snake: str) -> str:
|
||||
"""Convert a snake_case string to PascalCase.
|
||||
|
||||
Args:
|
||||
snake: The string to convert.
|
||||
|
||||
Returns:
|
||||
The PascalCase string.
|
||||
"""
|
||||
camel = snake.title()
|
||||
return re.sub('([0-9A-Za-z])_(?=[0-9A-Z])', lambda m: m.group(1), camel)
|
||||
|
||||
|
||||
def to_camel(snake: str) -> str:
|
||||
"""Convert a snake_case string to camelCase.
|
||||
|
||||
Args:
|
||||
snake: The string to convert.
|
||||
|
||||
Returns:
|
||||
The converted camelCase string.
|
||||
"""
|
||||
# If the string is already in camelCase and does not contain a digit followed
|
||||
# by a lowercase letter, return it as it is
|
||||
if re.match('^[a-z]+[A-Za-z0-9]*$', snake) and not re.search(r'\d[a-z]', snake):
|
||||
return snake
|
||||
|
||||
camel = to_pascal(snake)
|
||||
return re.sub('(^_*[A-Z])', lambda m: m.group(1).lower(), camel)
|
||||
|
||||
|
||||
def to_snake(camel: str) -> str:
|
||||
"""Convert a PascalCase, camelCase, or kebab-case string to snake_case.
|
||||
|
||||
Args:
|
||||
camel: The string to convert.
|
||||
|
||||
Returns:
|
||||
The converted string in snake_case.
|
||||
"""
|
||||
# Handle the sequence of uppercase letters followed by a lowercase letter
|
||||
snake = re.sub(r'([A-Z]+)([A-Z][a-z])', lambda m: f'{m.group(1)}_{m.group(2)}', camel)
|
||||
# Insert an underscore between a lowercase letter and an uppercase letter
|
||||
snake = re.sub(r'([a-z])([A-Z])', lambda m: f'{m.group(1)}_{m.group(2)}', snake)
|
||||
# Insert an underscore between a digit and an uppercase letter
|
||||
snake = re.sub(r'([0-9])([A-Z])', lambda m: f'{m.group(1)}_{m.group(2)}', snake)
|
||||
# Insert an underscore between a lowercase letter and a digit
|
||||
snake = re.sub(r'([a-z])([0-9])', lambda m: f'{m.group(1)}_{m.group(2)}', snake)
|
||||
# Replace hyphens with underscores to handle kebab-case
|
||||
snake = snake.replace('-', '_')
|
||||
return snake.lower()
|
||||
@@ -0,0 +1,132 @@
|
||||
"""Support for alias configurations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from ._internal import _internal_dataclass
|
||||
|
||||
__all__ = ('AliasGenerator', 'AliasPath', 'AliasChoices')
|
||||
|
||||
|
||||
@dataclasses.dataclass(**_internal_dataclass.slots_true)
|
||||
class AliasPath:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.10/concepts/alias#aliaspath-and-aliaschoices
|
||||
|
||||
A data class used by `validation_alias` as a convenience to create aliases.
|
||||
|
||||
Attributes:
|
||||
path: A list of string or integer aliases.
|
||||
"""
|
||||
|
||||
path: list[int | str]
|
||||
|
||||
def __init__(self, first_arg: str, *args: str | int) -> None:
|
||||
self.path = [first_arg] + list(args)
|
||||
|
||||
def convert_to_aliases(self) -> list[str | int]:
|
||||
"""Converts arguments to a list of string or integer aliases.
|
||||
|
||||
Returns:
|
||||
The list of aliases.
|
||||
"""
|
||||
return self.path
|
||||
|
||||
def search_dict_for_path(self, d: dict) -> Any:
|
||||
"""Searches a dictionary for the path specified by the alias.
|
||||
|
||||
Returns:
|
||||
The value at the specified path, or `PydanticUndefined` if the path is not found.
|
||||
"""
|
||||
v = d
|
||||
for k in self.path:
|
||||
if isinstance(v, str):
|
||||
# disallow indexing into a str, like for AliasPath('x', 0) and x='abc'
|
||||
return PydanticUndefined
|
||||
try:
|
||||
v = v[k]
|
||||
except (KeyError, IndexError, TypeError):
|
||||
return PydanticUndefined
|
||||
return v
|
||||
|
||||
|
||||
@dataclasses.dataclass(**_internal_dataclass.slots_true)
|
||||
class AliasChoices:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.10/concepts/alias#aliaspath-and-aliaschoices
|
||||
|
||||
A data class used by `validation_alias` as a convenience to create aliases.
|
||||
|
||||
Attributes:
|
||||
choices: A list containing a string or `AliasPath`.
|
||||
"""
|
||||
|
||||
choices: list[str | AliasPath]
|
||||
|
||||
def __init__(self, first_choice: str | AliasPath, *choices: str | AliasPath) -> None:
|
||||
self.choices = [first_choice] + list(choices)
|
||||
|
||||
def convert_to_aliases(self) -> list[list[str | int]]:
|
||||
"""Converts arguments to a list of lists containing string or integer aliases.
|
||||
|
||||
Returns:
|
||||
The list of aliases.
|
||||
"""
|
||||
aliases: list[list[str | int]] = []
|
||||
for c in self.choices:
|
||||
if isinstance(c, AliasPath):
|
||||
aliases.append(c.convert_to_aliases())
|
||||
else:
|
||||
aliases.append([c])
|
||||
return aliases
|
||||
|
||||
|
||||
@dataclasses.dataclass(**_internal_dataclass.slots_true)
|
||||
class AliasGenerator:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.10/concepts/alias#using-an-aliasgenerator
|
||||
|
||||
A data class used by `alias_generator` as a convenience to create various aliases.
|
||||
|
||||
Attributes:
|
||||
alias: A callable that takes a field name and returns an alias for it.
|
||||
validation_alias: A callable that takes a field name and returns a validation alias for it.
|
||||
serialization_alias: A callable that takes a field name and returns a serialization alias for it.
|
||||
"""
|
||||
|
||||
alias: Callable[[str], str] | None = None
|
||||
validation_alias: Callable[[str], str | AliasPath | AliasChoices] | None = None
|
||||
serialization_alias: Callable[[str], str] | None = None
|
||||
|
||||
def _generate_alias(
|
||||
self,
|
||||
alias_kind: Literal['alias', 'validation_alias', 'serialization_alias'],
|
||||
allowed_types: tuple[type[str] | type[AliasPath] | type[AliasChoices], ...],
|
||||
field_name: str,
|
||||
) -> str | AliasPath | AliasChoices | None:
|
||||
"""Generate an alias of the specified kind. Returns None if the alias generator is None.
|
||||
|
||||
Raises:
|
||||
TypeError: If the alias generator produces an invalid type.
|
||||
"""
|
||||
alias = None
|
||||
if alias_generator := getattr(self, alias_kind):
|
||||
alias = alias_generator(field_name)
|
||||
if alias and not isinstance(alias, allowed_types):
|
||||
raise TypeError(
|
||||
f'Invalid `{alias_kind}` type. `{alias_kind}` generator must produce one of `{allowed_types}`'
|
||||
)
|
||||
return alias
|
||||
|
||||
def generate_aliases(self, field_name: str) -> tuple[str | None, str | AliasPath | AliasChoices | None, str | None]:
|
||||
"""Generate `alias`, `validation_alias`, and `serialization_alias` for a field.
|
||||
|
||||
Returns:
|
||||
A tuple of three aliases - validation, alias, and serialization.
|
||||
"""
|
||||
alias = self._generate_alias('alias', (str,), field_name)
|
||||
validation_alias = self._generate_alias('validation_alias', (str, AliasChoices, AliasPath), field_name)
|
||||
serialization_alias = self._generate_alias('serialization_alias', (str,), field_name)
|
||||
|
||||
return alias, validation_alias, serialization_alias # type: ignore
|
||||
@@ -0,0 +1,122 @@
|
||||
"""Type annotations to use with `__get_pydantic_core_schema__` and `__get_pydantic_json_schema__`."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from pydantic_core import core_schema
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._internal._namespace_utils import NamespacesTuple
|
||||
from .json_schema import JsonSchemaMode, JsonSchemaValue
|
||||
|
||||
CoreSchemaOrField = Union[
|
||||
core_schema.CoreSchema,
|
||||
core_schema.ModelField,
|
||||
core_schema.DataclassField,
|
||||
core_schema.TypedDictField,
|
||||
core_schema.ComputedField,
|
||||
]
|
||||
|
||||
__all__ = 'GetJsonSchemaHandler', 'GetCoreSchemaHandler'
|
||||
|
||||
|
||||
class GetJsonSchemaHandler:
|
||||
"""Handler to call into the next JSON schema generation function.
|
||||
|
||||
Attributes:
|
||||
mode: Json schema mode, can be `validation` or `serialization`.
|
||||
"""
|
||||
|
||||
mode: JsonSchemaMode
|
||||
|
||||
def __call__(self, core_schema: CoreSchemaOrField, /) -> JsonSchemaValue:
|
||||
"""Call the inner handler and get the JsonSchemaValue it returns.
|
||||
This will call the next JSON schema modifying function up until it calls
|
||||
into `pydantic.json_schema.GenerateJsonSchema`, which will raise a
|
||||
`pydantic.errors.PydanticInvalidForJsonSchema` error if it cannot generate
|
||||
a JSON schema.
|
||||
|
||||
Args:
|
||||
core_schema: A `pydantic_core.core_schema.CoreSchema`.
|
||||
|
||||
Returns:
|
||||
JsonSchemaValue: The JSON schema generated by the inner JSON schema modify
|
||||
functions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def resolve_ref_schema(self, maybe_ref_json_schema: JsonSchemaValue, /) -> JsonSchemaValue:
|
||||
"""Get the real schema for a `{"$ref": ...}` schema.
|
||||
If the schema given is not a `$ref` schema, it will be returned as is.
|
||||
This means you don't have to check before calling this function.
|
||||
|
||||
Args:
|
||||
maybe_ref_json_schema: A JsonSchemaValue which may be a `$ref` schema.
|
||||
|
||||
Raises:
|
||||
LookupError: If the ref is not found.
|
||||
|
||||
Returns:
|
||||
JsonSchemaValue: A JsonSchemaValue that has no `$ref`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class GetCoreSchemaHandler:
|
||||
"""Handler to call into the next CoreSchema schema generation function."""
|
||||
|
||||
def __call__(self, source_type: Any, /) -> core_schema.CoreSchema:
|
||||
"""Call the inner handler and get the CoreSchema it returns.
|
||||
This will call the next CoreSchema modifying function up until it calls
|
||||
into Pydantic's internal schema generation machinery, which will raise a
|
||||
`pydantic.errors.PydanticSchemaGenerationError` error if it cannot generate
|
||||
a CoreSchema for the given source type.
|
||||
|
||||
Args:
|
||||
source_type: The input type.
|
||||
|
||||
Returns:
|
||||
CoreSchema: The `pydantic-core` CoreSchema generated.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_schema(self, source_type: Any, /) -> core_schema.CoreSchema:
|
||||
"""Generate a schema unrelated to the current context.
|
||||
Use this function if e.g. you are handling schema generation for a sequence
|
||||
and want to generate a schema for its items.
|
||||
Otherwise, you may end up doing something like applying a `min_length` constraint
|
||||
that was intended for the sequence itself to its items!
|
||||
|
||||
Args:
|
||||
source_type: The input type.
|
||||
|
||||
Returns:
|
||||
CoreSchema: The `pydantic-core` CoreSchema generated.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def resolve_ref_schema(self, maybe_ref_schema: core_schema.CoreSchema, /) -> core_schema.CoreSchema:
|
||||
"""Get the real schema for a `definition-ref` schema.
|
||||
If the schema given is not a `definition-ref` schema, it will be returned as is.
|
||||
This means you don't have to check before calling this function.
|
||||
|
||||
Args:
|
||||
maybe_ref_schema: A `CoreSchema`, `ref`-based or not.
|
||||
|
||||
Raises:
|
||||
LookupError: If the `ref` is not found.
|
||||
|
||||
Returns:
|
||||
A concrete `CoreSchema`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def field_name(self) -> str | None:
|
||||
"""Get the name of the closest field to this validator."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_types_namespace(self) -> NamespacesTuple:
|
||||
"""Internal method used during type resolution for serializer annotations."""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,5 @@
|
||||
"""`class_validators` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
@@ -0,0 +1,604 @@
|
||||
"""Color definitions are used as per the CSS3
|
||||
[CSS Color Module Level 3](http://www.w3.org/TR/css3-color/#svg-color) specification.
|
||||
|
||||
A few colors have multiple names referring to the sames colors, eg. `grey` and `gray` or `aqua` and `cyan`.
|
||||
|
||||
In these cases the _last_ color when sorted alphabetically takes preferences,
|
||||
eg. `Color((0, 255, 255)).as_named() == 'cyan'` because "cyan" comes after "aqua".
|
||||
|
||||
Warning: Deprecated
|
||||
The `Color` class is deprecated, use `pydantic_extra_types` instead.
|
||||
See [`pydantic-extra-types.Color`](../usage/types/extra_types/color_types.md)
|
||||
for more information.
|
||||
"""
|
||||
|
||||
import math
|
||||
import re
|
||||
from colorsys import hls_to_rgb, rgb_to_hls
|
||||
from typing import Any, Callable, Optional, Tuple, Type, Union, cast
|
||||
|
||||
from pydantic_core import CoreSchema, PydanticCustomError, core_schema
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from ._internal import _repr
|
||||
from ._internal._schema_generation_shared import GetJsonSchemaHandler as _GetJsonSchemaHandler
|
||||
from .json_schema import JsonSchemaValue
|
||||
from .warnings import PydanticDeprecatedSince20
|
||||
|
||||
ColorTuple = Union[Tuple[int, int, int], Tuple[int, int, int, float]]
|
||||
ColorType = Union[ColorTuple, str]
|
||||
HslColorTuple = Union[Tuple[float, float, float], Tuple[float, float, float, float]]
|
||||
|
||||
|
||||
class RGBA:
|
||||
"""Internal use only as a representation of a color."""
|
||||
|
||||
__slots__ = 'r', 'g', 'b', 'alpha', '_tuple'
|
||||
|
||||
def __init__(self, r: float, g: float, b: float, alpha: Optional[float]):
|
||||
self.r = r
|
||||
self.g = g
|
||||
self.b = b
|
||||
self.alpha = alpha
|
||||
|
||||
self._tuple: Tuple[float, float, float, Optional[float]] = (r, g, b, alpha)
|
||||
|
||||
def __getitem__(self, item: Any) -> Any:
|
||||
return self._tuple[item]
|
||||
|
||||
|
||||
# these are not compiled here to avoid import slowdown, they'll be compiled the first time they're used, then cached
|
||||
_r_255 = r'(\d{1,3}(?:\.\d+)?)'
|
||||
_r_comma = r'\s*,\s*'
|
||||
_r_alpha = r'(\d(?:\.\d+)?|\.\d+|\d{1,2}%)'
|
||||
_r_h = r'(-?\d+(?:\.\d+)?|-?\.\d+)(deg|rad|turn)?'
|
||||
_r_sl = r'(\d{1,3}(?:\.\d+)?)%'
|
||||
r_hex_short = r'\s*(?:#|0x)?([0-9a-f])([0-9a-f])([0-9a-f])([0-9a-f])?\s*'
|
||||
r_hex_long = r'\s*(?:#|0x)?([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})?\s*'
|
||||
# CSS3 RGB examples: rgb(0, 0, 0), rgba(0, 0, 0, 0.5), rgba(0, 0, 0, 50%)
|
||||
r_rgb = rf'\s*rgba?\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}(?:{_r_comma}{_r_alpha})?\s*\)\s*'
|
||||
# CSS3 HSL examples: hsl(270, 60%, 50%), hsla(270, 60%, 50%, 0.5), hsla(270, 60%, 50%, 50%)
|
||||
r_hsl = rf'\s*hsla?\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}(?:{_r_comma}{_r_alpha})?\s*\)\s*'
|
||||
# CSS4 RGB examples: rgb(0 0 0), rgb(0 0 0 / 0.5), rgb(0 0 0 / 50%), rgba(0 0 0 / 50%)
|
||||
r_rgb_v4_style = rf'\s*rgba?\(\s*{_r_255}\s+{_r_255}\s+{_r_255}(?:\s*/\s*{_r_alpha})?\s*\)\s*'
|
||||
# CSS4 HSL examples: hsl(270 60% 50%), hsl(270 60% 50% / 0.5), hsl(270 60% 50% / 50%), hsla(270 60% 50% / 50%)
|
||||
r_hsl_v4_style = rf'\s*hsla?\(\s*{_r_h}\s+{_r_sl}\s+{_r_sl}(?:\s*/\s*{_r_alpha})?\s*\)\s*'
|
||||
|
||||
# colors where the two hex characters are the same, if all colors match this the short version of hex colors can be used
|
||||
repeat_colors = {int(c * 2, 16) for c in '0123456789abcdef'}
|
||||
rads = 2 * math.pi
|
||||
|
||||
|
||||
@deprecated(
|
||||
'The `Color` class is deprecated, use `pydantic_extra_types` instead. '
|
||||
'See https://docs.pydantic.dev/latest/api/pydantic_extra_types_color/.',
|
||||
category=PydanticDeprecatedSince20,
|
||||
)
|
||||
class Color(_repr.Representation):
|
||||
"""Represents a color."""
|
||||
|
||||
__slots__ = '_original', '_rgba'
|
||||
|
||||
def __init__(self, value: ColorType) -> None:
|
||||
self._rgba: RGBA
|
||||
self._original: ColorType
|
||||
if isinstance(value, (tuple, list)):
|
||||
self._rgba = parse_tuple(value)
|
||||
elif isinstance(value, str):
|
||||
self._rgba = parse_str(value)
|
||||
elif isinstance(value, Color):
|
||||
self._rgba = value._rgba
|
||||
value = value._original
|
||||
else:
|
||||
raise PydanticCustomError(
|
||||
'color_error', 'value is not a valid color: value must be a tuple, list or string'
|
||||
)
|
||||
|
||||
# if we've got here value must be a valid color
|
||||
self._original = value
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, core_schema: core_schema.CoreSchema, handler: _GetJsonSchemaHandler
|
||||
) -> JsonSchemaValue:
|
||||
field_schema = {}
|
||||
field_schema.update(type='string', format='color')
|
||||
return field_schema
|
||||
|
||||
def original(self) -> ColorType:
|
||||
"""Original value passed to `Color`."""
|
||||
return self._original
|
||||
|
||||
def as_named(self, *, fallback: bool = False) -> str:
|
||||
"""Returns the name of the color if it can be found in `COLORS_BY_VALUE` dictionary,
|
||||
otherwise returns the hexadecimal representation of the color or raises `ValueError`.
|
||||
|
||||
Args:
|
||||
fallback: If True, falls back to returning the hexadecimal representation of
|
||||
the color instead of raising a ValueError when no named color is found.
|
||||
|
||||
Returns:
|
||||
The name of the color, or the hexadecimal representation of the color.
|
||||
|
||||
Raises:
|
||||
ValueError: When no named color is found and fallback is `False`.
|
||||
"""
|
||||
if self._rgba.alpha is None:
|
||||
rgb = cast(Tuple[int, int, int], self.as_rgb_tuple())
|
||||
try:
|
||||
return COLORS_BY_VALUE[rgb]
|
||||
except KeyError as e:
|
||||
if fallback:
|
||||
return self.as_hex()
|
||||
else:
|
||||
raise ValueError('no named color found, use fallback=True, as_hex() or as_rgb()') from e
|
||||
else:
|
||||
return self.as_hex()
|
||||
|
||||
def as_hex(self) -> str:
|
||||
"""Returns the hexadecimal representation of the color.
|
||||
|
||||
Hex string representing the color can be 3, 4, 6, or 8 characters depending on whether the string
|
||||
a "short" representation of the color is possible and whether there's an alpha channel.
|
||||
|
||||
Returns:
|
||||
The hexadecimal representation of the color.
|
||||
"""
|
||||
values = [float_to_255(c) for c in self._rgba[:3]]
|
||||
if self._rgba.alpha is not None:
|
||||
values.append(float_to_255(self._rgba.alpha))
|
||||
|
||||
as_hex = ''.join(f'{v:02x}' for v in values)
|
||||
if all(c in repeat_colors for c in values):
|
||||
as_hex = ''.join(as_hex[c] for c in range(0, len(as_hex), 2))
|
||||
return '#' + as_hex
|
||||
|
||||
def as_rgb(self) -> str:
|
||||
"""Color as an `rgb(<r>, <g>, <b>)` or `rgba(<r>, <g>, <b>, <a>)` string."""
|
||||
if self._rgba.alpha is None:
|
||||
return f'rgb({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)})'
|
||||
else:
|
||||
return (
|
||||
f'rgba({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)}, '
|
||||
f'{round(self._alpha_float(), 2)})'
|
||||
)
|
||||
|
||||
def as_rgb_tuple(self, *, alpha: Optional[bool] = None) -> ColorTuple:
|
||||
"""Returns the color as an RGB or RGBA tuple.
|
||||
|
||||
Args:
|
||||
alpha: Whether to include the alpha channel. There are three options for this input:
|
||||
|
||||
- `None` (default): Include alpha only if it's set. (e.g. not `None`)
|
||||
- `True`: Always include alpha.
|
||||
- `False`: Always omit alpha.
|
||||
|
||||
Returns:
|
||||
A tuple that contains the values of the red, green, and blue channels in the range 0 to 255.
|
||||
If alpha is included, it is in the range 0 to 1.
|
||||
"""
|
||||
r, g, b = (float_to_255(c) for c in self._rgba[:3])
|
||||
if alpha is None:
|
||||
if self._rgba.alpha is None:
|
||||
return r, g, b
|
||||
else:
|
||||
return r, g, b, self._alpha_float()
|
||||
elif alpha:
|
||||
return r, g, b, self._alpha_float()
|
||||
else:
|
||||
# alpha is False
|
||||
return r, g, b
|
||||
|
||||
def as_hsl(self) -> str:
|
||||
"""Color as an `hsl(<h>, <s>, <l>)` or `hsl(<h>, <s>, <l>, <a>)` string."""
|
||||
if self._rgba.alpha is None:
|
||||
h, s, li = self.as_hsl_tuple(alpha=False) # type: ignore
|
||||
return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%})'
|
||||
else:
|
||||
h, s, li, a = self.as_hsl_tuple(alpha=True) # type: ignore
|
||||
return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%}, {round(a, 2)})'
|
||||
|
||||
def as_hsl_tuple(self, *, alpha: Optional[bool] = None) -> HslColorTuple:
|
||||
"""Returns the color as an HSL or HSLA tuple.
|
||||
|
||||
Args:
|
||||
alpha: Whether to include the alpha channel.
|
||||
|
||||
- `None` (default): Include the alpha channel only if it's set (e.g. not `None`).
|
||||
- `True`: Always include alpha.
|
||||
- `False`: Always omit alpha.
|
||||
|
||||
Returns:
|
||||
The color as a tuple of hue, saturation, lightness, and alpha (if included).
|
||||
All elements are in the range 0 to 1.
|
||||
|
||||
Note:
|
||||
This is HSL as used in HTML and most other places, not HLS as used in Python's `colorsys`.
|
||||
"""
|
||||
h, l, s = rgb_to_hls(self._rgba.r, self._rgba.g, self._rgba.b) # noqa: E741
|
||||
if alpha is None:
|
||||
if self._rgba.alpha is None:
|
||||
return h, s, l
|
||||
else:
|
||||
return h, s, l, self._alpha_float()
|
||||
if alpha:
|
||||
return h, s, l, self._alpha_float()
|
||||
else:
|
||||
# alpha is False
|
||||
return h, s, l
|
||||
|
||||
def _alpha_float(self) -> float:
|
||||
return 1 if self._rgba.alpha is None else self._rgba.alpha
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source: Type[Any], handler: Callable[[Any], CoreSchema]
|
||||
) -> core_schema.CoreSchema:
|
||||
return core_schema.with_info_plain_validator_function(
|
||||
cls._validate, serialization=core_schema.to_string_ser_schema()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _validate(cls, __input_value: Any, _: Any) -> 'Color':
|
||||
return cls(__input_value)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.as_named(fallback=True)
|
||||
|
||||
def __repr_args__(self) -> '_repr.ReprArgs':
|
||||
return [(None, self.as_named(fallback=True))] + [('rgb', self.as_rgb_tuple())]
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, Color) and self.as_rgb_tuple() == other.as_rgb_tuple()
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.as_rgb_tuple())
|
||||
|
||||
|
||||
def parse_tuple(value: Tuple[Any, ...]) -> RGBA:
|
||||
"""Parse a tuple or list to get RGBA values.
|
||||
|
||||
Args:
|
||||
value: A tuple or list.
|
||||
|
||||
Returns:
|
||||
An `RGBA` tuple parsed from the input tuple.
|
||||
|
||||
Raises:
|
||||
PydanticCustomError: If tuple is not valid.
|
||||
"""
|
||||
if len(value) == 3:
|
||||
r, g, b = (parse_color_value(v) for v in value)
|
||||
return RGBA(r, g, b, None)
|
||||
elif len(value) == 4:
|
||||
r, g, b = (parse_color_value(v) for v in value[:3])
|
||||
return RGBA(r, g, b, parse_float_alpha(value[3]))
|
||||
else:
|
||||
raise PydanticCustomError('color_error', 'value is not a valid color: tuples must have length 3 or 4')
|
||||
|
||||
|
||||
def parse_str(value: str) -> RGBA:
|
||||
"""Parse a string representing a color to an RGBA tuple.
|
||||
|
||||
Possible formats for the input string include:
|
||||
|
||||
* named color, see `COLORS_BY_NAME`
|
||||
* hex short eg. `<prefix>fff` (prefix can be `#`, `0x` or nothing)
|
||||
* hex long eg. `<prefix>ffffff` (prefix can be `#`, `0x` or nothing)
|
||||
* `rgb(<r>, <g>, <b>)`
|
||||
* `rgba(<r>, <g>, <b>, <a>)`
|
||||
|
||||
Args:
|
||||
value: A string representing a color.
|
||||
|
||||
Returns:
|
||||
An `RGBA` tuple parsed from the input string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input string cannot be parsed to an RGBA tuple.
|
||||
"""
|
||||
value_lower = value.lower()
|
||||
try:
|
||||
r, g, b = COLORS_BY_NAME[value_lower]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
return ints_to_rgba(r, g, b, None)
|
||||
|
||||
m = re.fullmatch(r_hex_short, value_lower)
|
||||
if m:
|
||||
*rgb, a = m.groups()
|
||||
r, g, b = (int(v * 2, 16) for v in rgb)
|
||||
if a:
|
||||
alpha: Optional[float] = int(a * 2, 16) / 255
|
||||
else:
|
||||
alpha = None
|
||||
return ints_to_rgba(r, g, b, alpha)
|
||||
|
||||
m = re.fullmatch(r_hex_long, value_lower)
|
||||
if m:
|
||||
*rgb, a = m.groups()
|
||||
r, g, b = (int(v, 16) for v in rgb)
|
||||
if a:
|
||||
alpha = int(a, 16) / 255
|
||||
else:
|
||||
alpha = None
|
||||
return ints_to_rgba(r, g, b, alpha)
|
||||
|
||||
m = re.fullmatch(r_rgb, value_lower) or re.fullmatch(r_rgb_v4_style, value_lower)
|
||||
if m:
|
||||
return ints_to_rgba(*m.groups()) # type: ignore
|
||||
|
||||
m = re.fullmatch(r_hsl, value_lower) or re.fullmatch(r_hsl_v4_style, value_lower)
|
||||
if m:
|
||||
return parse_hsl(*m.groups()) # type: ignore
|
||||
|
||||
raise PydanticCustomError('color_error', 'value is not a valid color: string not recognised as a valid color')
|
||||
|
||||
|
||||
def ints_to_rgba(r: Union[int, str], g: Union[int, str], b: Union[int, str], alpha: Optional[float] = None) -> RGBA:
|
||||
"""Converts integer or string values for RGB color and an optional alpha value to an `RGBA` object.
|
||||
|
||||
Args:
|
||||
r: An integer or string representing the red color value.
|
||||
g: An integer or string representing the green color value.
|
||||
b: An integer or string representing the blue color value.
|
||||
alpha: A float representing the alpha value. Defaults to None.
|
||||
|
||||
Returns:
|
||||
An instance of the `RGBA` class with the corresponding color and alpha values.
|
||||
"""
|
||||
return RGBA(parse_color_value(r), parse_color_value(g), parse_color_value(b), parse_float_alpha(alpha))
|
||||
|
||||
|
||||
def parse_color_value(value: Union[int, str], max_val: int = 255) -> float:
|
||||
"""Parse the color value provided and return a number between 0 and 1.
|
||||
|
||||
Args:
|
||||
value: An integer or string color value.
|
||||
max_val: Maximum range value. Defaults to 255.
|
||||
|
||||
Raises:
|
||||
PydanticCustomError: If the value is not a valid color.
|
||||
|
||||
Returns:
|
||||
A number between 0 and 1.
|
||||
"""
|
||||
try:
|
||||
color = float(value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('color_error', 'value is not a valid color: color values must be a valid number')
|
||||
if 0 <= color <= max_val:
|
||||
return color / max_val
|
||||
else:
|
||||
raise PydanticCustomError(
|
||||
'color_error',
|
||||
'value is not a valid color: color values must be in the range 0 to {max_val}',
|
||||
{'max_val': max_val},
|
||||
)
|
||||
|
||||
|
||||
def parse_float_alpha(value: Union[None, str, float, int]) -> Optional[float]:
|
||||
"""Parse an alpha value checking it's a valid float in the range 0 to 1.
|
||||
|
||||
Args:
|
||||
value: The input value to parse.
|
||||
|
||||
Returns:
|
||||
The parsed value as a float, or `None` if the value was None or equal 1.
|
||||
|
||||
Raises:
|
||||
PydanticCustomError: If the input value cannot be successfully parsed as a float in the expected range.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
if isinstance(value, str) and value.endswith('%'):
|
||||
alpha = float(value[:-1]) / 100
|
||||
else:
|
||||
alpha = float(value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('color_error', 'value is not a valid color: alpha values must be a valid float')
|
||||
|
||||
if math.isclose(alpha, 1):
|
||||
return None
|
||||
elif 0 <= alpha <= 1:
|
||||
return alpha
|
||||
else:
|
||||
raise PydanticCustomError('color_error', 'value is not a valid color: alpha values must be in the range 0 to 1')
|
||||
|
||||
|
||||
def parse_hsl(h: str, h_units: str, sat: str, light: str, alpha: Optional[float] = None) -> RGBA:
|
||||
"""Parse raw hue, saturation, lightness, and alpha values and convert to RGBA.
|
||||
|
||||
Args:
|
||||
h: The hue value.
|
||||
h_units: The unit for hue value.
|
||||
sat: The saturation value.
|
||||
light: The lightness value.
|
||||
alpha: Alpha value.
|
||||
|
||||
Returns:
|
||||
An instance of `RGBA`.
|
||||
"""
|
||||
s_value, l_value = parse_color_value(sat, 100), parse_color_value(light, 100)
|
||||
|
||||
h_value = float(h)
|
||||
if h_units in {None, 'deg'}:
|
||||
h_value = h_value % 360 / 360
|
||||
elif h_units == 'rad':
|
||||
h_value = h_value % rads / rads
|
||||
else:
|
||||
# turns
|
||||
h_value = h_value % 1
|
||||
|
||||
r, g, b = hls_to_rgb(h_value, l_value, s_value)
|
||||
return RGBA(r, g, b, parse_float_alpha(alpha))
|
||||
|
||||
|
||||
def float_to_255(c: float) -> int:
|
||||
"""Converts a float value between 0 and 1 (inclusive) to an integer between 0 and 255 (inclusive).
|
||||
|
||||
Args:
|
||||
c: The float value to be converted. Must be between 0 and 1 (inclusive).
|
||||
|
||||
Returns:
|
||||
The integer equivalent of the given float value rounded to the nearest whole number.
|
||||
|
||||
Raises:
|
||||
ValueError: If the given float value is outside the acceptable range of 0 to 1 (inclusive).
|
||||
"""
|
||||
return int(round(c * 255))
|
||||
|
||||
|
||||
COLORS_BY_NAME = {
|
||||
'aliceblue': (240, 248, 255),
|
||||
'antiquewhite': (250, 235, 215),
|
||||
'aqua': (0, 255, 255),
|
||||
'aquamarine': (127, 255, 212),
|
||||
'azure': (240, 255, 255),
|
||||
'beige': (245, 245, 220),
|
||||
'bisque': (255, 228, 196),
|
||||
'black': (0, 0, 0),
|
||||
'blanchedalmond': (255, 235, 205),
|
||||
'blue': (0, 0, 255),
|
||||
'blueviolet': (138, 43, 226),
|
||||
'brown': (165, 42, 42),
|
||||
'burlywood': (222, 184, 135),
|
||||
'cadetblue': (95, 158, 160),
|
||||
'chartreuse': (127, 255, 0),
|
||||
'chocolate': (210, 105, 30),
|
||||
'coral': (255, 127, 80),
|
||||
'cornflowerblue': (100, 149, 237),
|
||||
'cornsilk': (255, 248, 220),
|
||||
'crimson': (220, 20, 60),
|
||||
'cyan': (0, 255, 255),
|
||||
'darkblue': (0, 0, 139),
|
||||
'darkcyan': (0, 139, 139),
|
||||
'darkgoldenrod': (184, 134, 11),
|
||||
'darkgray': (169, 169, 169),
|
||||
'darkgreen': (0, 100, 0),
|
||||
'darkgrey': (169, 169, 169),
|
||||
'darkkhaki': (189, 183, 107),
|
||||
'darkmagenta': (139, 0, 139),
|
||||
'darkolivegreen': (85, 107, 47),
|
||||
'darkorange': (255, 140, 0),
|
||||
'darkorchid': (153, 50, 204),
|
||||
'darkred': (139, 0, 0),
|
||||
'darksalmon': (233, 150, 122),
|
||||
'darkseagreen': (143, 188, 143),
|
||||
'darkslateblue': (72, 61, 139),
|
||||
'darkslategray': (47, 79, 79),
|
||||
'darkslategrey': (47, 79, 79),
|
||||
'darkturquoise': (0, 206, 209),
|
||||
'darkviolet': (148, 0, 211),
|
||||
'deeppink': (255, 20, 147),
|
||||
'deepskyblue': (0, 191, 255),
|
||||
'dimgray': (105, 105, 105),
|
||||
'dimgrey': (105, 105, 105),
|
||||
'dodgerblue': (30, 144, 255),
|
||||
'firebrick': (178, 34, 34),
|
||||
'floralwhite': (255, 250, 240),
|
||||
'forestgreen': (34, 139, 34),
|
||||
'fuchsia': (255, 0, 255),
|
||||
'gainsboro': (220, 220, 220),
|
||||
'ghostwhite': (248, 248, 255),
|
||||
'gold': (255, 215, 0),
|
||||
'goldenrod': (218, 165, 32),
|
||||
'gray': (128, 128, 128),
|
||||
'green': (0, 128, 0),
|
||||
'greenyellow': (173, 255, 47),
|
||||
'grey': (128, 128, 128),
|
||||
'honeydew': (240, 255, 240),
|
||||
'hotpink': (255, 105, 180),
|
||||
'indianred': (205, 92, 92),
|
||||
'indigo': (75, 0, 130),
|
||||
'ivory': (255, 255, 240),
|
||||
'khaki': (240, 230, 140),
|
||||
'lavender': (230, 230, 250),
|
||||
'lavenderblush': (255, 240, 245),
|
||||
'lawngreen': (124, 252, 0),
|
||||
'lemonchiffon': (255, 250, 205),
|
||||
'lightblue': (173, 216, 230),
|
||||
'lightcoral': (240, 128, 128),
|
||||
'lightcyan': (224, 255, 255),
|
||||
'lightgoldenrodyellow': (250, 250, 210),
|
||||
'lightgray': (211, 211, 211),
|
||||
'lightgreen': (144, 238, 144),
|
||||
'lightgrey': (211, 211, 211),
|
||||
'lightpink': (255, 182, 193),
|
||||
'lightsalmon': (255, 160, 122),
|
||||
'lightseagreen': (32, 178, 170),
|
||||
'lightskyblue': (135, 206, 250),
|
||||
'lightslategray': (119, 136, 153),
|
||||
'lightslategrey': (119, 136, 153),
|
||||
'lightsteelblue': (176, 196, 222),
|
||||
'lightyellow': (255, 255, 224),
|
||||
'lime': (0, 255, 0),
|
||||
'limegreen': (50, 205, 50),
|
||||
'linen': (250, 240, 230),
|
||||
'magenta': (255, 0, 255),
|
||||
'maroon': (128, 0, 0),
|
||||
'mediumaquamarine': (102, 205, 170),
|
||||
'mediumblue': (0, 0, 205),
|
||||
'mediumorchid': (186, 85, 211),
|
||||
'mediumpurple': (147, 112, 219),
|
||||
'mediumseagreen': (60, 179, 113),
|
||||
'mediumslateblue': (123, 104, 238),
|
||||
'mediumspringgreen': (0, 250, 154),
|
||||
'mediumturquoise': (72, 209, 204),
|
||||
'mediumvioletred': (199, 21, 133),
|
||||
'midnightblue': (25, 25, 112),
|
||||
'mintcream': (245, 255, 250),
|
||||
'mistyrose': (255, 228, 225),
|
||||
'moccasin': (255, 228, 181),
|
||||
'navajowhite': (255, 222, 173),
|
||||
'navy': (0, 0, 128),
|
||||
'oldlace': (253, 245, 230),
|
||||
'olive': (128, 128, 0),
|
||||
'olivedrab': (107, 142, 35),
|
||||
'orange': (255, 165, 0),
|
||||
'orangered': (255, 69, 0),
|
||||
'orchid': (218, 112, 214),
|
||||
'palegoldenrod': (238, 232, 170),
|
||||
'palegreen': (152, 251, 152),
|
||||
'paleturquoise': (175, 238, 238),
|
||||
'palevioletred': (219, 112, 147),
|
||||
'papayawhip': (255, 239, 213),
|
||||
'peachpuff': (255, 218, 185),
|
||||
'peru': (205, 133, 63),
|
||||
'pink': (255, 192, 203),
|
||||
'plum': (221, 160, 221),
|
||||
'powderblue': (176, 224, 230),
|
||||
'purple': (128, 0, 128),
|
||||
'red': (255, 0, 0),
|
||||
'rosybrown': (188, 143, 143),
|
||||
'royalblue': (65, 105, 225),
|
||||
'saddlebrown': (139, 69, 19),
|
||||
'salmon': (250, 128, 114),
|
||||
'sandybrown': (244, 164, 96),
|
||||
'seagreen': (46, 139, 87),
|
||||
'seashell': (255, 245, 238),
|
||||
'sienna': (160, 82, 45),
|
||||
'silver': (192, 192, 192),
|
||||
'skyblue': (135, 206, 235),
|
||||
'slateblue': (106, 90, 205),
|
||||
'slategray': (112, 128, 144),
|
||||
'slategrey': (112, 128, 144),
|
||||
'snow': (255, 250, 250),
|
||||
'springgreen': (0, 255, 127),
|
||||
'steelblue': (70, 130, 180),
|
||||
'tan': (210, 180, 140),
|
||||
'teal': (0, 128, 128),
|
||||
'thistle': (216, 191, 216),
|
||||
'tomato': (255, 99, 71),
|
||||
'turquoise': (64, 224, 208),
|
||||
'violet': (238, 130, 238),
|
||||
'wheat': (245, 222, 179),
|
||||
'white': (255, 255, 255),
|
||||
'whitesmoke': (245, 245, 245),
|
||||
'yellow': (255, 255, 0),
|
||||
'yellowgreen': (154, 205, 50),
|
||||
}
|
||||
|
||||
COLORS_BY_VALUE = {v: k for k, v in COLORS_BY_NAME.items()}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,366 @@
|
||||
"""Provide an enhanced dataclass that performs validation."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses
|
||||
import sys
|
||||
import types
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, NoReturn, TypeVar, overload
|
||||
from warnings import warn
|
||||
|
||||
from typing_extensions import Literal, TypeGuard, dataclass_transform
|
||||
|
||||
from ._internal import _config, _decorators, _namespace_utils, _typing_extra
|
||||
from ._internal import _dataclasses as _pydantic_dataclasses
|
||||
from ._migration import getattr_migration
|
||||
from .config import ConfigDict
|
||||
from .errors import PydanticUserError
|
||||
from .fields import Field, FieldInfo, PrivateAttr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._internal._dataclasses import PydanticDataclass
|
||||
from ._internal._namespace_utils import MappingNamespace
|
||||
|
||||
__all__ = 'dataclass', 'rebuild_dataclass'
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr))
|
||||
@overload
|
||||
def dataclass(
|
||||
*,
|
||||
init: Literal[False] = False,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool = False,
|
||||
config: ConfigDict | type[object] | None = None,
|
||||
validate_on_init: bool | None = None,
|
||||
kw_only: bool = ...,
|
||||
slots: bool = ...,
|
||||
) -> Callable[[type[_T]], type[PydanticDataclass]]: # type: ignore
|
||||
...
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr))
|
||||
@overload
|
||||
def dataclass(
|
||||
_cls: type[_T], # type: ignore
|
||||
*,
|
||||
init: Literal[False] = False,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool | None = None,
|
||||
config: ConfigDict | type[object] | None = None,
|
||||
validate_on_init: bool | None = None,
|
||||
kw_only: bool = ...,
|
||||
slots: bool = ...,
|
||||
) -> type[PydanticDataclass]: ...
|
||||
|
||||
else:
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr))
|
||||
@overload
|
||||
def dataclass(
|
||||
*,
|
||||
init: Literal[False] = False,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool | None = None,
|
||||
config: ConfigDict | type[object] | None = None,
|
||||
validate_on_init: bool | None = None,
|
||||
) -> Callable[[type[_T]], type[PydanticDataclass]]: # type: ignore
|
||||
...
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr))
|
||||
@overload
|
||||
def dataclass(
|
||||
_cls: type[_T], # type: ignore
|
||||
*,
|
||||
init: Literal[False] = False,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool | None = None,
|
||||
config: ConfigDict | type[object] | None = None,
|
||||
validate_on_init: bool | None = None,
|
||||
) -> type[PydanticDataclass]: ...
|
||||
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr))
|
||||
def dataclass(
|
||||
_cls: type[_T] | None = None,
|
||||
*,
|
||||
init: Literal[False] = False,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool | None = None,
|
||||
config: ConfigDict | type[object] | None = None,
|
||||
validate_on_init: bool | None = None,
|
||||
kw_only: bool = False,
|
||||
slots: bool = False,
|
||||
) -> Callable[[type[_T]], type[PydanticDataclass]] | type[PydanticDataclass]:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.10/concepts/dataclasses/
|
||||
|
||||
A decorator used to create a Pydantic-enhanced dataclass, similar to the standard Python `dataclass`,
|
||||
but with added validation.
|
||||
|
||||
This function should be used similarly to `dataclasses.dataclass`.
|
||||
|
||||
Args:
|
||||
_cls: The target `dataclass`.
|
||||
init: Included for signature compatibility with `dataclasses.dataclass`, and is passed through to
|
||||
`dataclasses.dataclass` when appropriate. If specified, must be set to `False`, as pydantic inserts its
|
||||
own `__init__` function.
|
||||
repr: A boolean indicating whether to include the field in the `__repr__` output.
|
||||
eq: Determines if a `__eq__` method should be generated for the class.
|
||||
order: Determines if comparison magic methods should be generated, such as `__lt__`, but not `__eq__`.
|
||||
unsafe_hash: Determines if a `__hash__` method should be included in the class, as in `dataclasses.dataclass`.
|
||||
frozen: Determines if the generated class should be a 'frozen' `dataclass`, which does not allow its
|
||||
attributes to be modified after it has been initialized. If not set, the value from the provided `config` argument will be used (and will default to `False` otherwise).
|
||||
config: The Pydantic config to use for the `dataclass`.
|
||||
validate_on_init: A deprecated parameter included for backwards compatibility; in V2, all Pydantic dataclasses
|
||||
are validated on init.
|
||||
kw_only: Determines if `__init__` method parameters must be specified by keyword only. Defaults to `False`.
|
||||
slots: Determines if the generated class should be a 'slots' `dataclass`, which does not allow the addition of
|
||||
new attributes after instantiation.
|
||||
|
||||
Returns:
|
||||
A decorator that accepts a class as its argument and returns a Pydantic `dataclass`.
|
||||
|
||||
Raises:
|
||||
AssertionError: Raised if `init` is not `False` or `validate_on_init` is `False`.
|
||||
"""
|
||||
assert init is False, 'pydantic.dataclasses.dataclass only supports init=False'
|
||||
assert validate_on_init is not False, 'validate_on_init=False is no longer supported'
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
kwargs = {'kw_only': kw_only, 'slots': slots}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
def make_pydantic_fields_compatible(cls: type[Any]) -> None:
|
||||
"""Make sure that stdlib `dataclasses` understands `Field` kwargs like `kw_only`
|
||||
To do that, we simply change
|
||||
`x: int = pydantic.Field(..., kw_only=True)`
|
||||
into
|
||||
`x: int = dataclasses.field(default=pydantic.Field(..., kw_only=True), kw_only=True)`
|
||||
"""
|
||||
for annotation_cls in cls.__mro__:
|
||||
# In Python < 3.9, `__annotations__` might not be present if there are no fields.
|
||||
# we therefore need to use `getattr` to avoid an `AttributeError`.
|
||||
annotations = getattr(annotation_cls, '__annotations__', [])
|
||||
for field_name in annotations:
|
||||
field_value = getattr(cls, field_name, None)
|
||||
# Process only if this is an instance of `FieldInfo`.
|
||||
if not isinstance(field_value, FieldInfo):
|
||||
continue
|
||||
|
||||
# Initialize arguments for the standard `dataclasses.field`.
|
||||
field_args: dict = {'default': field_value}
|
||||
|
||||
# Handle `kw_only` for Python 3.10+
|
||||
if sys.version_info >= (3, 10) and field_value.kw_only:
|
||||
field_args['kw_only'] = True
|
||||
|
||||
# Set `repr` attribute if it's explicitly specified to be not `True`.
|
||||
if field_value.repr is not True:
|
||||
field_args['repr'] = field_value.repr
|
||||
|
||||
setattr(cls, field_name, dataclasses.field(**field_args))
|
||||
# In Python 3.8, dataclasses checks cls.__dict__['__annotations__'] for annotations,
|
||||
# so we must make sure it's initialized before we add to it.
|
||||
if cls.__dict__.get('__annotations__') is None:
|
||||
cls.__annotations__ = {}
|
||||
cls.__annotations__[field_name] = annotations[field_name]
|
||||
|
||||
def create_dataclass(cls: type[Any]) -> type[PydanticDataclass]:
|
||||
"""Create a Pydantic dataclass from a regular dataclass.
|
||||
|
||||
Args:
|
||||
cls: The class to create the Pydantic dataclass from.
|
||||
|
||||
Returns:
|
||||
A Pydantic dataclass.
|
||||
"""
|
||||
from ._internal._utils import is_model_class
|
||||
|
||||
if is_model_class(cls):
|
||||
raise PydanticUserError(
|
||||
f'Cannot create a Pydantic dataclass from {cls.__name__} as it is already a Pydantic model',
|
||||
code='dataclass-on-model',
|
||||
)
|
||||
|
||||
original_cls = cls
|
||||
|
||||
# we warn on conflicting config specifications, but only if the class doesn't have a dataclass base
|
||||
# because a dataclass base might provide a __pydantic_config__ attribute that we don't want to warn about
|
||||
has_dataclass_base = any(dataclasses.is_dataclass(base) for base in cls.__bases__)
|
||||
if not has_dataclass_base and config is not None and hasattr(cls, '__pydantic_config__'):
|
||||
warn(
|
||||
f'`config` is set via both the `dataclass` decorator and `__pydantic_config__` for dataclass {cls.__name__}. '
|
||||
f'The `config` specification from `dataclass` decorator will take priority.',
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# if config is not explicitly provided, try to read it from the type
|
||||
config_dict = config if config is not None else getattr(cls, '__pydantic_config__', None)
|
||||
config_wrapper = _config.ConfigWrapper(config_dict)
|
||||
decorators = _decorators.DecoratorInfos.build(cls)
|
||||
|
||||
# Keep track of the original __doc__ so that we can restore it after applying the dataclasses decorator
|
||||
# Otherwise, classes with no __doc__ will have their signature added into the JSON schema description,
|
||||
# since dataclasses.dataclass will set this as the __doc__
|
||||
original_doc = cls.__doc__
|
||||
|
||||
if _pydantic_dataclasses.is_builtin_dataclass(cls):
|
||||
# Don't preserve the docstring for vanilla dataclasses, as it may include the signature
|
||||
# This matches v1 behavior, and there was an explicit test for it
|
||||
original_doc = None
|
||||
|
||||
# We don't want to add validation to the existing std lib dataclass, so we will subclass it
|
||||
# If the class is generic, we need to make sure the subclass also inherits from Generic
|
||||
# with all the same parameters.
|
||||
bases = (cls,)
|
||||
if issubclass(cls, Generic):
|
||||
generic_base = Generic[cls.__parameters__] # type: ignore
|
||||
bases = bases + (generic_base,)
|
||||
cls = types.new_class(cls.__name__, bases)
|
||||
|
||||
make_pydantic_fields_compatible(cls)
|
||||
|
||||
# Respect frozen setting from dataclass constructor and fallback to config setting if not provided
|
||||
if frozen is not None:
|
||||
frozen_ = frozen
|
||||
if config_wrapper.frozen:
|
||||
# It's not recommended to define both, as the setting from the dataclass decorator will take priority.
|
||||
warn(
|
||||
f'`frozen` is set via both the `dataclass` decorator and `config` for dataclass {cls.__name__!r}.'
|
||||
'This is not recommended. The `frozen` specification on `dataclass` will take priority.',
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
else:
|
||||
frozen_ = config_wrapper.frozen or False
|
||||
|
||||
cls = dataclasses.dataclass( # type: ignore[call-overload]
|
||||
cls,
|
||||
# the value of init here doesn't affect anything except that it makes it easier to generate a signature
|
||||
init=True,
|
||||
repr=repr,
|
||||
eq=eq,
|
||||
order=order,
|
||||
unsafe_hash=unsafe_hash,
|
||||
frozen=frozen_,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
cls.__pydantic_decorators__ = decorators # type: ignore
|
||||
cls.__doc__ = original_doc
|
||||
cls.__module__ = original_cls.__module__
|
||||
cls.__qualname__ = original_cls.__qualname__
|
||||
cls.__pydantic_complete__ = False # `complete_dataclass` will set it to `True` if successful.
|
||||
# TODO `parent_namespace` is currently None, but we could do the same thing as Pydantic models:
|
||||
# fetch the parent ns using `parent_frame_namespace` (if the dataclass was defined in a function),
|
||||
# and possibly cache it (see the `__pydantic_parent_namespace__` logic for models).
|
||||
_pydantic_dataclasses.complete_dataclass(cls, config_wrapper, raise_errors=False)
|
||||
return cls
|
||||
|
||||
return create_dataclass if _cls is None else create_dataclass(_cls)
|
||||
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
|
||||
if (3, 8) <= sys.version_info < (3, 11):
|
||||
# Monkeypatch dataclasses.InitVar so that typing doesn't error if it occurs as a type when evaluating type hints
|
||||
# Starting in 3.11, typing.get_type_hints will not raise an error if the retrieved type hints are not callable.
|
||||
|
||||
def _call_initvar(*args: Any, **kwargs: Any) -> NoReturn:
|
||||
"""This function does nothing but raise an error that is as similar as possible to what you'd get
|
||||
if you were to try calling `InitVar[int]()` without this monkeypatch. The whole purpose is just
|
||||
to ensure typing._type_check does not error if the type hint evaluates to `InitVar[<parameter>]`.
|
||||
"""
|
||||
raise TypeError("'InitVar' object is not callable")
|
||||
|
||||
dataclasses.InitVar.__call__ = _call_initvar
|
||||
|
||||
|
||||
def rebuild_dataclass(
|
||||
cls: type[PydanticDataclass],
|
||||
*,
|
||||
force: bool = False,
|
||||
raise_errors: bool = True,
|
||||
_parent_namespace_depth: int = 2,
|
||||
_types_namespace: MappingNamespace | None = None,
|
||||
) -> bool | None:
|
||||
"""Try to rebuild the pydantic-core schema for the dataclass.
|
||||
|
||||
This may be necessary when one of the annotations is a ForwardRef which could not be resolved during
|
||||
the initial attempt to build the schema, and automatic rebuilding fails.
|
||||
|
||||
This is analogous to `BaseModel.model_rebuild`.
|
||||
|
||||
Args:
|
||||
cls: The class to rebuild the pydantic-core schema for.
|
||||
force: Whether to force the rebuilding of the schema, defaults to `False`.
|
||||
raise_errors: Whether to raise errors, defaults to `True`.
|
||||
_parent_namespace_depth: The depth level of the parent namespace, defaults to 2.
|
||||
_types_namespace: The types namespace, defaults to `None`.
|
||||
|
||||
Returns:
|
||||
Returns `None` if the schema is already "complete" and rebuilding was not required.
|
||||
If rebuilding _was_ required, returns `True` if rebuilding was successful, otherwise `False`.
|
||||
"""
|
||||
if not force and cls.__pydantic_complete__:
|
||||
return None
|
||||
|
||||
if '__pydantic_core_schema__' in cls.__dict__:
|
||||
delattr(cls, '__pydantic_core_schema__') # delete cached value to ensure full rebuild happens
|
||||
|
||||
if _types_namespace is not None:
|
||||
rebuild_ns = _types_namespace
|
||||
elif _parent_namespace_depth > 0:
|
||||
rebuild_ns = _typing_extra.parent_frame_namespace(parent_depth=_parent_namespace_depth, force=True) or {}
|
||||
else:
|
||||
rebuild_ns = {}
|
||||
|
||||
ns_resolver = _namespace_utils.NsResolver(
|
||||
parent_namespace=rebuild_ns,
|
||||
)
|
||||
|
||||
return _pydantic_dataclasses.complete_dataclass(
|
||||
cls,
|
||||
_config.ConfigWrapper(cls.__pydantic_config__, check=False),
|
||||
raise_errors=raise_errors,
|
||||
ns_resolver=ns_resolver,
|
||||
# We could provide a different config instead (with `'defer_build'` set to `True`)
|
||||
# of this explicit `_force_build` argument, but because config can come from the
|
||||
# decorator parameter or the `__pydantic_config__` attribute, `complete_dataclass`
|
||||
# will overwrite `__pydantic_config__` with the provided config above:
|
||||
_force_build=True,
|
||||
)
|
||||
|
||||
|
||||
def is_pydantic_dataclass(class_: type[Any], /) -> TypeGuard[type[PydanticDataclass]]:
|
||||
"""Whether a class is a pydantic dataclass.
|
||||
|
||||
Args:
|
||||
class_: The class.
|
||||
|
||||
Returns:
|
||||
`True` if the class is a pydantic dataclass, `False` otherwise.
|
||||
"""
|
||||
try:
|
||||
return '__pydantic_validator__' in class_.__dict__ and dataclasses.is_dataclass(class_)
|
||||
except AttributeError:
|
||||
return False
|
||||
@@ -0,0 +1,5 @@
|
||||
"""The `datetime_parse` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
@@ -0,0 +1,5 @@
|
||||
"""The `decorator` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
@@ -0,0 +1,256 @@
|
||||
"""Old `@validator` and `@root_validator` function validators from V1."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from functools import partial, partialmethod
|
||||
from types import FunctionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, overload
|
||||
from warnings import warn
|
||||
|
||||
from typing_extensions import Literal, Protocol, TypeAlias, deprecated
|
||||
|
||||
from .._internal import _decorators, _decorators_v1
|
||||
from ..errors import PydanticUserError
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
|
||||
_ALLOW_REUSE_WARNING_MESSAGE = '`allow_reuse` is deprecated and will be ignored; it should no longer be necessary'
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _OnlyValueValidatorClsMethod(Protocol):
|
||||
def __call__(self, __cls: Any, __value: Any) -> Any: ...
|
||||
|
||||
class _V1ValidatorWithValuesClsMethod(Protocol):
|
||||
def __call__(self, __cls: Any, __value: Any, values: dict[str, Any]) -> Any: ...
|
||||
|
||||
class _V1ValidatorWithValuesKwOnlyClsMethod(Protocol):
|
||||
def __call__(self, __cls: Any, __value: Any, *, values: dict[str, Any]) -> Any: ...
|
||||
|
||||
class _V1ValidatorWithKwargsClsMethod(Protocol):
|
||||
def __call__(self, __cls: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
class _V1ValidatorWithValuesAndKwargsClsMethod(Protocol):
|
||||
def __call__(self, __cls: Any, values: dict[str, Any], **kwargs: Any) -> Any: ...
|
||||
|
||||
class _V1RootValidatorClsMethod(Protocol):
|
||||
def __call__(
|
||||
self, __cls: Any, __values: _decorators_v1.RootValidatorValues
|
||||
) -> _decorators_v1.RootValidatorValues: ...
|
||||
|
||||
V1Validator = Union[
|
||||
_OnlyValueValidatorClsMethod,
|
||||
_V1ValidatorWithValuesClsMethod,
|
||||
_V1ValidatorWithValuesKwOnlyClsMethod,
|
||||
_V1ValidatorWithKwargsClsMethod,
|
||||
_V1ValidatorWithValuesAndKwargsClsMethod,
|
||||
_decorators_v1.V1ValidatorWithValues,
|
||||
_decorators_v1.V1ValidatorWithValuesKwOnly,
|
||||
_decorators_v1.V1ValidatorWithKwargs,
|
||||
_decorators_v1.V1ValidatorWithValuesAndKwargs,
|
||||
]
|
||||
|
||||
V1RootValidator = Union[
|
||||
_V1RootValidatorClsMethod,
|
||||
_decorators_v1.V1RootValidatorFunction,
|
||||
]
|
||||
|
||||
_PartialClsOrStaticMethod: TypeAlias = Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any]]
|
||||
|
||||
# Allow both a V1 (assumed pre=False) or V2 (assumed mode='after') validator
|
||||
# We lie to type checkers and say we return the same thing we get
|
||||
# but in reality we return a proxy object that _mostly_ behaves like the wrapped thing
|
||||
_V1ValidatorType = TypeVar('_V1ValidatorType', V1Validator, _PartialClsOrStaticMethod)
|
||||
_V1RootValidatorFunctionType = TypeVar(
|
||||
'_V1RootValidatorFunctionType',
|
||||
_decorators_v1.V1RootValidatorFunction,
|
||||
_V1RootValidatorClsMethod,
|
||||
_PartialClsOrStaticMethod,
|
||||
)
|
||||
else:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
|
||||
|
||||
@deprecated(
|
||||
'Pydantic V1 style `@validator` validators are deprecated.'
|
||||
' You should migrate to Pydantic V2 style `@field_validator` validators,'
|
||||
' see the migration guide for more details',
|
||||
category=None,
|
||||
)
|
||||
def validator(
|
||||
__field: str,
|
||||
*fields: str,
|
||||
pre: bool = False,
|
||||
each_item: bool = False,
|
||||
always: bool = False,
|
||||
check_fields: bool | None = None,
|
||||
allow_reuse: bool = False,
|
||||
) -> Callable[[_V1ValidatorType], _V1ValidatorType]:
|
||||
"""Decorate methods on the class indicating that they should be used to validate fields.
|
||||
|
||||
Args:
|
||||
__field (str): The first field the validator should be called on; this is separate
|
||||
from `fields` to ensure an error is raised if you don't pass at least one.
|
||||
*fields (str): Additional field(s) the validator should be called on.
|
||||
pre (bool, optional): Whether this validator should be called before the standard
|
||||
validators (else after). Defaults to False.
|
||||
each_item (bool, optional): For complex objects (sets, lists etc.) whether to validate
|
||||
individual elements rather than the whole object. Defaults to False.
|
||||
always (bool, optional): Whether this method and other validators should be called even if
|
||||
the value is missing. Defaults to False.
|
||||
check_fields (bool | None, optional): Whether to check that the fields actually exist on the model.
|
||||
Defaults to None.
|
||||
allow_reuse (bool, optional): Whether to track and raise an error if another validator refers to
|
||||
the decorated function. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Callable: A decorator that can be used to decorate a
|
||||
function to be used as a validator.
|
||||
"""
|
||||
warn(
|
||||
'Pydantic V1 style `@validator` validators are deprecated.'
|
||||
' You should migrate to Pydantic V2 style `@field_validator` validators,'
|
||||
' see the migration guide for more details',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if allow_reuse is True: # pragma: no cover
|
||||
warn(_ALLOW_REUSE_WARNING_MESSAGE, DeprecationWarning)
|
||||
fields = __field, *fields
|
||||
if isinstance(fields[0], FunctionType):
|
||||
raise PydanticUserError(
|
||||
'`@validator` should be used with fields and keyword arguments, not bare. '
|
||||
"E.g. usage should be `@validator('<field_name>', ...)`",
|
||||
code='validator-no-fields',
|
||||
)
|
||||
elif not all(isinstance(field, str) for field in fields):
|
||||
raise PydanticUserError(
|
||||
'`@validator` fields should be passed as separate string args. '
|
||||
"E.g. usage should be `@validator('<field_name_1>', '<field_name_2>', ...)`",
|
||||
code='validator-invalid-fields',
|
||||
)
|
||||
|
||||
mode: Literal['before', 'after'] = 'before' if pre is True else 'after'
|
||||
|
||||
def dec(f: Any) -> _decorators.PydanticDescriptorProxy[Any]:
|
||||
if _decorators.is_instance_method_from_sig(f):
|
||||
raise PydanticUserError(
|
||||
'`@validator` cannot be applied to instance methods', code='validator-instance-method'
|
||||
)
|
||||
# auto apply the @classmethod decorator
|
||||
f = _decorators.ensure_classmethod_based_on_signature(f)
|
||||
wrap = _decorators_v1.make_generic_v1_field_validator
|
||||
validator_wrapper_info = _decorators.ValidatorDecoratorInfo(
|
||||
fields=fields,
|
||||
mode=mode,
|
||||
each_item=each_item,
|
||||
always=always,
|
||||
check_fields=check_fields,
|
||||
)
|
||||
return _decorators.PydanticDescriptorProxy(f, validator_wrapper_info, shim=wrap)
|
||||
|
||||
return dec # type: ignore[return-value]
|
||||
|
||||
|
||||
@overload
|
||||
def root_validator(
|
||||
*,
|
||||
# if you don't specify `pre` the default is `pre=False`
|
||||
# which means you need to specify `skip_on_failure=True`
|
||||
skip_on_failure: Literal[True],
|
||||
allow_reuse: bool = ...,
|
||||
) -> Callable[
|
||||
[_V1RootValidatorFunctionType],
|
||||
_V1RootValidatorFunctionType,
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def root_validator(
|
||||
*,
|
||||
# if you specify `pre=True` then you don't need to specify
|
||||
# `skip_on_failure`, in fact it is not allowed as an argument!
|
||||
pre: Literal[True],
|
||||
allow_reuse: bool = ...,
|
||||
) -> Callable[
|
||||
[_V1RootValidatorFunctionType],
|
||||
_V1RootValidatorFunctionType,
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def root_validator(
|
||||
*,
|
||||
# if you explicitly specify `pre=False` then you
|
||||
# MUST specify `skip_on_failure=True`
|
||||
pre: Literal[False],
|
||||
skip_on_failure: Literal[True],
|
||||
allow_reuse: bool = ...,
|
||||
) -> Callable[
|
||||
[_V1RootValidatorFunctionType],
|
||||
_V1RootValidatorFunctionType,
|
||||
]: ...
|
||||
|
||||
|
||||
@deprecated(
|
||||
'Pydantic V1 style `@root_validator` validators are deprecated.'
|
||||
' You should migrate to Pydantic V2 style `@model_validator` validators,'
|
||||
' see the migration guide for more details',
|
||||
category=None,
|
||||
)
|
||||
def root_validator(
|
||||
*__args,
|
||||
pre: bool = False,
|
||||
skip_on_failure: bool = False,
|
||||
allow_reuse: bool = False,
|
||||
) -> Any:
|
||||
"""Decorate methods on a model indicating that they should be used to validate (and perhaps
|
||||
modify) data either before or after standard model parsing/validation is performed.
|
||||
|
||||
Args:
|
||||
pre (bool, optional): Whether this validator should be called before the standard
|
||||
validators (else after). Defaults to False.
|
||||
skip_on_failure (bool, optional): Whether to stop validation and return as soon as a
|
||||
failure is encountered. Defaults to False.
|
||||
allow_reuse (bool, optional): Whether to track and raise an error if another validator
|
||||
refers to the decorated function. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Any: A decorator that can be used to decorate a function to be used as a root_validator.
|
||||
"""
|
||||
warn(
|
||||
'Pydantic V1 style `@root_validator` validators are deprecated.'
|
||||
' You should migrate to Pydantic V2 style `@model_validator` validators,'
|
||||
' see the migration guide for more details',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if __args:
|
||||
# Ensure a nice error is raised if someone attempts to use the bare decorator
|
||||
return root_validator()(*__args) # type: ignore
|
||||
|
||||
if allow_reuse is True: # pragma: no cover
|
||||
warn(_ALLOW_REUSE_WARNING_MESSAGE, DeprecationWarning)
|
||||
mode: Literal['before', 'after'] = 'before' if pre is True else 'after'
|
||||
if pre is False and skip_on_failure is not True:
|
||||
raise PydanticUserError(
|
||||
'If you use `@root_validator` with pre=False (the default) you MUST specify `skip_on_failure=True`.'
|
||||
' Note that `@root_validator` is deprecated and should be replaced with `@model_validator`.',
|
||||
code='root-validator-pre-skip',
|
||||
)
|
||||
|
||||
wrap = partial(_decorators_v1.make_v1_generic_root_validator, pre=pre)
|
||||
|
||||
def dec(f: Callable[..., Any] | classmethod[Any, Any, Any] | staticmethod[Any, Any]) -> Any:
|
||||
if _decorators.is_instance_method_from_sig(f):
|
||||
raise TypeError('`@root_validator` cannot be applied to instance methods')
|
||||
# auto apply the @classmethod decorator
|
||||
res = _decorators.ensure_classmethod_based_on_signature(f)
|
||||
dec_info = _decorators.RootValidatorDecoratorInfo(mode=mode)
|
||||
return _decorators.PydanticDescriptorProxy(res, dec_info, shim=wrap)
|
||||
|
||||
return dec
|
||||
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from typing_extensions import Literal, deprecated
|
||||
|
||||
from .._internal import _config
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
|
||||
__all__ = 'BaseConfig', 'Extra'
|
||||
|
||||
|
||||
class _ConfigMetaclass(type):
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
try:
|
||||
obj = _config.config_defaults[item]
|
||||
warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning)
|
||||
return obj
|
||||
except KeyError as exc:
|
||||
raise AttributeError(f"type object '{self.__name__}' has no attribute {exc}") from exc
|
||||
|
||||
|
||||
@deprecated('BaseConfig is deprecated. Use the `pydantic.ConfigDict` instead.', category=PydanticDeprecatedSince20)
|
||||
class BaseConfig(metaclass=_ConfigMetaclass):
|
||||
"""This class is only retained for backwards compatibility.
|
||||
|
||||
!!! Warning "Deprecated"
|
||||
BaseConfig is deprecated. Use the [`pydantic.ConfigDict`][pydantic.ConfigDict] instead.
|
||||
"""
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
try:
|
||||
obj = super().__getattribute__(item)
|
||||
warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning)
|
||||
return obj
|
||||
except AttributeError as exc:
|
||||
try:
|
||||
return getattr(type(self), item)
|
||||
except AttributeError:
|
||||
# re-raising changes the displayed text to reflect that `self` is not a type
|
||||
raise AttributeError(str(exc)) from exc
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning)
|
||||
return super().__init_subclass__(**kwargs)
|
||||
|
||||
|
||||
class _ExtraMeta(type):
|
||||
def __getattribute__(self, __name: str) -> Any:
|
||||
# The @deprecated decorator accesses other attributes, so we only emit a warning for the expected ones
|
||||
if __name in {'allow', 'ignore', 'forbid'}:
|
||||
warnings.warn(
|
||||
"`pydantic.config.Extra` is deprecated, use literal values instead (e.g. `extra='allow'`)",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return super().__getattribute__(__name)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"Extra is deprecated. Use literal values instead (e.g. `extra='allow'`)", category=PydanticDeprecatedSince20
|
||||
)
|
||||
class Extra(metaclass=_ExtraMeta):
|
||||
allow: Literal['allow'] = 'allow'
|
||||
ignore: Literal['ignore'] = 'ignore'
|
||||
forbid: Literal['forbid'] = 'forbid'
|
||||
@@ -0,0 +1,224 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import typing
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any, Tuple
|
||||
|
||||
import typing_extensions
|
||||
|
||||
from .._internal import (
|
||||
_model_construction,
|
||||
_typing_extra,
|
||||
_utils,
|
||||
)
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .. import BaseModel
|
||||
from .._internal._utils import AbstractSetIntStr, MappingIntStrAny
|
||||
|
||||
AnyClassMethod = classmethod[Any, Any, Any]
|
||||
TupleGenerator = typing.Generator[Tuple[str, Any], None, None]
|
||||
Model = typing.TypeVar('Model', bound='BaseModel')
|
||||
# should be `set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None`, but mypy can't cope
|
||||
IncEx: typing_extensions.TypeAlias = 'set[int] | set[str] | dict[int, Any] | dict[str, Any] | None'
|
||||
|
||||
_object_setattr = _model_construction.object_setattr
|
||||
|
||||
|
||||
def _iter(
|
||||
self: BaseModel,
|
||||
to_dict: bool = False,
|
||||
by_alias: bool = False,
|
||||
include: AbstractSetIntStr | MappingIntStrAny | None = None,
|
||||
exclude: AbstractSetIntStr | MappingIntStrAny | None = None,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
) -> TupleGenerator:
|
||||
# Merge field set excludes with explicit exclude parameter with explicit overriding field set options.
|
||||
# The extra "is not None" guards are not logically necessary but optimizes performance for the simple case.
|
||||
if exclude is not None:
|
||||
exclude = _utils.ValueItems.merge(
|
||||
{k: v.exclude for k, v in self.__pydantic_fields__.items() if v.exclude is not None}, exclude
|
||||
)
|
||||
|
||||
if include is not None:
|
||||
include = _utils.ValueItems.merge({k: True for k in self.__pydantic_fields__}, include, intersect=True)
|
||||
|
||||
allowed_keys = _calculate_keys(self, include=include, exclude=exclude, exclude_unset=exclude_unset) # type: ignore
|
||||
if allowed_keys is None and not (to_dict or by_alias or exclude_unset or exclude_defaults or exclude_none):
|
||||
# huge boost for plain _iter()
|
||||
yield from self.__dict__.items()
|
||||
if self.__pydantic_extra__:
|
||||
yield from self.__pydantic_extra__.items()
|
||||
return
|
||||
|
||||
value_exclude = _utils.ValueItems(self, exclude) if exclude is not None else None
|
||||
value_include = _utils.ValueItems(self, include) if include is not None else None
|
||||
|
||||
if self.__pydantic_extra__ is None:
|
||||
items = self.__dict__.items()
|
||||
else:
|
||||
items = list(self.__dict__.items()) + list(self.__pydantic_extra__.items())
|
||||
|
||||
for field_key, v in items:
|
||||
if (allowed_keys is not None and field_key not in allowed_keys) or (exclude_none and v is None):
|
||||
continue
|
||||
|
||||
if exclude_defaults:
|
||||
try:
|
||||
field = self.__pydantic_fields__[field_key]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
if not field.is_required() and field.default == v:
|
||||
continue
|
||||
|
||||
if by_alias and field_key in self.__pydantic_fields__:
|
||||
dict_key = self.__pydantic_fields__[field_key].alias or field_key
|
||||
else:
|
||||
dict_key = field_key
|
||||
|
||||
if to_dict or value_include or value_exclude:
|
||||
v = _get_value(
|
||||
type(self),
|
||||
v,
|
||||
to_dict=to_dict,
|
||||
by_alias=by_alias,
|
||||
include=value_include and value_include.for_element(field_key),
|
||||
exclude=value_exclude and value_exclude.for_element(field_key),
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
yield dict_key, v
|
||||
|
||||
|
||||
def _copy_and_set_values(
|
||||
self: Model,
|
||||
values: dict[str, Any],
|
||||
fields_set: set[str],
|
||||
extra: dict[str, Any] | None = None,
|
||||
private: dict[str, Any] | None = None,
|
||||
*,
|
||||
deep: bool, # UP006
|
||||
) -> Model:
|
||||
if deep:
|
||||
# chances of having empty dict here are quite low for using smart_deepcopy
|
||||
values = deepcopy(values)
|
||||
extra = deepcopy(extra)
|
||||
private = deepcopy(private)
|
||||
|
||||
cls = self.__class__
|
||||
m = cls.__new__(cls)
|
||||
_object_setattr(m, '__dict__', values)
|
||||
_object_setattr(m, '__pydantic_extra__', extra)
|
||||
_object_setattr(m, '__pydantic_fields_set__', fields_set)
|
||||
_object_setattr(m, '__pydantic_private__', private)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
@typing.no_type_check
|
||||
def _get_value(
|
||||
cls: type[BaseModel],
|
||||
v: Any,
|
||||
to_dict: bool,
|
||||
by_alias: bool,
|
||||
include: AbstractSetIntStr | MappingIntStrAny | None,
|
||||
exclude: AbstractSetIntStr | MappingIntStrAny | None,
|
||||
exclude_unset: bool,
|
||||
exclude_defaults: bool,
|
||||
exclude_none: bool,
|
||||
) -> Any:
|
||||
from .. import BaseModel
|
||||
|
||||
if isinstance(v, BaseModel):
|
||||
if to_dict:
|
||||
return v.model_dump(
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
include=include, # type: ignore
|
||||
exclude=exclude, # type: ignore
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
else:
|
||||
return v.copy(include=include, exclude=exclude)
|
||||
|
||||
value_exclude = _utils.ValueItems(v, exclude) if exclude else None
|
||||
value_include = _utils.ValueItems(v, include) if include else None
|
||||
|
||||
if isinstance(v, dict):
|
||||
return {
|
||||
k_: _get_value(
|
||||
cls,
|
||||
v_,
|
||||
to_dict=to_dict,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
include=value_include and value_include.for_element(k_),
|
||||
exclude=value_exclude and value_exclude.for_element(k_),
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
for k_, v_ in v.items()
|
||||
if (not value_exclude or not value_exclude.is_excluded(k_))
|
||||
and (not value_include or value_include.is_included(k_))
|
||||
}
|
||||
|
||||
elif _utils.sequence_like(v):
|
||||
seq_args = (
|
||||
_get_value(
|
||||
cls,
|
||||
v_,
|
||||
to_dict=to_dict,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
include=value_include and value_include.for_element(i),
|
||||
exclude=value_exclude and value_exclude.for_element(i),
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
for i, v_ in enumerate(v)
|
||||
if (not value_exclude or not value_exclude.is_excluded(i))
|
||||
and (not value_include or value_include.is_included(i))
|
||||
)
|
||||
|
||||
return v.__class__(*seq_args) if _typing_extra.is_namedtuple(v.__class__) else v.__class__(seq_args)
|
||||
|
||||
elif isinstance(v, Enum) and getattr(cls.model_config, 'use_enum_values', False):
|
||||
return v.value
|
||||
|
||||
else:
|
||||
return v
|
||||
|
||||
|
||||
def _calculate_keys(
|
||||
self: BaseModel,
|
||||
include: MappingIntStrAny | None,
|
||||
exclude: MappingIntStrAny | None,
|
||||
exclude_unset: bool,
|
||||
update: typing.Dict[str, Any] | None = None, # noqa UP006
|
||||
) -> typing.AbstractSet[str] | None:
|
||||
if include is None and exclude is None and exclude_unset is False:
|
||||
return None
|
||||
|
||||
keys: typing.AbstractSet[str]
|
||||
if exclude_unset:
|
||||
keys = self.__pydantic_fields_set__.copy()
|
||||
else:
|
||||
keys = set(self.__dict__.keys())
|
||||
keys = keys | (self.__pydantic_extra__ or {}).keys()
|
||||
|
||||
if include is not None:
|
||||
keys &= include.keys()
|
||||
|
||||
if update:
|
||||
keys -= update.keys()
|
||||
|
||||
if exclude:
|
||||
keys -= {k for k, v in exclude.items() if _utils.ValueItems.is_true(v)}
|
||||
|
||||
return keys
|
||||
@@ -0,0 +1,283 @@
|
||||
import warnings
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from .._internal import _config, _typing_extra
|
||||
from ..alias_generators import to_pascal
|
||||
from ..errors import PydanticUserError
|
||||
from ..functional_validators import field_validator
|
||||
from ..main import BaseModel, create_model
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
|
||||
__all__ = ('validate_arguments',)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
AnyCallable = Callable[..., Any]
|
||||
|
||||
AnyCallableT = TypeVar('AnyCallableT', bound=AnyCallable)
|
||||
ConfigType = Union[None, Type[Any], Dict[str, Any]]
|
||||
|
||||
|
||||
@overload
|
||||
def validate_arguments(
|
||||
func: None = None, *, config: 'ConfigType' = None
|
||||
) -> Callable[['AnyCallableT'], 'AnyCallableT']: ...
|
||||
|
||||
|
||||
@overload
|
||||
def validate_arguments(func: 'AnyCallableT') -> 'AnyCallableT': ...
|
||||
|
||||
|
||||
@deprecated(
|
||||
'The `validate_arguments` method is deprecated; use `validate_call` instead.',
|
||||
category=None,
|
||||
)
|
||||
def validate_arguments(func: Optional['AnyCallableT'] = None, *, config: 'ConfigType' = None) -> Any:
|
||||
"""Decorator to validate the arguments passed to a function."""
|
||||
warnings.warn(
|
||||
'The `validate_arguments` method is deprecated; use `validate_call` instead.',
|
||||
PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def validate(_func: 'AnyCallable') -> 'AnyCallable':
|
||||
vd = ValidatedFunction(_func, config)
|
||||
|
||||
@wraps(_func)
|
||||
def wrapper_function(*args: Any, **kwargs: Any) -> Any:
|
||||
return vd.call(*args, **kwargs)
|
||||
|
||||
wrapper_function.vd = vd # type: ignore
|
||||
wrapper_function.validate = vd.init_model_instance # type: ignore
|
||||
wrapper_function.raw_function = vd.raw_function # type: ignore
|
||||
wrapper_function.model = vd.model # type: ignore
|
||||
return wrapper_function
|
||||
|
||||
if func:
|
||||
return validate(func)
|
||||
else:
|
||||
return validate
|
||||
|
||||
|
||||
ALT_V_ARGS = 'v__args'
|
||||
ALT_V_KWARGS = 'v__kwargs'
|
||||
V_POSITIONAL_ONLY_NAME = 'v__positional_only'
|
||||
V_DUPLICATE_KWARGS = 'v__duplicate_kwargs'
|
||||
|
||||
|
||||
class ValidatedFunction:
|
||||
def __init__(self, function: 'AnyCallable', config: 'ConfigType'):
|
||||
from inspect import Parameter, signature
|
||||
|
||||
parameters: Mapping[str, Parameter] = signature(function).parameters
|
||||
|
||||
if parameters.keys() & {ALT_V_ARGS, ALT_V_KWARGS, V_POSITIONAL_ONLY_NAME, V_DUPLICATE_KWARGS}:
|
||||
raise PydanticUserError(
|
||||
f'"{ALT_V_ARGS}", "{ALT_V_KWARGS}", "{V_POSITIONAL_ONLY_NAME}" and "{V_DUPLICATE_KWARGS}" '
|
||||
f'are not permitted as argument names when using the "{validate_arguments.__name__}" decorator',
|
||||
code=None,
|
||||
)
|
||||
|
||||
self.raw_function = function
|
||||
self.arg_mapping: Dict[int, str] = {}
|
||||
self.positional_only_args: set[str] = set()
|
||||
self.v_args_name = 'args'
|
||||
self.v_kwargs_name = 'kwargs'
|
||||
|
||||
type_hints = _typing_extra.get_type_hints(function, include_extras=True)
|
||||
takes_args = False
|
||||
takes_kwargs = False
|
||||
fields: Dict[str, Tuple[Any, Any]] = {}
|
||||
for i, (name, p) in enumerate(parameters.items()):
|
||||
if p.annotation is p.empty:
|
||||
annotation = Any
|
||||
else:
|
||||
annotation = type_hints[name]
|
||||
|
||||
default = ... if p.default is p.empty else p.default
|
||||
if p.kind == Parameter.POSITIONAL_ONLY:
|
||||
self.arg_mapping[i] = name
|
||||
fields[name] = annotation, default
|
||||
fields[V_POSITIONAL_ONLY_NAME] = List[str], None
|
||||
self.positional_only_args.add(name)
|
||||
elif p.kind == Parameter.POSITIONAL_OR_KEYWORD:
|
||||
self.arg_mapping[i] = name
|
||||
fields[name] = annotation, default
|
||||
fields[V_DUPLICATE_KWARGS] = List[str], None
|
||||
elif p.kind == Parameter.KEYWORD_ONLY:
|
||||
fields[name] = annotation, default
|
||||
elif p.kind == Parameter.VAR_POSITIONAL:
|
||||
self.v_args_name = name
|
||||
fields[name] = Tuple[annotation, ...], None
|
||||
takes_args = True
|
||||
else:
|
||||
assert p.kind == Parameter.VAR_KEYWORD, p.kind
|
||||
self.v_kwargs_name = name
|
||||
fields[name] = Dict[str, annotation], None
|
||||
takes_kwargs = True
|
||||
|
||||
# these checks avoid a clash between "args" and a field with that name
|
||||
if not takes_args and self.v_args_name in fields:
|
||||
self.v_args_name = ALT_V_ARGS
|
||||
|
||||
# same with "kwargs"
|
||||
if not takes_kwargs and self.v_kwargs_name in fields:
|
||||
self.v_kwargs_name = ALT_V_KWARGS
|
||||
|
||||
if not takes_args:
|
||||
# we add the field so validation below can raise the correct exception
|
||||
fields[self.v_args_name] = List[Any], None
|
||||
|
||||
if not takes_kwargs:
|
||||
# same with kwargs
|
||||
fields[self.v_kwargs_name] = Dict[Any, Any], None
|
||||
|
||||
self.create_model(fields, takes_args, takes_kwargs, config)
|
||||
|
||||
def init_model_instance(self, *args: Any, **kwargs: Any) -> BaseModel:
|
||||
values = self.build_values(args, kwargs)
|
||||
return self.model(**values)
|
||||
|
||||
def call(self, *args: Any, **kwargs: Any) -> Any:
|
||||
m = self.init_model_instance(*args, **kwargs)
|
||||
return self.execute(m)
|
||||
|
||||
def build_values(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
values: Dict[str, Any] = {}
|
||||
if args:
|
||||
arg_iter = enumerate(args)
|
||||
while True:
|
||||
try:
|
||||
i, a = next(arg_iter)
|
||||
except StopIteration:
|
||||
break
|
||||
arg_name = self.arg_mapping.get(i)
|
||||
if arg_name is not None:
|
||||
values[arg_name] = a
|
||||
else:
|
||||
values[self.v_args_name] = [a] + [a for _, a in arg_iter]
|
||||
break
|
||||
|
||||
var_kwargs: Dict[str, Any] = {}
|
||||
wrong_positional_args = []
|
||||
duplicate_kwargs = []
|
||||
fields_alias = [
|
||||
field.alias
|
||||
for name, field in self.model.__pydantic_fields__.items()
|
||||
if name not in (self.v_args_name, self.v_kwargs_name)
|
||||
]
|
||||
non_var_fields = set(self.model.__pydantic_fields__) - {self.v_args_name, self.v_kwargs_name}
|
||||
for k, v in kwargs.items():
|
||||
if k in non_var_fields or k in fields_alias:
|
||||
if k in self.positional_only_args:
|
||||
wrong_positional_args.append(k)
|
||||
if k in values:
|
||||
duplicate_kwargs.append(k)
|
||||
values[k] = v
|
||||
else:
|
||||
var_kwargs[k] = v
|
||||
|
||||
if var_kwargs:
|
||||
values[self.v_kwargs_name] = var_kwargs
|
||||
if wrong_positional_args:
|
||||
values[V_POSITIONAL_ONLY_NAME] = wrong_positional_args
|
||||
if duplicate_kwargs:
|
||||
values[V_DUPLICATE_KWARGS] = duplicate_kwargs
|
||||
return values
|
||||
|
||||
def execute(self, m: BaseModel) -> Any:
|
||||
d = {
|
||||
k: v
|
||||
for k, v in m.__dict__.items()
|
||||
if k in m.__pydantic_fields_set__ or m.__pydantic_fields__[k].default_factory
|
||||
}
|
||||
var_kwargs = d.pop(self.v_kwargs_name, {})
|
||||
|
||||
if self.v_args_name in d:
|
||||
args_: List[Any] = []
|
||||
in_kwargs = False
|
||||
kwargs = {}
|
||||
for name, value in d.items():
|
||||
if in_kwargs:
|
||||
kwargs[name] = value
|
||||
elif name == self.v_args_name:
|
||||
args_ += value
|
||||
in_kwargs = True
|
||||
else:
|
||||
args_.append(value)
|
||||
return self.raw_function(*args_, **kwargs, **var_kwargs)
|
||||
elif self.positional_only_args:
|
||||
args_ = []
|
||||
kwargs = {}
|
||||
for name, value in d.items():
|
||||
if name in self.positional_only_args:
|
||||
args_.append(value)
|
||||
else:
|
||||
kwargs[name] = value
|
||||
return self.raw_function(*args_, **kwargs, **var_kwargs)
|
||||
else:
|
||||
return self.raw_function(**d, **var_kwargs)
|
||||
|
||||
def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool, config: 'ConfigType') -> None:
|
||||
pos_args = len(self.arg_mapping)
|
||||
|
||||
config_wrapper = _config.ConfigWrapper(config)
|
||||
|
||||
if config_wrapper.alias_generator:
|
||||
raise PydanticUserError(
|
||||
'Setting the "alias_generator" property on custom Config for '
|
||||
'@validate_arguments is not yet supported, please remove.',
|
||||
code=None,
|
||||
)
|
||||
if config_wrapper.extra is None:
|
||||
config_wrapper.config_dict['extra'] = 'forbid'
|
||||
|
||||
class DecoratorBaseModel(BaseModel):
|
||||
@field_validator(self.v_args_name, check_fields=False)
|
||||
@classmethod
|
||||
def check_args(cls, v: Optional[List[Any]]) -> Optional[List[Any]]:
|
||||
if takes_args or v is None:
|
||||
return v
|
||||
|
||||
raise TypeError(f'{pos_args} positional arguments expected but {pos_args + len(v)} given')
|
||||
|
||||
@field_validator(self.v_kwargs_name, check_fields=False)
|
||||
@classmethod
|
||||
def check_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
if takes_kwargs or v is None:
|
||||
return v
|
||||
|
||||
plural = '' if len(v) == 1 else 's'
|
||||
keys = ', '.join(map(repr, v.keys()))
|
||||
raise TypeError(f'unexpected keyword argument{plural}: {keys}')
|
||||
|
||||
@field_validator(V_POSITIONAL_ONLY_NAME, check_fields=False)
|
||||
@classmethod
|
||||
def check_positional_only(cls, v: Optional[List[str]]) -> None:
|
||||
if v is None:
|
||||
return
|
||||
|
||||
plural = '' if len(v) == 1 else 's'
|
||||
keys = ', '.join(map(repr, v))
|
||||
raise TypeError(f'positional-only argument{plural} passed as keyword argument{plural}: {keys}')
|
||||
|
||||
@field_validator(V_DUPLICATE_KWARGS, check_fields=False)
|
||||
@classmethod
|
||||
def check_duplicate_kwargs(cls, v: Optional[List[str]]) -> None:
|
||||
if v is None:
|
||||
return
|
||||
|
||||
plural = '' if len(v) == 1 else 's'
|
||||
keys = ', '.join(map(repr, v))
|
||||
raise TypeError(f'multiple values for argument{plural}: {keys}')
|
||||
|
||||
model_config = config_wrapper.config_dict
|
||||
|
||||
self.model = create_model(to_pascal(self.raw_function.__name__), __base__=DecoratorBaseModel, **fields)
|
||||
@@ -0,0 +1,141 @@
|
||||
import datetime
|
||||
import warnings
|
||||
from collections import deque
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
||||
from pathlib import Path
|
||||
from re import Pattern
|
||||
from types import GeneratorType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Type, Union
|
||||
from uuid import UUID
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from .._internal._import_utils import import_cached_base_model
|
||||
from ..color import Color
|
||||
from ..networks import NameEmail
|
||||
from ..types import SecretBytes, SecretStr
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
|
||||
__all__ = 'pydantic_encoder', 'custom_pydantic_encoder', 'timedelta_isoformat'
|
||||
|
||||
|
||||
def isoformat(o: Union[datetime.date, datetime.time]) -> str:
|
||||
return o.isoformat()
|
||||
|
||||
|
||||
def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
|
||||
"""Encodes a Decimal as int of there's no exponent, otherwise float.
|
||||
|
||||
This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
|
||||
where a integer (but not int typed) is used. Encoding this as a float
|
||||
results in failed round-tripping between encode and parse.
|
||||
Our Id type is a prime example of this.
|
||||
|
||||
>>> decimal_encoder(Decimal("1.0"))
|
||||
1.0
|
||||
|
||||
>>> decimal_encoder(Decimal("1"))
|
||||
1
|
||||
"""
|
||||
exponent = dec_value.as_tuple().exponent
|
||||
if isinstance(exponent, int) and exponent >= 0:
|
||||
return int(dec_value)
|
||||
else:
|
||||
return float(dec_value)
|
||||
|
||||
|
||||
ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
|
||||
bytes: lambda o: o.decode(),
|
||||
Color: str,
|
||||
datetime.date: isoformat,
|
||||
datetime.datetime: isoformat,
|
||||
datetime.time: isoformat,
|
||||
datetime.timedelta: lambda td: td.total_seconds(),
|
||||
Decimal: decimal_encoder,
|
||||
Enum: lambda o: o.value,
|
||||
frozenset: list,
|
||||
deque: list,
|
||||
GeneratorType: list,
|
||||
IPv4Address: str,
|
||||
IPv4Interface: str,
|
||||
IPv4Network: str,
|
||||
IPv6Address: str,
|
||||
IPv6Interface: str,
|
||||
IPv6Network: str,
|
||||
NameEmail: str,
|
||||
Path: str,
|
||||
Pattern: lambda o: o.pattern,
|
||||
SecretBytes: str,
|
||||
SecretStr: str,
|
||||
set: list,
|
||||
UUID: str,
|
||||
}
|
||||
|
||||
|
||||
@deprecated(
|
||||
'`pydantic_encoder` is deprecated, use `pydantic_core.to_jsonable_python` instead.',
|
||||
category=None,
|
||||
)
|
||||
def pydantic_encoder(obj: Any) -> Any:
|
||||
warnings.warn(
|
||||
'`pydantic_encoder` is deprecated, use `pydantic_core.to_jsonable_python` instead.',
|
||||
category=PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
)
|
||||
from dataclasses import asdict, is_dataclass
|
||||
|
||||
BaseModel = import_cached_base_model()
|
||||
|
||||
if isinstance(obj, BaseModel):
|
||||
return obj.model_dump()
|
||||
elif is_dataclass(obj):
|
||||
return asdict(obj) # type: ignore
|
||||
|
||||
# Check the class type and its superclasses for a matching encoder
|
||||
for base in obj.__class__.__mro__[:-1]:
|
||||
try:
|
||||
encoder = ENCODERS_BY_TYPE[base]
|
||||
except KeyError:
|
||||
continue
|
||||
return encoder(obj)
|
||||
else: # We have exited the for loop without finding a suitable encoder
|
||||
raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable")
|
||||
|
||||
|
||||
# TODO: Add a suggested migration path once there is a way to use custom encoders
|
||||
@deprecated(
|
||||
'`custom_pydantic_encoder` is deprecated, use `BaseModel.model_dump` instead.',
|
||||
category=None,
|
||||
)
|
||||
def custom_pydantic_encoder(type_encoders: Dict[Any, Callable[[Type[Any]], Any]], obj: Any) -> Any:
|
||||
warnings.warn(
|
||||
'`custom_pydantic_encoder` is deprecated, use `BaseModel.model_dump` instead.',
|
||||
category=PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
)
|
||||
# Check the class type and its superclasses for a matching encoder
|
||||
for base in obj.__class__.__mro__[:-1]:
|
||||
try:
|
||||
encoder = type_encoders[base]
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
return encoder(obj)
|
||||
else: # We have exited the for loop without finding a suitable encoder
|
||||
return pydantic_encoder(obj)
|
||||
|
||||
|
||||
@deprecated('`timedelta_isoformat` is deprecated.', category=None)
|
||||
def timedelta_isoformat(td: datetime.timedelta) -> str:
|
||||
"""ISO 8601 encoding for Python timedelta object."""
|
||||
warnings.warn('`timedelta_isoformat` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2)
|
||||
minutes, seconds = divmod(td.seconds, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
return f'{"-" if td.days < 0 else ""}P{abs(td.days)}DT{hours:d}H{minutes:d}M{seconds:d}.{td.microseconds:06d}S'
|
||||
@@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pickle
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
|
||||
|
||||
class Protocol(str, Enum):
|
||||
json = 'json'
|
||||
pickle = 'pickle'
|
||||
|
||||
|
||||
@deprecated('`load_str_bytes` is deprecated.', category=None)
|
||||
def load_str_bytes(
|
||||
b: str | bytes,
|
||||
*,
|
||||
content_type: str | None = None,
|
||||
encoding: str = 'utf8',
|
||||
proto: Protocol | None = None,
|
||||
allow_pickle: bool = False,
|
||||
json_loads: Callable[[str], Any] = json.loads,
|
||||
) -> Any:
|
||||
warnings.warn('`load_str_bytes` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2)
|
||||
if proto is None and content_type:
|
||||
if content_type.endswith(('json', 'javascript')):
|
||||
pass
|
||||
elif allow_pickle and content_type.endswith('pickle'):
|
||||
proto = Protocol.pickle
|
||||
else:
|
||||
raise TypeError(f'Unknown content-type: {content_type}')
|
||||
|
||||
proto = proto or Protocol.json
|
||||
|
||||
if proto == Protocol.json:
|
||||
if isinstance(b, bytes):
|
||||
b = b.decode(encoding)
|
||||
return json_loads(b) # type: ignore
|
||||
elif proto == Protocol.pickle:
|
||||
if not allow_pickle:
|
||||
raise RuntimeError('Trying to decode with pickle with allow_pickle=False')
|
||||
bb = b if isinstance(b, bytes) else b.encode() # type: ignore
|
||||
return pickle.loads(bb)
|
||||
else:
|
||||
raise TypeError(f'Unknown protocol: {proto}')
|
||||
|
||||
|
||||
@deprecated('`load_file` is deprecated.', category=None)
|
||||
def load_file(
|
||||
path: str | Path,
|
||||
*,
|
||||
content_type: str | None = None,
|
||||
encoding: str = 'utf8',
|
||||
proto: Protocol | None = None,
|
||||
allow_pickle: bool = False,
|
||||
json_loads: Callable[[str], Any] = json.loads,
|
||||
) -> Any:
|
||||
warnings.warn('`load_file` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2)
|
||||
path = Path(path)
|
||||
b = path.read_bytes()
|
||||
if content_type is None:
|
||||
if path.suffix in ('.js', '.json'):
|
||||
proto = Protocol.json
|
||||
elif path.suffix == '.pkl':
|
||||
proto = Protocol.pickle
|
||||
|
||||
return load_str_bytes(
|
||||
b, proto=proto, content_type=content_type, encoding=encoding, allow_pickle=allow_pickle, json_loads=json_loads
|
||||
)
|
||||
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar, Union
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from ..json_schema import DEFAULT_REF_TEMPLATE, GenerateJsonSchema
|
||||
from ..type_adapter import TypeAdapter
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
|
||||
__all__ = 'parse_obj_as', 'schema_of', 'schema_json_of'
|
||||
|
||||
NameFactory = Union[str, Callable[[Type[Any]], str]]
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
@deprecated(
|
||||
'`parse_obj_as` is deprecated. Use `pydantic.TypeAdapter.validate_python` instead.',
|
||||
category=None,
|
||||
)
|
||||
def parse_obj_as(type_: type[T], obj: Any, type_name: NameFactory | None = None) -> T:
|
||||
warnings.warn(
|
||||
'`parse_obj_as` is deprecated. Use `pydantic.TypeAdapter.validate_python` instead.',
|
||||
category=PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
)
|
||||
if type_name is not None: # pragma: no cover
|
||||
warnings.warn(
|
||||
'The type_name parameter is deprecated. parse_obj_as no longer creates temporary models',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return TypeAdapter(type_).validate_python(obj)
|
||||
|
||||
|
||||
@deprecated(
|
||||
'`schema_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.',
|
||||
category=None,
|
||||
)
|
||||
def schema_of(
|
||||
type_: Any,
|
||||
*,
|
||||
title: NameFactory | None = None,
|
||||
by_alias: bool = True,
|
||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
|
||||
) -> dict[str, Any]:
|
||||
"""Generate a JSON schema (as dict) for the passed model or dynamically generated one."""
|
||||
warnings.warn(
|
||||
'`schema_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.',
|
||||
category=PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
)
|
||||
res = TypeAdapter(type_).json_schema(
|
||||
by_alias=by_alias,
|
||||
schema_generator=schema_generator,
|
||||
ref_template=ref_template,
|
||||
)
|
||||
if title is not None:
|
||||
if isinstance(title, str):
|
||||
res['title'] = title
|
||||
else:
|
||||
warnings.warn(
|
||||
'Passing a callable for the `title` parameter is deprecated and no longer supported',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
res['title'] = title(type_)
|
||||
return res
|
||||
|
||||
|
||||
@deprecated(
|
||||
'`schema_json_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.',
|
||||
category=None,
|
||||
)
|
||||
def schema_json_of(
|
||||
type_: Any,
|
||||
*,
|
||||
title: NameFactory | None = None,
|
||||
by_alias: bool = True,
|
||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
|
||||
**dumps_kwargs: Any,
|
||||
) -> str:
|
||||
"""Generate a JSON schema (as JSON) for the passed model or dynamically generated one."""
|
||||
warnings.warn(
|
||||
'`schema_json_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.',
|
||||
category=PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
)
|
||||
return json.dumps(
|
||||
schema_of(type_, title=title, by_alias=by_alias, ref_template=ref_template, schema_generator=schema_generator),
|
||||
**dumps_kwargs,
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
"""The `env_settings` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
@@ -0,0 +1,5 @@
|
||||
"""The `error_wrappers` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
@@ -0,0 +1,162 @@
|
||||
"""Pydantic-specific errors."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import re
|
||||
|
||||
from typing_extensions import Literal, Self
|
||||
|
||||
from ._migration import getattr_migration
|
||||
from .version import version_short
|
||||
|
||||
__all__ = (
|
||||
'PydanticUserError',
|
||||
'PydanticUndefinedAnnotation',
|
||||
'PydanticImportError',
|
||||
'PydanticSchemaGenerationError',
|
||||
'PydanticInvalidForJsonSchema',
|
||||
'PydanticErrorCodes',
|
||||
)
|
||||
|
||||
# We use this URL to allow for future flexibility about how we host the docs, while allowing for Pydantic
|
||||
# code in the while with "old" URLs to still work.
|
||||
# 'u' refers to "user errors" - e.g. errors caused by developers using pydantic, as opposed to validation errors.
|
||||
DEV_ERROR_DOCS_URL = f'https://errors.pydantic.dev/{version_short()}/u/'
|
||||
PydanticErrorCodes = Literal[
|
||||
'class-not-fully-defined',
|
||||
'custom-json-schema',
|
||||
'decorator-missing-field',
|
||||
'discriminator-no-field',
|
||||
'discriminator-alias-type',
|
||||
'discriminator-needs-literal',
|
||||
'discriminator-alias',
|
||||
'discriminator-validator',
|
||||
'callable-discriminator-no-tag',
|
||||
'typed-dict-version',
|
||||
'model-field-overridden',
|
||||
'model-field-missing-annotation',
|
||||
'config-both',
|
||||
'removed-kwargs',
|
||||
'circular-reference-schema',
|
||||
'invalid-for-json-schema',
|
||||
'json-schema-already-used',
|
||||
'base-model-instantiated',
|
||||
'undefined-annotation',
|
||||
'schema-for-unknown-type',
|
||||
'import-error',
|
||||
'create-model-field-definitions',
|
||||
'create-model-config-base',
|
||||
'validator-no-fields',
|
||||
'validator-invalid-fields',
|
||||
'validator-instance-method',
|
||||
'validator-input-type',
|
||||
'root-validator-pre-skip',
|
||||
'model-serializer-instance-method',
|
||||
'validator-field-config-info',
|
||||
'validator-v1-signature',
|
||||
'validator-signature',
|
||||
'field-serializer-signature',
|
||||
'model-serializer-signature',
|
||||
'multiple-field-serializers',
|
||||
'invalid-annotated-type',
|
||||
'type-adapter-config-unused',
|
||||
'root-model-extra',
|
||||
'unevaluable-type-annotation',
|
||||
'dataclass-init-false-extra-allow',
|
||||
'clashing-init-and-init-var',
|
||||
'model-config-invalid-field-name',
|
||||
'with-config-on-model',
|
||||
'dataclass-on-model',
|
||||
'validate-call-type',
|
||||
'unpack-typed-dict',
|
||||
'overlapping-unpack-typed-dict',
|
||||
'invalid-self-type',
|
||||
]
|
||||
|
||||
|
||||
class PydanticErrorMixin:
|
||||
"""A mixin class for common functionality shared by all Pydantic-specific errors.
|
||||
|
||||
Attributes:
|
||||
message: A message describing the error.
|
||||
code: An optional error code from PydanticErrorCodes enum.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, *, code: PydanticErrorCodes | None) -> None:
|
||||
self.message = message
|
||||
self.code = code
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.code is None:
|
||||
return self.message
|
||||
else:
|
||||
return f'{self.message}\n\nFor further information visit {DEV_ERROR_DOCS_URL}{self.code}'
|
||||
|
||||
|
||||
class PydanticUserError(PydanticErrorMixin, TypeError):
|
||||
"""An error raised due to incorrect use of Pydantic."""
|
||||
|
||||
|
||||
class PydanticUndefinedAnnotation(PydanticErrorMixin, NameError):
|
||||
"""A subclass of `NameError` raised when handling undefined annotations during `CoreSchema` generation.
|
||||
|
||||
Attributes:
|
||||
name: Name of the error.
|
||||
message: Description of the error.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, message: str) -> None:
|
||||
self.name = name
|
||||
super().__init__(message=message, code='undefined-annotation')
|
||||
|
||||
@classmethod
|
||||
def from_name_error(cls, name_error: NameError) -> Self:
|
||||
"""Convert a `NameError` to a `PydanticUndefinedAnnotation` error.
|
||||
|
||||
Args:
|
||||
name_error: `NameError` to be converted.
|
||||
|
||||
Returns:
|
||||
Converted `PydanticUndefinedAnnotation` error.
|
||||
"""
|
||||
try:
|
||||
name = name_error.name # type: ignore # python > 3.10
|
||||
except AttributeError:
|
||||
name = re.search(r".*'(.+?)'", str(name_error)).group(1) # type: ignore[union-attr]
|
||||
return cls(name=name, message=str(name_error))
|
||||
|
||||
|
||||
class PydanticImportError(PydanticErrorMixin, ImportError):
|
||||
"""An error raised when an import fails due to module changes between V1 and V2.
|
||||
|
||||
Attributes:
|
||||
message: Description of the error.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message, code='import-error')
|
||||
|
||||
|
||||
class PydanticSchemaGenerationError(PydanticUserError):
|
||||
"""An error raised during failures to generate a `CoreSchema` for some type.
|
||||
|
||||
Attributes:
|
||||
message: Description of the error.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message, code='schema-for-unknown-type')
|
||||
|
||||
|
||||
class PydanticInvalidForJsonSchema(PydanticUserError):
|
||||
"""An error raised during failures to generate a JSON schema for some `CoreSchema`.
|
||||
|
||||
Attributes:
|
||||
message: Description of the error.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message, code='invalid-for-json-schema')
|
||||
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
@@ -0,0 +1,10 @@
|
||||
"""The "experimental" module of pydantic contains potential new features that are subject to change."""
|
||||
|
||||
import warnings
|
||||
|
||||
from pydantic.warnings import PydanticExperimentalWarning
|
||||
|
||||
warnings.warn(
|
||||
'This module is experimental, its contents are subject to change and deprecation.',
|
||||
category=PydanticExperimentalWarning,
|
||||
)
|
||||
@@ -0,0 +1,669 @@
|
||||
"""Experimental pipeline API functionality. Be careful with this API, it's subject to change."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import operator
|
||||
import re
|
||||
import sys
|
||||
from collections import deque
|
||||
from collections.abc import Container
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from functools import cached_property, partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Pattern, Protocol, TypeVar, Union, overload
|
||||
|
||||
import annotated_types
|
||||
from typing_extensions import Annotated
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
|
||||
from pydantic._internal._internal_dataclass import slots_true as _slots_true
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
EllipsisType = type(Ellipsis)
|
||||
else:
|
||||
from types import EllipsisType
|
||||
|
||||
__all__ = ['validate_as', 'validate_as_deferred', 'transform']
|
||||
|
||||
_slots_frozen = {**_slots_true, 'frozen': True}
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _ValidateAs:
|
||||
tp: type[Any]
|
||||
strict: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ValidateAsDefer:
|
||||
func: Callable[[], type[Any]]
|
||||
|
||||
@cached_property
|
||||
def tp(self) -> type[Any]:
|
||||
return self.func()
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _Transform:
|
||||
func: Callable[[Any], Any]
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _PipelineOr:
|
||||
left: _Pipeline[Any, Any]
|
||||
right: _Pipeline[Any, Any]
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _PipelineAnd:
|
||||
left: _Pipeline[Any, Any]
|
||||
right: _Pipeline[Any, Any]
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _Eq:
|
||||
value: Any
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _NotEq:
|
||||
value: Any
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _In:
|
||||
values: Container[Any]
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _NotIn:
|
||||
values: Container[Any]
|
||||
|
||||
|
||||
_ConstraintAnnotation = Union[
|
||||
annotated_types.Le,
|
||||
annotated_types.Ge,
|
||||
annotated_types.Lt,
|
||||
annotated_types.Gt,
|
||||
annotated_types.Len,
|
||||
annotated_types.MultipleOf,
|
||||
annotated_types.Timezone,
|
||||
annotated_types.Interval,
|
||||
annotated_types.Predicate,
|
||||
# common predicates not included in annotated_types
|
||||
_Eq,
|
||||
_NotEq,
|
||||
_In,
|
||||
_NotIn,
|
||||
# regular expressions
|
||||
Pattern[str],
|
||||
]
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _Constraint:
|
||||
constraint: _ConstraintAnnotation
|
||||
|
||||
|
||||
_Step = Union[_ValidateAs, _ValidateAsDefer, _Transform, _PipelineOr, _PipelineAnd, _Constraint]
|
||||
|
||||
_InT = TypeVar('_InT')
|
||||
_OutT = TypeVar('_OutT')
|
||||
_NewOutT = TypeVar('_NewOutT')
|
||||
|
||||
|
||||
class _FieldTypeMarker:
|
||||
pass
|
||||
|
||||
|
||||
# TODO: ultimately, make this public, see https://github.com/pydantic/pydantic/pull/9459#discussion_r1628197626
|
||||
# Also, make this frozen eventually, but that doesn't work right now because of the generic base
|
||||
# Which attempts to modify __orig_base__ and such.
|
||||
# We could go with a manual freeze, but that seems overkill for now.
|
||||
@dataclass(**_slots_true)
|
||||
class _Pipeline(Generic[_InT, _OutT]):
|
||||
"""Abstract representation of a chain of validation, transformation, and parsing steps."""
|
||||
|
||||
_steps: tuple[_Step, ...]
|
||||
|
||||
def transform(
|
||||
self,
|
||||
func: Callable[[_OutT], _NewOutT],
|
||||
) -> _Pipeline[_InT, _NewOutT]:
|
||||
"""Transform the output of the previous step.
|
||||
|
||||
If used as the first step in a pipeline, the type of the field is used.
|
||||
That is, the transformation is applied to after the value is parsed to the field's type.
|
||||
"""
|
||||
return _Pipeline[_InT, _NewOutT](self._steps + (_Transform(func),))
|
||||
|
||||
@overload
|
||||
def validate_as(self, tp: type[_NewOutT], *, strict: bool = ...) -> _Pipeline[_InT, _NewOutT]: ...
|
||||
|
||||
@overload
|
||||
def validate_as(self, tp: EllipsisType, *, strict: bool = ...) -> _Pipeline[_InT, Any]: # type: ignore
|
||||
...
|
||||
|
||||
def validate_as(self, tp: type[_NewOutT] | EllipsisType, *, strict: bool = False) -> _Pipeline[_InT, Any]: # type: ignore
|
||||
"""Validate / parse the input into a new type.
|
||||
|
||||
If no type is provided, the type of the field is used.
|
||||
|
||||
Types are parsed in Pydantic's `lax` mode by default,
|
||||
but you can enable `strict` mode by passing `strict=True`.
|
||||
"""
|
||||
if isinstance(tp, EllipsisType):
|
||||
return _Pipeline[_InT, Any](self._steps + (_ValidateAs(_FieldTypeMarker, strict=strict),))
|
||||
return _Pipeline[_InT, _NewOutT](self._steps + (_ValidateAs(tp, strict=strict),))
|
||||
|
||||
def validate_as_deferred(self, func: Callable[[], type[_NewOutT]]) -> _Pipeline[_InT, _NewOutT]:
|
||||
"""Parse the input into a new type, deferring resolution of the type until the current class
|
||||
is fully defined.
|
||||
|
||||
This is useful when you need to reference the class in it's own type annotations.
|
||||
"""
|
||||
return _Pipeline[_InT, _NewOutT](self._steps + (_ValidateAsDefer(func),))
|
||||
|
||||
# constraints
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _NewOutGe], constraint: annotated_types.Ge) -> _Pipeline[_InT, _NewOutGe]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _NewOutGt], constraint: annotated_types.Gt) -> _Pipeline[_InT, _NewOutGt]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _NewOutLe], constraint: annotated_types.Le) -> _Pipeline[_InT, _NewOutLe]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _NewOutLt], constraint: annotated_types.Lt) -> _Pipeline[_InT, _NewOutLt]: ...
|
||||
|
||||
@overload
|
||||
def constrain(
|
||||
self: _Pipeline[_InT, _NewOutLen], constraint: annotated_types.Len
|
||||
) -> _Pipeline[_InT, _NewOutLen]: ...
|
||||
|
||||
@overload
|
||||
def constrain(
|
||||
self: _Pipeline[_InT, _NewOutT], constraint: annotated_types.MultipleOf
|
||||
) -> _Pipeline[_InT, _NewOutT]: ...
|
||||
|
||||
@overload
|
||||
def constrain(
|
||||
self: _Pipeline[_InT, _NewOutDatetime], constraint: annotated_types.Timezone
|
||||
) -> _Pipeline[_InT, _NewOutDatetime]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _OutT], constraint: annotated_types.Predicate) -> _Pipeline[_InT, _OutT]: ...
|
||||
|
||||
@overload
|
||||
def constrain(
|
||||
self: _Pipeline[_InT, _NewOutInterval], constraint: annotated_types.Interval
|
||||
) -> _Pipeline[_InT, _NewOutInterval]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _OutT], constraint: _Eq) -> _Pipeline[_InT, _OutT]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _OutT], constraint: _NotEq) -> _Pipeline[_InT, _OutT]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _OutT], constraint: _In) -> _Pipeline[_InT, _OutT]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _OutT], constraint: _NotIn) -> _Pipeline[_InT, _OutT]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _NewOutT], constraint: Pattern[str]) -> _Pipeline[_InT, _NewOutT]: ...
|
||||
|
||||
def constrain(self, constraint: _ConstraintAnnotation) -> Any:
|
||||
"""Constrain a value to meet a certain condition.
|
||||
|
||||
We support most conditions from `annotated_types`, as well as regular expressions.
|
||||
|
||||
Most of the time you'll be calling a shortcut method like `gt`, `lt`, `len`, etc
|
||||
so you don't need to call this directly.
|
||||
"""
|
||||
return _Pipeline[_InT, _OutT](self._steps + (_Constraint(constraint),))
|
||||
|
||||
def predicate(self: _Pipeline[_InT, _NewOutT], func: Callable[[_NewOutT], bool]) -> _Pipeline[_InT, _NewOutT]:
|
||||
"""Constrain a value to meet a certain predicate."""
|
||||
return self.constrain(annotated_types.Predicate(func))
|
||||
|
||||
def gt(self: _Pipeline[_InT, _NewOutGt], gt: _NewOutGt) -> _Pipeline[_InT, _NewOutGt]:
|
||||
"""Constrain a value to be greater than a certain value."""
|
||||
return self.constrain(annotated_types.Gt(gt))
|
||||
|
||||
def lt(self: _Pipeline[_InT, _NewOutLt], lt: _NewOutLt) -> _Pipeline[_InT, _NewOutLt]:
|
||||
"""Constrain a value to be less than a certain value."""
|
||||
return self.constrain(annotated_types.Lt(lt))
|
||||
|
||||
def ge(self: _Pipeline[_InT, _NewOutGe], ge: _NewOutGe) -> _Pipeline[_InT, _NewOutGe]:
|
||||
"""Constrain a value to be greater than or equal to a certain value."""
|
||||
return self.constrain(annotated_types.Ge(ge))
|
||||
|
||||
def le(self: _Pipeline[_InT, _NewOutLe], le: _NewOutLe) -> _Pipeline[_InT, _NewOutLe]:
|
||||
"""Constrain a value to be less than or equal to a certain value."""
|
||||
return self.constrain(annotated_types.Le(le))
|
||||
|
||||
def len(self: _Pipeline[_InT, _NewOutLen], min_len: int, max_len: int | None = None) -> _Pipeline[_InT, _NewOutLen]:
|
||||
"""Constrain a value to have a certain length."""
|
||||
return self.constrain(annotated_types.Len(min_len, max_len))
|
||||
|
||||
@overload
|
||||
def multiple_of(self: _Pipeline[_InT, _NewOutDiv], multiple_of: _NewOutDiv) -> _Pipeline[_InT, _NewOutDiv]: ...
|
||||
|
||||
@overload
|
||||
def multiple_of(self: _Pipeline[_InT, _NewOutMod], multiple_of: _NewOutMod) -> _Pipeline[_InT, _NewOutMod]: ...
|
||||
|
||||
def multiple_of(self: _Pipeline[_InT, Any], multiple_of: Any) -> _Pipeline[_InT, Any]:
|
||||
"""Constrain a value to be a multiple of a certain number."""
|
||||
return self.constrain(annotated_types.MultipleOf(multiple_of))
|
||||
|
||||
def eq(self: _Pipeline[_InT, _OutT], value: _OutT) -> _Pipeline[_InT, _OutT]:
|
||||
"""Constrain a value to be equal to a certain value."""
|
||||
return self.constrain(_Eq(value))
|
||||
|
||||
def not_eq(self: _Pipeline[_InT, _OutT], value: _OutT) -> _Pipeline[_InT, _OutT]:
|
||||
"""Constrain a value to not be equal to a certain value."""
|
||||
return self.constrain(_NotEq(value))
|
||||
|
||||
def in_(self: _Pipeline[_InT, _OutT], values: Container[_OutT]) -> _Pipeline[_InT, _OutT]:
|
||||
"""Constrain a value to be in a certain set."""
|
||||
return self.constrain(_In(values))
|
||||
|
||||
def not_in(self: _Pipeline[_InT, _OutT], values: Container[_OutT]) -> _Pipeline[_InT, _OutT]:
|
||||
"""Constrain a value to not be in a certain set."""
|
||||
return self.constrain(_NotIn(values))
|
||||
|
||||
# timezone methods
|
||||
def datetime_tz_naive(self: _Pipeline[_InT, datetime.datetime]) -> _Pipeline[_InT, datetime.datetime]:
|
||||
return self.constrain(annotated_types.Timezone(None))
|
||||
|
||||
def datetime_tz_aware(self: _Pipeline[_InT, datetime.datetime]) -> _Pipeline[_InT, datetime.datetime]:
|
||||
return self.constrain(annotated_types.Timezone(...))
|
||||
|
||||
def datetime_tz(
|
||||
self: _Pipeline[_InT, datetime.datetime], tz: datetime.tzinfo
|
||||
) -> _Pipeline[_InT, datetime.datetime]:
|
||||
return self.constrain(annotated_types.Timezone(tz)) # type: ignore
|
||||
|
||||
def datetime_with_tz(
|
||||
self: _Pipeline[_InT, datetime.datetime], tz: datetime.tzinfo | None
|
||||
) -> _Pipeline[_InT, datetime.datetime]:
|
||||
return self.transform(partial(datetime.datetime.replace, tzinfo=tz))
|
||||
|
||||
# string methods
|
||||
def str_lower(self: _Pipeline[_InT, str]) -> _Pipeline[_InT, str]:
|
||||
return self.transform(str.lower)
|
||||
|
||||
def str_upper(self: _Pipeline[_InT, str]) -> _Pipeline[_InT, str]:
|
||||
return self.transform(str.upper)
|
||||
|
||||
def str_title(self: _Pipeline[_InT, str]) -> _Pipeline[_InT, str]:
|
||||
return self.transform(str.title)
|
||||
|
||||
def str_strip(self: _Pipeline[_InT, str]) -> _Pipeline[_InT, str]:
|
||||
return self.transform(str.strip)
|
||||
|
||||
def str_pattern(self: _Pipeline[_InT, str], pattern: str) -> _Pipeline[_InT, str]:
|
||||
return self.constrain(re.compile(pattern))
|
||||
|
||||
def str_contains(self: _Pipeline[_InT, str], substring: str) -> _Pipeline[_InT, str]:
|
||||
return self.predicate(lambda v: substring in v)
|
||||
|
||||
def str_starts_with(self: _Pipeline[_InT, str], prefix: str) -> _Pipeline[_InT, str]:
|
||||
return self.predicate(lambda v: v.startswith(prefix))
|
||||
|
||||
def str_ends_with(self: _Pipeline[_InT, str], suffix: str) -> _Pipeline[_InT, str]:
|
||||
return self.predicate(lambda v: v.endswith(suffix))
|
||||
|
||||
# operators
|
||||
def otherwise(self, other: _Pipeline[_OtherIn, _OtherOut]) -> _Pipeline[_InT | _OtherIn, _OutT | _OtherOut]:
|
||||
"""Combine two validation chains, returning the result of the first chain if it succeeds, and the second chain if it fails."""
|
||||
return _Pipeline((_PipelineOr(self, other),))
|
||||
|
||||
__or__ = otherwise
|
||||
|
||||
def then(self, other: _Pipeline[_OutT, _OtherOut]) -> _Pipeline[_InT, _OtherOut]:
|
||||
"""Pipe the result of one validation chain into another."""
|
||||
return _Pipeline((_PipelineAnd(self, other),))
|
||||
|
||||
__and__ = then
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> cs.CoreSchema:
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
queue = deque(self._steps)
|
||||
|
||||
s = None
|
||||
|
||||
while queue:
|
||||
step = queue.popleft()
|
||||
s = _apply_step(step, s, handler, source_type)
|
||||
|
||||
s = s or cs.any_schema()
|
||||
return s
|
||||
|
||||
def __supports_type__(self, _: _OutT) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
validate_as = _Pipeline[Any, Any](()).validate_as
|
||||
validate_as_deferred = _Pipeline[Any, Any](()).validate_as_deferred
|
||||
transform = _Pipeline[Any, Any]((_ValidateAs(_FieldTypeMarker),)).transform
|
||||
|
||||
|
||||
def _check_func(
|
||||
func: Callable[[Any], bool], predicate_err: str | Callable[[], str], s: cs.CoreSchema | None
|
||||
) -> cs.CoreSchema:
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
def handler(v: Any) -> Any:
|
||||
if func(v):
|
||||
return v
|
||||
raise ValueError(f'Expected {predicate_err if isinstance(predicate_err, str) else predicate_err()}')
|
||||
|
||||
if s is None:
|
||||
return cs.no_info_plain_validator_function(handler)
|
||||
else:
|
||||
return cs.no_info_after_validator_function(handler, s)
|
||||
|
||||
|
||||
def _apply_step(step: _Step, s: cs.CoreSchema | None, handler: GetCoreSchemaHandler, source_type: Any) -> cs.CoreSchema:
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
if isinstance(step, _ValidateAs):
|
||||
s = _apply_parse(s, step.tp, step.strict, handler, source_type)
|
||||
elif isinstance(step, _ValidateAsDefer):
|
||||
s = _apply_parse(s, step.tp, False, handler, source_type)
|
||||
elif isinstance(step, _Transform):
|
||||
s = _apply_transform(s, step.func, handler)
|
||||
elif isinstance(step, _Constraint):
|
||||
s = _apply_constraint(s, step.constraint)
|
||||
elif isinstance(step, _PipelineOr):
|
||||
s = cs.union_schema([handler(step.left), handler(step.right)])
|
||||
else:
|
||||
assert isinstance(step, _PipelineAnd)
|
||||
s = cs.chain_schema([handler(step.left), handler(step.right)])
|
||||
return s
|
||||
|
||||
|
||||
def _apply_parse(
|
||||
s: cs.CoreSchema | None,
|
||||
tp: type[Any],
|
||||
strict: bool,
|
||||
handler: GetCoreSchemaHandler,
|
||||
source_type: Any,
|
||||
) -> cs.CoreSchema:
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
from pydantic import Strict
|
||||
|
||||
if tp is _FieldTypeMarker:
|
||||
return handler(source_type)
|
||||
|
||||
if strict:
|
||||
tp = Annotated[tp, Strict()] # type: ignore
|
||||
|
||||
if s and s['type'] == 'any':
|
||||
return handler(tp)
|
||||
else:
|
||||
return cs.chain_schema([s, handler(tp)]) if s else handler(tp)
|
||||
|
||||
|
||||
def _apply_transform(
|
||||
s: cs.CoreSchema | None, func: Callable[[Any], Any], handler: GetCoreSchemaHandler
|
||||
) -> cs.CoreSchema:
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
if s is None:
|
||||
return cs.no_info_plain_validator_function(func)
|
||||
|
||||
if s['type'] == 'str':
|
||||
if func is str.strip:
|
||||
s = s.copy()
|
||||
s['strip_whitespace'] = True
|
||||
return s
|
||||
elif func is str.lower:
|
||||
s = s.copy()
|
||||
s['to_lower'] = True
|
||||
return s
|
||||
elif func is str.upper:
|
||||
s = s.copy()
|
||||
s['to_upper'] = True
|
||||
return s
|
||||
|
||||
return cs.no_info_after_validator_function(func, s)
|
||||
|
||||
|
||||
def _apply_constraint( # noqa: C901
|
||||
s: cs.CoreSchema | None, constraint: _ConstraintAnnotation
|
||||
) -> cs.CoreSchema:
|
||||
"""Apply a single constraint to a schema."""
|
||||
if isinstance(constraint, annotated_types.Gt):
|
||||
gt = constraint.gt
|
||||
if s and s['type'] in {'int', 'float', 'decimal'}:
|
||||
s = s.copy()
|
||||
if s['type'] == 'int' and isinstance(gt, int):
|
||||
s['gt'] = gt
|
||||
elif s['type'] == 'float' and isinstance(gt, float):
|
||||
s['gt'] = gt
|
||||
elif s['type'] == 'decimal' and isinstance(gt, Decimal):
|
||||
s['gt'] = gt
|
||||
else:
|
||||
|
||||
def check_gt(v: Any) -> bool:
|
||||
return v > gt
|
||||
|
||||
s = _check_func(check_gt, f'> {gt}', s)
|
||||
elif isinstance(constraint, annotated_types.Ge):
|
||||
ge = constraint.ge
|
||||
if s and s['type'] in {'int', 'float', 'decimal'}:
|
||||
s = s.copy()
|
||||
if s['type'] == 'int' and isinstance(ge, int):
|
||||
s['ge'] = ge
|
||||
elif s['type'] == 'float' and isinstance(ge, float):
|
||||
s['ge'] = ge
|
||||
elif s['type'] == 'decimal' and isinstance(ge, Decimal):
|
||||
s['ge'] = ge
|
||||
|
||||
def check_ge(v: Any) -> bool:
|
||||
return v >= ge
|
||||
|
||||
s = _check_func(check_ge, f'>= {ge}', s)
|
||||
elif isinstance(constraint, annotated_types.Lt):
|
||||
lt = constraint.lt
|
||||
if s and s['type'] in {'int', 'float', 'decimal'}:
|
||||
s = s.copy()
|
||||
if s['type'] == 'int' and isinstance(lt, int):
|
||||
s['lt'] = lt
|
||||
elif s['type'] == 'float' and isinstance(lt, float):
|
||||
s['lt'] = lt
|
||||
elif s['type'] == 'decimal' and isinstance(lt, Decimal):
|
||||
s['lt'] = lt
|
||||
|
||||
def check_lt(v: Any) -> bool:
|
||||
return v < lt
|
||||
|
||||
s = _check_func(check_lt, f'< {lt}', s)
|
||||
elif isinstance(constraint, annotated_types.Le):
|
||||
le = constraint.le
|
||||
if s and s['type'] in {'int', 'float', 'decimal'}:
|
||||
s = s.copy()
|
||||
if s['type'] == 'int' and isinstance(le, int):
|
||||
s['le'] = le
|
||||
elif s['type'] == 'float' and isinstance(le, float):
|
||||
s['le'] = le
|
||||
elif s['type'] == 'decimal' and isinstance(le, Decimal):
|
||||
s['le'] = le
|
||||
|
||||
def check_le(v: Any) -> bool:
|
||||
return v <= le
|
||||
|
||||
s = _check_func(check_le, f'<= {le}', s)
|
||||
elif isinstance(constraint, annotated_types.Len):
|
||||
min_len = constraint.min_length
|
||||
max_len = constraint.max_length
|
||||
|
||||
if s and s['type'] in {'str', 'list', 'tuple', 'set', 'frozenset', 'dict'}:
|
||||
assert (
|
||||
s['type'] == 'str'
|
||||
or s['type'] == 'list'
|
||||
or s['type'] == 'tuple'
|
||||
or s['type'] == 'set'
|
||||
or s['type'] == 'dict'
|
||||
or s['type'] == 'frozenset'
|
||||
)
|
||||
s = s.copy()
|
||||
if min_len != 0:
|
||||
s['min_length'] = min_len
|
||||
if max_len is not None:
|
||||
s['max_length'] = max_len
|
||||
|
||||
def check_len(v: Any) -> bool:
|
||||
if max_len is not None:
|
||||
return (min_len <= len(v)) and (len(v) <= max_len)
|
||||
return min_len <= len(v)
|
||||
|
||||
s = _check_func(check_len, f'length >= {min_len} and length <= {max_len}', s)
|
||||
elif isinstance(constraint, annotated_types.MultipleOf):
|
||||
multiple_of = constraint.multiple_of
|
||||
if s and s['type'] in {'int', 'float', 'decimal'}:
|
||||
s = s.copy()
|
||||
if s['type'] == 'int' and isinstance(multiple_of, int):
|
||||
s['multiple_of'] = multiple_of
|
||||
elif s['type'] == 'float' and isinstance(multiple_of, float):
|
||||
s['multiple_of'] = multiple_of
|
||||
elif s['type'] == 'decimal' and isinstance(multiple_of, Decimal):
|
||||
s['multiple_of'] = multiple_of
|
||||
|
||||
def check_multiple_of(v: Any) -> bool:
|
||||
return v % multiple_of == 0
|
||||
|
||||
s = _check_func(check_multiple_of, f'% {multiple_of} == 0', s)
|
||||
elif isinstance(constraint, annotated_types.Timezone):
|
||||
tz = constraint.tz
|
||||
|
||||
if tz is ...:
|
||||
if s and s['type'] == 'datetime':
|
||||
s = s.copy()
|
||||
s['tz_constraint'] = 'aware'
|
||||
else:
|
||||
|
||||
def check_tz_aware(v: object) -> bool:
|
||||
assert isinstance(v, datetime.datetime)
|
||||
return v.tzinfo is not None
|
||||
|
||||
s = _check_func(check_tz_aware, 'timezone aware', s)
|
||||
elif tz is None:
|
||||
if s and s['type'] == 'datetime':
|
||||
s = s.copy()
|
||||
s['tz_constraint'] = 'naive'
|
||||
else:
|
||||
|
||||
def check_tz_naive(v: object) -> bool:
|
||||
assert isinstance(v, datetime.datetime)
|
||||
return v.tzinfo is None
|
||||
|
||||
s = _check_func(check_tz_naive, 'timezone naive', s)
|
||||
else:
|
||||
raise NotImplementedError('Constraining to a specific timezone is not yet supported')
|
||||
elif isinstance(constraint, annotated_types.Interval):
|
||||
if constraint.ge:
|
||||
s = _apply_constraint(s, annotated_types.Ge(constraint.ge))
|
||||
if constraint.gt:
|
||||
s = _apply_constraint(s, annotated_types.Gt(constraint.gt))
|
||||
if constraint.le:
|
||||
s = _apply_constraint(s, annotated_types.Le(constraint.le))
|
||||
if constraint.lt:
|
||||
s = _apply_constraint(s, annotated_types.Lt(constraint.lt))
|
||||
assert s is not None
|
||||
elif isinstance(constraint, annotated_types.Predicate):
|
||||
func = constraint.func
|
||||
|
||||
if func.__name__ == '<lambda>':
|
||||
# attempt to extract the source code for a lambda function
|
||||
# to use as the function name in error messages
|
||||
# TODO: is there a better way? should we just not do this?
|
||||
import inspect
|
||||
|
||||
try:
|
||||
# remove ')' suffix, can use removesuffix once we drop 3.8
|
||||
source = inspect.getsource(func).strip()
|
||||
if source.endswith(')'):
|
||||
source = source[:-1]
|
||||
lambda_source_code = '`' + ''.join(''.join(source.split('lambda ')[1:]).split(':')[1:]).strip() + '`'
|
||||
except OSError:
|
||||
# stringified annotations
|
||||
lambda_source_code = 'lambda'
|
||||
|
||||
s = _check_func(func, lambda_source_code, s)
|
||||
else:
|
||||
s = _check_func(func, func.__name__, s)
|
||||
elif isinstance(constraint, _NotEq):
|
||||
value = constraint.value
|
||||
|
||||
def check_not_eq(v: Any) -> bool:
|
||||
return operator.__ne__(v, value)
|
||||
|
||||
s = _check_func(check_not_eq, f'!= {value}', s)
|
||||
elif isinstance(constraint, _Eq):
|
||||
value = constraint.value
|
||||
|
||||
def check_eq(v: Any) -> bool:
|
||||
return operator.__eq__(v, value)
|
||||
|
||||
s = _check_func(check_eq, f'== {value}', s)
|
||||
elif isinstance(constraint, _In):
|
||||
values = constraint.values
|
||||
|
||||
def check_in(v: Any) -> bool:
|
||||
return operator.__contains__(values, v)
|
||||
|
||||
s = _check_func(check_in, f'in {values}', s)
|
||||
elif isinstance(constraint, _NotIn):
|
||||
values = constraint.values
|
||||
|
||||
def check_not_in(v: Any) -> bool:
|
||||
return operator.__not__(operator.__contains__(values, v))
|
||||
|
||||
s = _check_func(check_not_in, f'not in {values}', s)
|
||||
else:
|
||||
assert isinstance(constraint, Pattern)
|
||||
if s and s['type'] == 'str':
|
||||
s = s.copy()
|
||||
s['pattern'] = constraint.pattern
|
||||
else:
|
||||
|
||||
def check_pattern(v: object) -> bool:
|
||||
assert isinstance(v, str)
|
||||
return constraint.match(v) is not None
|
||||
|
||||
s = _check_func(check_pattern, f'~ {constraint.pattern}', s)
|
||||
return s
|
||||
|
||||
|
||||
class _SupportsRange(annotated_types.SupportsLe, annotated_types.SupportsGe, Protocol):
|
||||
pass
|
||||
|
||||
|
||||
class _SupportsLen(Protocol):
|
||||
def __len__(self) -> int: ...
|
||||
|
||||
|
||||
_NewOutGt = TypeVar('_NewOutGt', bound=annotated_types.SupportsGt)
|
||||
_NewOutGe = TypeVar('_NewOutGe', bound=annotated_types.SupportsGe)
|
||||
_NewOutLt = TypeVar('_NewOutLt', bound=annotated_types.SupportsLt)
|
||||
_NewOutLe = TypeVar('_NewOutLe', bound=annotated_types.SupportsLe)
|
||||
_NewOutLen = TypeVar('_NewOutLen', bound=_SupportsLen)
|
||||
_NewOutDiv = TypeVar('_NewOutDiv', bound=annotated_types.SupportsDiv)
|
||||
_NewOutMod = TypeVar('_NewOutMod', bound=annotated_types.SupportsMod)
|
||||
_NewOutDatetime = TypeVar('_NewOutDatetime', bound=datetime.datetime)
|
||||
_NewOutInterval = TypeVar('_NewOutInterval', bound=_SupportsRange)
|
||||
_OtherIn = TypeVar('_OtherIn')
|
||||
_OtherOut = TypeVar('_OtherOut')
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,449 @@
|
||||
"""This module contains related classes and functions for serialization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from functools import partial, partialmethod
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload
|
||||
|
||||
from pydantic_core import PydanticUndefined, core_schema
|
||||
from pydantic_core.core_schema import SerializationInfo, SerializerFunctionWrapHandler, WhenUsed
|
||||
from typing_extensions import Annotated, Literal, TypeAlias
|
||||
|
||||
from . import PydanticUndefinedAnnotation
|
||||
from ._internal import _decorators, _internal_dataclass
|
||||
from .annotated_handlers import GetCoreSchemaHandler
|
||||
|
||||
|
||||
@dataclasses.dataclass(**_internal_dataclass.slots_true, frozen=True)
|
||||
class PlainSerializer:
|
||||
"""Plain serializers use a function to modify the output of serialization.
|
||||
|
||||
This is particularly helpful when you want to customize the serialization for annotated types.
|
||||
Consider an input of `list`, which will be serialized into a space-delimited string.
|
||||
|
||||
```python
|
||||
from typing import List
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic import BaseModel, PlainSerializer
|
||||
|
||||
CustomStr = Annotated[
|
||||
List, PlainSerializer(lambda x: ' '.join(x), return_type=str)
|
||||
]
|
||||
|
||||
class StudentModel(BaseModel):
|
||||
courses: CustomStr
|
||||
|
||||
student = StudentModel(courses=['Math', 'Chemistry', 'English'])
|
||||
print(student.model_dump())
|
||||
#> {'courses': 'Math Chemistry English'}
|
||||
```
|
||||
|
||||
Attributes:
|
||||
func: The serializer function.
|
||||
return_type: The return type for the function. If omitted it will be inferred from the type annotation.
|
||||
when_used: Determines when this serializer should be used. Accepts a string with values `'always'`,
|
||||
`'unless-none'`, `'json'`, and `'json-unless-none'`. Defaults to 'always'.
|
||||
"""
|
||||
|
||||
func: core_schema.SerializerFunction
|
||||
return_type: Any = PydanticUndefined
|
||||
when_used: WhenUsed = 'always'
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
"""Gets the Pydantic core schema.
|
||||
|
||||
Args:
|
||||
source_type: The source type.
|
||||
handler: The `GetCoreSchemaHandler` instance.
|
||||
|
||||
Returns:
|
||||
The Pydantic core schema.
|
||||
"""
|
||||
schema = handler(source_type)
|
||||
try:
|
||||
# Do not pass in globals as the function could be defined in a different module.
|
||||
# Instead, let `get_function_return_type` infer the globals to use, but still pass
|
||||
# in locals that may contain a parent/rebuild namespace:
|
||||
return_type = _decorators.get_function_return_type(
|
||||
self.func,
|
||||
self.return_type,
|
||||
localns=handler._get_types_namespace().locals,
|
||||
)
|
||||
except NameError as e:
|
||||
raise PydanticUndefinedAnnotation.from_name_error(e) from e
|
||||
return_schema = None if return_type is PydanticUndefined else handler.generate_schema(return_type)
|
||||
schema['serialization'] = core_schema.plain_serializer_function_ser_schema(
|
||||
function=self.func,
|
||||
info_arg=_decorators.inspect_annotated_serializer(self.func, 'plain'),
|
||||
return_schema=return_schema,
|
||||
when_used=self.when_used,
|
||||
)
|
||||
return schema
|
||||
|
||||
|
||||
@dataclasses.dataclass(**_internal_dataclass.slots_true, frozen=True)
|
||||
class WrapSerializer:
|
||||
"""Wrap serializers receive the raw inputs along with a handler function that applies the standard serialization
|
||||
logic, and can modify the resulting value before returning it as the final output of serialization.
|
||||
|
||||
For example, here's a scenario in which a wrap serializer transforms timezones to UTC **and** utilizes the existing `datetime` serialization logic.
|
||||
|
||||
```python
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic import BaseModel, WrapSerializer
|
||||
|
||||
class EventDatetime(BaseModel):
|
||||
start: datetime
|
||||
end: datetime
|
||||
|
||||
def convert_to_utc(value: Any, handler, info) -> Dict[str, datetime]:
|
||||
# Note that `handler` can actually help serialize the `value` for
|
||||
# further custom serialization in case it's a subclass.
|
||||
partial_result = handler(value, info)
|
||||
if info.mode == 'json':
|
||||
return {
|
||||
k: datetime.fromisoformat(v).astimezone(timezone.utc)
|
||||
for k, v in partial_result.items()
|
||||
}
|
||||
return {k: v.astimezone(timezone.utc) for k, v in partial_result.items()}
|
||||
|
||||
UTCEventDatetime = Annotated[EventDatetime, WrapSerializer(convert_to_utc)]
|
||||
|
||||
class EventModel(BaseModel):
|
||||
event_datetime: UTCEventDatetime
|
||||
|
||||
dt = EventDatetime(
|
||||
start='2024-01-01T07:00:00-08:00', end='2024-01-03T20:00:00+06:00'
|
||||
)
|
||||
event = EventModel(event_datetime=dt)
|
||||
print(event.model_dump())
|
||||
'''
|
||||
{
|
||||
'event_datetime': {
|
||||
'start': datetime.datetime(
|
||||
2024, 1, 1, 15, 0, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
'end': datetime.datetime(
|
||||
2024, 1, 3, 14, 0, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
print(event.model_dump_json())
|
||||
'''
|
||||
{"event_datetime":{"start":"2024-01-01T15:00:00Z","end":"2024-01-03T14:00:00Z"}}
|
||||
'''
|
||||
```
|
||||
|
||||
Attributes:
|
||||
func: The serializer function to be wrapped.
|
||||
return_type: The return type for the function. If omitted it will be inferred from the type annotation.
|
||||
when_used: Determines when this serializer should be used. Accepts a string with values `'always'`,
|
||||
`'unless-none'`, `'json'`, and `'json-unless-none'`. Defaults to 'always'.
|
||||
"""
|
||||
|
||||
func: core_schema.WrapSerializerFunction
|
||||
return_type: Any = PydanticUndefined
|
||||
when_used: WhenUsed = 'always'
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
"""This method is used to get the Pydantic core schema of the class.
|
||||
|
||||
Args:
|
||||
source_type: Source type.
|
||||
handler: Core schema handler.
|
||||
|
||||
Returns:
|
||||
The generated core schema of the class.
|
||||
"""
|
||||
schema = handler(source_type)
|
||||
globalns, localns = handler._get_types_namespace()
|
||||
try:
|
||||
# Do not pass in globals as the function could be defined in a different module.
|
||||
# Instead, let `get_function_return_type` infer the globals to use, but still pass
|
||||
# in locals that may contain a parent/rebuild namespace:
|
||||
return_type = _decorators.get_function_return_type(
|
||||
self.func,
|
||||
self.return_type,
|
||||
localns=handler._get_types_namespace().locals,
|
||||
)
|
||||
except NameError as e:
|
||||
raise PydanticUndefinedAnnotation.from_name_error(e) from e
|
||||
return_schema = None if return_type is PydanticUndefined else handler.generate_schema(return_type)
|
||||
schema['serialization'] = core_schema.wrap_serializer_function_ser_schema(
|
||||
function=self.func,
|
||||
info_arg=_decorators.inspect_annotated_serializer(self.func, 'wrap'),
|
||||
return_schema=return_schema,
|
||||
when_used=self.when_used,
|
||||
)
|
||||
return schema
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
_Partial: TypeAlias = 'partial[Any] | partialmethod[Any]'
|
||||
|
||||
FieldPlainSerializer: TypeAlias = 'core_schema.SerializerFunction | _Partial'
|
||||
"""A field serializer method or function in `plain` mode."""
|
||||
|
||||
FieldWrapSerializer: TypeAlias = 'core_schema.WrapSerializerFunction | _Partial'
|
||||
"""A field serializer method or function in `wrap` mode."""
|
||||
|
||||
FieldSerializer: TypeAlias = 'FieldPlainSerializer | FieldWrapSerializer'
|
||||
"""A field serializer method or function."""
|
||||
|
||||
_FieldPlainSerializerT = TypeVar('_FieldPlainSerializerT', bound=FieldPlainSerializer)
|
||||
_FieldWrapSerializerT = TypeVar('_FieldWrapSerializerT', bound=FieldWrapSerializer)
|
||||
|
||||
|
||||
@overload
|
||||
def field_serializer(
|
||||
field: str,
|
||||
/,
|
||||
*fields: str,
|
||||
mode: Literal['wrap'],
|
||||
return_type: Any = ...,
|
||||
when_used: WhenUsed = ...,
|
||||
check_fields: bool | None = ...,
|
||||
) -> Callable[[_FieldWrapSerializerT], _FieldWrapSerializerT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def field_serializer(
|
||||
field: str,
|
||||
/,
|
||||
*fields: str,
|
||||
mode: Literal['plain'] = ...,
|
||||
return_type: Any = ...,
|
||||
when_used: WhenUsed = ...,
|
||||
check_fields: bool | None = ...,
|
||||
) -> Callable[[_FieldPlainSerializerT], _FieldPlainSerializerT]: ...
|
||||
|
||||
|
||||
def field_serializer(
|
||||
*fields: str,
|
||||
mode: Literal['plain', 'wrap'] = 'plain',
|
||||
return_type: Any = PydanticUndefined,
|
||||
when_used: WhenUsed = 'always',
|
||||
check_fields: bool | None = None,
|
||||
) -> (
|
||||
Callable[[_FieldWrapSerializerT], _FieldWrapSerializerT]
|
||||
| Callable[[_FieldPlainSerializerT], _FieldPlainSerializerT]
|
||||
):
|
||||
"""Decorator that enables custom field serialization.
|
||||
|
||||
In the below example, a field of type `set` is used to mitigate duplication. A `field_serializer` is used to serialize the data as a sorted list.
|
||||
|
||||
```python
|
||||
from typing import Set
|
||||
|
||||
from pydantic import BaseModel, field_serializer
|
||||
|
||||
class StudentModel(BaseModel):
|
||||
name: str = 'Jane'
|
||||
courses: Set[str]
|
||||
|
||||
@field_serializer('courses', when_used='json')
|
||||
def serialize_courses_in_order(self, courses: Set[str]):
|
||||
return sorted(courses)
|
||||
|
||||
student = StudentModel(courses={'Math', 'Chemistry', 'English'})
|
||||
print(student.model_dump_json())
|
||||
#> {"name":"Jane","courses":["Chemistry","English","Math"]}
|
||||
```
|
||||
|
||||
See [Custom serializers](../concepts/serialization.md#custom-serializers) for more information.
|
||||
|
||||
Four signatures are supported:
|
||||
|
||||
- `(self, value: Any, info: FieldSerializationInfo)`
|
||||
- `(self, value: Any, nxt: SerializerFunctionWrapHandler, info: FieldSerializationInfo)`
|
||||
- `(value: Any, info: SerializationInfo)`
|
||||
- `(value: Any, nxt: SerializerFunctionWrapHandler, info: SerializationInfo)`
|
||||
|
||||
Args:
|
||||
fields: Which field(s) the method should be called on.
|
||||
mode: The serialization mode.
|
||||
|
||||
- `plain` means the function will be called instead of the default serialization logic,
|
||||
- `wrap` means the function will be called with an argument to optionally call the
|
||||
default serialization logic.
|
||||
return_type: Optional return type for the function, if omitted it will be inferred from the type annotation.
|
||||
when_used: Determines the serializer will be used for serialization.
|
||||
check_fields: Whether to check that the fields actually exist on the model.
|
||||
|
||||
Returns:
|
||||
The decorator function.
|
||||
"""
|
||||
|
||||
def dec(f: FieldSerializer) -> _decorators.PydanticDescriptorProxy[Any]:
|
||||
dec_info = _decorators.FieldSerializerDecoratorInfo(
|
||||
fields=fields,
|
||||
mode=mode,
|
||||
return_type=return_type,
|
||||
when_used=when_used,
|
||||
check_fields=check_fields,
|
||||
)
|
||||
return _decorators.PydanticDescriptorProxy(f, dec_info) # pyright: ignore[reportArgumentType]
|
||||
|
||||
return dec # pyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# The first argument in the following callables represent the `self` type:
|
||||
|
||||
ModelPlainSerializerWithInfo: TypeAlias = Callable[[Any, SerializationInfo], Any]
|
||||
"""A model serializer method with the `info` argument, in `plain` mode."""
|
||||
|
||||
ModelPlainSerializerWithoutInfo: TypeAlias = Callable[[Any], Any]
|
||||
"""A model serializer method without the `info` argument, in `plain` mode."""
|
||||
|
||||
ModelPlainSerializer: TypeAlias = 'ModelPlainSerializerWithInfo | ModelPlainSerializerWithoutInfo'
|
||||
"""A model serializer method in `plain` mode."""
|
||||
|
||||
ModelWrapSerializerWithInfo: TypeAlias = Callable[[Any, SerializerFunctionWrapHandler, SerializationInfo], Any]
|
||||
"""A model serializer method with the `info` argument, in `wrap` mode."""
|
||||
|
||||
ModelWrapSerializerWithoutInfo: TypeAlias = Callable[[Any, SerializerFunctionWrapHandler], Any]
|
||||
"""A model serializer method without the `info` argument, in `wrap` mode."""
|
||||
|
||||
ModelWrapSerializer: TypeAlias = 'ModelWrapSerializerWithInfo | ModelWrapSerializerWithoutInfo'
|
||||
"""A model serializer method in `wrap` mode."""
|
||||
|
||||
ModelSerializer: TypeAlias = 'ModelPlainSerializer | ModelWrapSerializer'
|
||||
|
||||
_ModelPlainSerializerT = TypeVar('_ModelPlainSerializerT', bound=ModelPlainSerializer)
|
||||
_ModelWrapSerializerT = TypeVar('_ModelWrapSerializerT', bound=ModelWrapSerializer)
|
||||
|
||||
|
||||
@overload
|
||||
def model_serializer(f: _ModelPlainSerializerT, /) -> _ModelPlainSerializerT: ...
|
||||
|
||||
|
||||
@overload
|
||||
def model_serializer(
|
||||
*, mode: Literal['wrap'], when_used: WhenUsed = 'always', return_type: Any = ...
|
||||
) -> Callable[[_ModelWrapSerializerT], _ModelWrapSerializerT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def model_serializer(
|
||||
*,
|
||||
mode: Literal['plain'] = ...,
|
||||
when_used: WhenUsed = 'always',
|
||||
return_type: Any = ...,
|
||||
) -> Callable[[_ModelPlainSerializerT], _ModelPlainSerializerT]: ...
|
||||
|
||||
|
||||
def model_serializer(
|
||||
f: _ModelPlainSerializerT | _ModelWrapSerializerT | None = None,
|
||||
/,
|
||||
*,
|
||||
mode: Literal['plain', 'wrap'] = 'plain',
|
||||
when_used: WhenUsed = 'always',
|
||||
return_type: Any = PydanticUndefined,
|
||||
) -> (
|
||||
_ModelPlainSerializerT
|
||||
| Callable[[_ModelWrapSerializerT], _ModelWrapSerializerT]
|
||||
| Callable[[_ModelPlainSerializerT], _ModelPlainSerializerT]
|
||||
):
|
||||
"""Decorator that enables custom model serialization.
|
||||
|
||||
This is useful when a model need to be serialized in a customized manner, allowing for flexibility beyond just specific fields.
|
||||
|
||||
An example would be to serialize temperature to the same temperature scale, such as degrees Celsius.
|
||||
|
||||
```python
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, model_serializer
|
||||
|
||||
class TemperatureModel(BaseModel):
|
||||
unit: Literal['C', 'F']
|
||||
value: int
|
||||
|
||||
@model_serializer()
|
||||
def serialize_model(self):
|
||||
if self.unit == 'F':
|
||||
return {'unit': 'C', 'value': int((self.value - 32) / 1.8)}
|
||||
return {'unit': self.unit, 'value': self.value}
|
||||
|
||||
temperature = TemperatureModel(unit='F', value=212)
|
||||
print(temperature.model_dump())
|
||||
#> {'unit': 'C', 'value': 100}
|
||||
```
|
||||
|
||||
Two signatures are supported for `mode='plain'`, which is the default:
|
||||
|
||||
- `(self)`
|
||||
- `(self, info: SerializationInfo)`
|
||||
|
||||
And two other signatures for `mode='wrap'`:
|
||||
|
||||
- `(self, nxt: SerializerFunctionWrapHandler)`
|
||||
- `(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo)`
|
||||
|
||||
See [Custom serializers](../concepts/serialization.md#custom-serializers) for more information.
|
||||
|
||||
Args:
|
||||
f: The function to be decorated.
|
||||
mode: The serialization mode.
|
||||
|
||||
- `'plain'` means the function will be called instead of the default serialization logic
|
||||
- `'wrap'` means the function will be called with an argument to optionally call the default
|
||||
serialization logic.
|
||||
when_used: Determines when this serializer should be used.
|
||||
return_type: The return type for the function. If omitted it will be inferred from the type annotation.
|
||||
|
||||
Returns:
|
||||
The decorator function.
|
||||
"""
|
||||
|
||||
def dec(f: ModelSerializer) -> _decorators.PydanticDescriptorProxy[Any]:
|
||||
dec_info = _decorators.ModelSerializerDecoratorInfo(mode=mode, return_type=return_type, when_used=when_used)
|
||||
return _decorators.PydanticDescriptorProxy(f, dec_info)
|
||||
|
||||
if f is None:
|
||||
return dec # pyright: ignore[reportReturnType]
|
||||
else:
|
||||
return dec(f) # pyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
AnyType = TypeVar('AnyType')
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
SerializeAsAny = Annotated[AnyType, ...] # SerializeAsAny[list[str]] will be treated by type checkers as list[str]
|
||||
"""Force serialization to ignore whatever is defined in the schema and instead ask the object
|
||||
itself how it should be serialized.
|
||||
In particular, this means that when model subclasses are serialized, fields present in the subclass
|
||||
but not in the original schema will be included.
|
||||
"""
|
||||
else:
|
||||
|
||||
@dataclasses.dataclass(**_internal_dataclass.slots_true)
|
||||
class SerializeAsAny: # noqa: D101
|
||||
def __class_getitem__(cls, item: Any) -> Any:
|
||||
return Annotated[item, SerializeAsAny()]
|
||||
|
||||
def __get_pydantic_core_schema__(
|
||||
self, source_type: Any, handler: GetCoreSchemaHandler
|
||||
) -> core_schema.CoreSchema:
|
||||
schema = handler(source_type)
|
||||
schema_to_update = schema
|
||||
while schema_to_update['type'] == 'definitions':
|
||||
schema_to_update = schema_to_update.copy()
|
||||
schema_to_update = schema_to_update['schema']
|
||||
schema_to_update['serialization'] = core_schema.wrap_serializer_function_ser_schema(
|
||||
lambda x, h: h(x), schema=core_schema.any_schema()
|
||||
)
|
||||
return schema
|
||||
|
||||
__hash__ = object.__hash__
|
||||
@@ -0,0 +1,825 @@
|
||||
"""This module contains related classes and functions for validation."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses
|
||||
import sys
|
||||
from functools import partialmethod
|
||||
from types import FunctionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, cast, overload
|
||||
|
||||
from pydantic_core import PydanticUndefined, core_schema
|
||||
from pydantic_core import core_schema as _core_schema
|
||||
from typing_extensions import Annotated, Literal, Self, TypeAlias
|
||||
|
||||
from ._internal import _decorators, _generics, _internal_dataclass
|
||||
from .annotated_handlers import GetCoreSchemaHandler
|
||||
from .errors import PydanticUserError
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from typing_extensions import Protocol
|
||||
else:
|
||||
from typing import Protocol
|
||||
|
||||
_inspect_validator = _decorators.inspect_validator
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true)
|
||||
class AfterValidator:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.10/concepts/validators/#field-validators
|
||||
|
||||
A metadata class that indicates that a validation should be applied **after** the inner validation logic.
|
||||
|
||||
Attributes:
|
||||
func: The validator function.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, ValidationError
|
||||
|
||||
MyInt = Annotated[int, AfterValidator(lambda v: v + 1)]
|
||||
|
||||
class Model(BaseModel):
|
||||
a: MyInt
|
||||
|
||||
print(Model(a=1).a)
|
||||
#> 2
|
||||
|
||||
try:
|
||||
Model(a='a')
|
||||
except ValidationError as e:
|
||||
print(e.json(indent=2))
|
||||
'''
|
||||
[
|
||||
{
|
||||
"type": "int_parsing",
|
||||
"loc": [
|
||||
"a"
|
||||
],
|
||||
"msg": "Input should be a valid integer, unable to parse string as an integer",
|
||||
"input": "a",
|
||||
"url": "https://errors.pydantic.dev/2/v/int_parsing"
|
||||
}
|
||||
]
|
||||
'''
|
||||
```
|
||||
"""
|
||||
|
||||
func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
schema = handler(source_type)
|
||||
info_arg = _inspect_validator(self.func, 'after')
|
||||
if info_arg:
|
||||
func = cast(core_schema.WithInfoValidatorFunction, self.func)
|
||||
return core_schema.with_info_after_validator_function(func, schema=schema, field_name=handler.field_name)
|
||||
else:
|
||||
func = cast(core_schema.NoInfoValidatorFunction, self.func)
|
||||
return core_schema.no_info_after_validator_function(func, schema=schema)
|
||||
|
||||
@classmethod
|
||||
def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self:
|
||||
return cls(func=decorator.func)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true)
|
||||
class BeforeValidator:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.10/concepts/validators/#field-validators
|
||||
|
||||
A metadata class that indicates that a validation should be applied **before** the inner validation logic.
|
||||
|
||||
Attributes:
|
||||
func: The validator function.
|
||||
json_schema_input_type: The input type of the function. This is only used to generate the appropriate
|
||||
JSON Schema (in validation mode).
|
||||
|
||||
Example:
|
||||
```python
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic import BaseModel, BeforeValidator
|
||||
|
||||
MyInt = Annotated[int, BeforeValidator(lambda v: v + 1)]
|
||||
|
||||
class Model(BaseModel):
|
||||
a: MyInt
|
||||
|
||||
print(Model(a=1).a)
|
||||
#> 2
|
||||
|
||||
try:
|
||||
Model(a='a')
|
||||
except TypeError as e:
|
||||
print(e)
|
||||
#> can only concatenate str (not "int") to str
|
||||
```
|
||||
"""
|
||||
|
||||
func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction
|
||||
json_schema_input_type: Any = PydanticUndefined
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
schema = handler(source_type)
|
||||
input_schema = (
|
||||
None
|
||||
if self.json_schema_input_type is PydanticUndefined
|
||||
else handler.generate_schema(self.json_schema_input_type)
|
||||
)
|
||||
|
||||
info_arg = _inspect_validator(self.func, 'before')
|
||||
if info_arg:
|
||||
func = cast(core_schema.WithInfoValidatorFunction, self.func)
|
||||
return core_schema.with_info_before_validator_function(
|
||||
func,
|
||||
schema=schema,
|
||||
field_name=handler.field_name,
|
||||
json_schema_input_schema=input_schema,
|
||||
)
|
||||
else:
|
||||
func = cast(core_schema.NoInfoValidatorFunction, self.func)
|
||||
return core_schema.no_info_before_validator_function(
|
||||
func, schema=schema, json_schema_input_schema=input_schema
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self:
|
||||
return cls(
|
||||
func=decorator.func,
|
||||
json_schema_input_type=decorator.info.json_schema_input_type,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true)
|
||||
class PlainValidator:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.10/concepts/validators/#field-validators
|
||||
|
||||
A metadata class that indicates that a validation should be applied **instead** of the inner validation logic.
|
||||
|
||||
!!! note
|
||||
Before v2.9, `PlainValidator` wasn't always compatible with JSON Schema generation for `mode='validation'`.
|
||||
You can now use the `json_schema_input_type` argument to specify the input type of the function
|
||||
to be used in the JSON schema when `mode='validation'` (the default). See the example below for more details.
|
||||
|
||||
Attributes:
|
||||
func: The validator function.
|
||||
json_schema_input_type: The input type of the function. This is only used to generate the appropriate
|
||||
JSON Schema (in validation mode). If not provided, will default to `Any`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from typing import Union
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic import BaseModel, PlainValidator
|
||||
|
||||
MyInt = Annotated[
|
||||
int,
|
||||
PlainValidator(
|
||||
lambda v: int(v) + 1, json_schema_input_type=Union[str, int] # (1)!
|
||||
),
|
||||
]
|
||||
|
||||
class Model(BaseModel):
|
||||
a: MyInt
|
||||
|
||||
print(Model(a='1').a)
|
||||
#> 2
|
||||
|
||||
print(Model(a=1).a)
|
||||
#> 2
|
||||
```
|
||||
|
||||
1. In this example, we've specified the `json_schema_input_type` as `Union[str, int]` which indicates to the JSON schema
|
||||
generator that in validation mode, the input type for the `a` field can be either a `str` or an `int`.
|
||||
"""
|
||||
|
||||
func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction
|
||||
json_schema_input_type: Any = Any
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
# Note that for some valid uses of PlainValidator, it is not possible to generate a core schema for the
|
||||
# source_type, so calling `handler(source_type)` will error, which prevents us from generating a proper
|
||||
# serialization schema. To work around this for use cases that will not involve serialization, we simply
|
||||
# catch any PydanticSchemaGenerationError that may be raised while attempting to build the serialization schema
|
||||
# and abort any attempts to handle special serialization.
|
||||
from pydantic import PydanticSchemaGenerationError
|
||||
|
||||
try:
|
||||
schema = handler(source_type)
|
||||
# TODO if `schema['serialization']` is one of `'include-exclude-dict/sequence',
|
||||
# schema validation will fail. That's why we use 'type ignore' comments below.
|
||||
serialization = schema.get(
|
||||
'serialization',
|
||||
core_schema.wrap_serializer_function_ser_schema(
|
||||
function=lambda v, h: h(v),
|
||||
schema=schema,
|
||||
return_schema=handler.generate_schema(source_type),
|
||||
),
|
||||
)
|
||||
except PydanticSchemaGenerationError:
|
||||
serialization = None
|
||||
|
||||
input_schema = handler.generate_schema(self.json_schema_input_type)
|
||||
|
||||
info_arg = _inspect_validator(self.func, 'plain')
|
||||
if info_arg:
|
||||
func = cast(core_schema.WithInfoValidatorFunction, self.func)
|
||||
return core_schema.with_info_plain_validator_function(
|
||||
func,
|
||||
field_name=handler.field_name,
|
||||
serialization=serialization, # pyright: ignore[reportArgumentType]
|
||||
json_schema_input_schema=input_schema,
|
||||
)
|
||||
else:
|
||||
func = cast(core_schema.NoInfoValidatorFunction, self.func)
|
||||
return core_schema.no_info_plain_validator_function(
|
||||
func,
|
||||
serialization=serialization, # pyright: ignore[reportArgumentType]
|
||||
json_schema_input_schema=input_schema,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self:
|
||||
return cls(
|
||||
func=decorator.func,
|
||||
json_schema_input_type=decorator.info.json_schema_input_type,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true)
|
||||
class WrapValidator:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.10/concepts/validators/#field-validators
|
||||
|
||||
A metadata class that indicates that a validation should be applied **around** the inner validation logic.
|
||||
|
||||
Attributes:
|
||||
func: The validator function.
|
||||
json_schema_input_type: The input type of the function. This is only used to generate the appropriate
|
||||
JSON Schema (in validation mode).
|
||||
|
||||
```python
|
||||
from datetime import datetime
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic import BaseModel, ValidationError, WrapValidator
|
||||
|
||||
def validate_timestamp(v, handler):
|
||||
if v == 'now':
|
||||
# we don't want to bother with further validation, just return the new value
|
||||
return datetime.now()
|
||||
try:
|
||||
return handler(v)
|
||||
except ValidationError:
|
||||
# validation failed, in this case we want to return a default value
|
||||
return datetime(2000, 1, 1)
|
||||
|
||||
MyTimestamp = Annotated[datetime, WrapValidator(validate_timestamp)]
|
||||
|
||||
class Model(BaseModel):
|
||||
a: MyTimestamp
|
||||
|
||||
print(Model(a='now').a)
|
||||
#> 2032-01-02 03:04:05.000006
|
||||
print(Model(a='invalid').a)
|
||||
#> 2000-01-01 00:00:00
|
||||
```
|
||||
"""
|
||||
|
||||
func: core_schema.NoInfoWrapValidatorFunction | core_schema.WithInfoWrapValidatorFunction
|
||||
json_schema_input_type: Any = PydanticUndefined
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
schema = handler(source_type)
|
||||
input_schema = (
|
||||
None
|
||||
if self.json_schema_input_type is PydanticUndefined
|
||||
else handler.generate_schema(self.json_schema_input_type)
|
||||
)
|
||||
|
||||
info_arg = _inspect_validator(self.func, 'wrap')
|
||||
if info_arg:
|
||||
func = cast(core_schema.WithInfoWrapValidatorFunction, self.func)
|
||||
return core_schema.with_info_wrap_validator_function(
|
||||
func,
|
||||
schema=schema,
|
||||
field_name=handler.field_name,
|
||||
json_schema_input_schema=input_schema,
|
||||
)
|
||||
else:
|
||||
func = cast(core_schema.NoInfoWrapValidatorFunction, self.func)
|
||||
return core_schema.no_info_wrap_validator_function(
|
||||
func,
|
||||
schema=schema,
|
||||
json_schema_input_schema=input_schema,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self:
|
||||
return cls(
|
||||
func=decorator.func,
|
||||
json_schema_input_type=decorator.info.json_schema_input_type,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _OnlyValueValidatorClsMethod(Protocol):
|
||||
def __call__(self, cls: Any, value: Any, /) -> Any: ...
|
||||
|
||||
class _V2ValidatorClsMethod(Protocol):
|
||||
def __call__(self, cls: Any, value: Any, info: _core_schema.ValidationInfo, /) -> Any: ...
|
||||
|
||||
class _OnlyValueWrapValidatorClsMethod(Protocol):
|
||||
def __call__(self, cls: Any, value: Any, handler: _core_schema.ValidatorFunctionWrapHandler, /) -> Any: ...
|
||||
|
||||
class _V2WrapValidatorClsMethod(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
cls: Any,
|
||||
value: Any,
|
||||
handler: _core_schema.ValidatorFunctionWrapHandler,
|
||||
info: _core_schema.ValidationInfo,
|
||||
/,
|
||||
) -> Any: ...
|
||||
|
||||
_V2Validator = Union[
|
||||
_V2ValidatorClsMethod,
|
||||
_core_schema.WithInfoValidatorFunction,
|
||||
_OnlyValueValidatorClsMethod,
|
||||
_core_schema.NoInfoValidatorFunction,
|
||||
]
|
||||
|
||||
_V2WrapValidator = Union[
|
||||
_V2WrapValidatorClsMethod,
|
||||
_core_schema.WithInfoWrapValidatorFunction,
|
||||
_OnlyValueWrapValidatorClsMethod,
|
||||
_core_schema.NoInfoWrapValidatorFunction,
|
||||
]
|
||||
|
||||
_PartialClsOrStaticMethod: TypeAlias = Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any]]
|
||||
|
||||
_V2BeforeAfterOrPlainValidatorType = TypeVar(
|
||||
'_V2BeforeAfterOrPlainValidatorType',
|
||||
bound=Union[_V2Validator, _PartialClsOrStaticMethod],
|
||||
)
|
||||
_V2WrapValidatorType = TypeVar('_V2WrapValidatorType', bound=Union[_V2WrapValidator, _PartialClsOrStaticMethod])
|
||||
|
||||
FieldValidatorModes: TypeAlias = Literal['before', 'after', 'wrap', 'plain']
|
||||
|
||||
|
||||
@overload
|
||||
def field_validator(
|
||||
field: str,
|
||||
/,
|
||||
*fields: str,
|
||||
mode: Literal['wrap'],
|
||||
check_fields: bool | None = ...,
|
||||
json_schema_input_type: Any = ...,
|
||||
) -> Callable[[_V2WrapValidatorType], _V2WrapValidatorType]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def field_validator(
|
||||
field: str,
|
||||
/,
|
||||
*fields: str,
|
||||
mode: Literal['before', 'plain'],
|
||||
check_fields: bool | None = ...,
|
||||
json_schema_input_type: Any = ...,
|
||||
) -> Callable[[_V2BeforeAfterOrPlainValidatorType], _V2BeforeAfterOrPlainValidatorType]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def field_validator(
|
||||
field: str,
|
||||
/,
|
||||
*fields: str,
|
||||
mode: Literal['after'] = ...,
|
||||
check_fields: bool | None = ...,
|
||||
) -> Callable[[_V2BeforeAfterOrPlainValidatorType], _V2BeforeAfterOrPlainValidatorType]: ...
|
||||
|
||||
|
||||
def field_validator(
|
||||
field: str,
|
||||
/,
|
||||
*fields: str,
|
||||
mode: FieldValidatorModes = 'after',
|
||||
check_fields: bool | None = None,
|
||||
json_schema_input_type: Any = PydanticUndefined,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.10/concepts/validators/#field-validators
|
||||
|
||||
Decorate methods on the class indicating that they should be used to validate fields.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
from typing import Any
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
class Model(BaseModel):
|
||||
a: str
|
||||
|
||||
@field_validator('a')
|
||||
@classmethod
|
||||
def ensure_foobar(cls, v: Any):
|
||||
if 'foobar' not in v:
|
||||
raise ValueError('"foobar" not found in a')
|
||||
return v
|
||||
|
||||
print(repr(Model(a='this is foobar good')))
|
||||
#> Model(a='this is foobar good')
|
||||
|
||||
try:
|
||||
Model(a='snap')
|
||||
except ValidationError as exc_info:
|
||||
print(exc_info)
|
||||
'''
|
||||
1 validation error for Model
|
||||
a
|
||||
Value error, "foobar" not found in a [type=value_error, input_value='snap', input_type=str]
|
||||
'''
|
||||
```
|
||||
|
||||
For more in depth examples, see [Field Validators](../concepts/validators.md#field-validators).
|
||||
|
||||
Args:
|
||||
field: The first field the `field_validator` should be called on; this is separate
|
||||
from `fields` to ensure an error is raised if you don't pass at least one.
|
||||
*fields: Additional field(s) the `field_validator` should be called on.
|
||||
mode: Specifies whether to validate the fields before or after validation.
|
||||
check_fields: Whether to check that the fields actually exist on the model.
|
||||
json_schema_input_type: The input type of the function. This is only used to generate
|
||||
the appropriate JSON Schema (in validation mode) and can only specified
|
||||
when `mode` is either `'before'`, `'plain'` or `'wrap'`.
|
||||
|
||||
Returns:
|
||||
A decorator that can be used to decorate a function to be used as a field_validator.
|
||||
|
||||
Raises:
|
||||
PydanticUserError:
|
||||
- If `@field_validator` is used bare (with no fields).
|
||||
- If the args passed to `@field_validator` as fields are not strings.
|
||||
- If `@field_validator` applied to instance methods.
|
||||
"""
|
||||
if isinstance(field, FunctionType):
|
||||
raise PydanticUserError(
|
||||
'`@field_validator` should be used with fields and keyword arguments, not bare. '
|
||||
"E.g. usage should be `@validator('<field_name>', ...)`",
|
||||
code='validator-no-fields',
|
||||
)
|
||||
|
||||
if mode not in ('before', 'plain', 'wrap') and json_schema_input_type is not PydanticUndefined:
|
||||
raise PydanticUserError(
|
||||
f"`json_schema_input_type` can't be used when mode is set to {mode!r}",
|
||||
code='validator-input-type',
|
||||
)
|
||||
|
||||
if json_schema_input_type is PydanticUndefined and mode == 'plain':
|
||||
json_schema_input_type = Any
|
||||
|
||||
fields = field, *fields
|
||||
if not all(isinstance(field, str) for field in fields):
|
||||
raise PydanticUserError(
|
||||
'`@field_validator` fields should be passed as separate string args. '
|
||||
"E.g. usage should be `@validator('<field_name_1>', '<field_name_2>', ...)`",
|
||||
code='validator-invalid-fields',
|
||||
)
|
||||
|
||||
def dec(
|
||||
f: Callable[..., Any] | staticmethod[Any, Any] | classmethod[Any, Any, Any],
|
||||
) -> _decorators.PydanticDescriptorProxy[Any]:
|
||||
if _decorators.is_instance_method_from_sig(f):
|
||||
raise PydanticUserError(
|
||||
'`@field_validator` cannot be applied to instance methods', code='validator-instance-method'
|
||||
)
|
||||
|
||||
# auto apply the @classmethod decorator
|
||||
f = _decorators.ensure_classmethod_based_on_signature(f)
|
||||
|
||||
dec_info = _decorators.FieldValidatorDecoratorInfo(
|
||||
fields=fields, mode=mode, check_fields=check_fields, json_schema_input_type=json_schema_input_type
|
||||
)
|
||||
return _decorators.PydanticDescriptorProxy(f, dec_info)
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
_ModelType = TypeVar('_ModelType')
|
||||
_ModelTypeCo = TypeVar('_ModelTypeCo', covariant=True)
|
||||
|
||||
|
||||
class ModelWrapValidatorHandler(_core_schema.ValidatorFunctionWrapHandler, Protocol[_ModelTypeCo]):
|
||||
"""`@model_validator` decorated function handler argument type. This is used when `mode='wrap'`."""
|
||||
|
||||
def __call__( # noqa: D102
|
||||
self,
|
||||
value: Any,
|
||||
outer_location: str | int | None = None,
|
||||
/,
|
||||
) -> _ModelTypeCo: # pragma: no cover
|
||||
...
|
||||
|
||||
|
||||
class ModelWrapValidatorWithoutInfo(Protocol[_ModelType]):
|
||||
"""A `@model_validator` decorated function signature.
|
||||
This is used when `mode='wrap'` and the function does not have info argument.
|
||||
"""
|
||||
|
||||
def __call__( # noqa: D102
|
||||
self,
|
||||
cls: type[_ModelType],
|
||||
# this can be a dict, a model instance
|
||||
# or anything else that gets passed to validate_python
|
||||
# thus validators _must_ handle all cases
|
||||
value: Any,
|
||||
handler: ModelWrapValidatorHandler[_ModelType],
|
||||
/,
|
||||
) -> _ModelType: ...
|
||||
|
||||
|
||||
class ModelWrapValidator(Protocol[_ModelType]):
|
||||
"""A `@model_validator` decorated function signature. This is used when `mode='wrap'`."""
|
||||
|
||||
def __call__( # noqa: D102
|
||||
self,
|
||||
cls: type[_ModelType],
|
||||
# this can be a dict, a model instance
|
||||
# or anything else that gets passed to validate_python
|
||||
# thus validators _must_ handle all cases
|
||||
value: Any,
|
||||
handler: ModelWrapValidatorHandler[_ModelType],
|
||||
info: _core_schema.ValidationInfo,
|
||||
/,
|
||||
) -> _ModelType: ...
|
||||
|
||||
|
||||
class FreeModelBeforeValidatorWithoutInfo(Protocol):
|
||||
"""A `@model_validator` decorated function signature.
|
||||
This is used when `mode='before'` and the function does not have info argument.
|
||||
"""
|
||||
|
||||
def __call__( # noqa: D102
|
||||
self,
|
||||
# this can be a dict, a model instance
|
||||
# or anything else that gets passed to validate_python
|
||||
# thus validators _must_ handle all cases
|
||||
value: Any,
|
||||
/,
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class ModelBeforeValidatorWithoutInfo(Protocol):
|
||||
"""A `@model_validator` decorated function signature.
|
||||
This is used when `mode='before'` and the function does not have info argument.
|
||||
"""
|
||||
|
||||
def __call__( # noqa: D102
|
||||
self,
|
||||
cls: Any,
|
||||
# this can be a dict, a model instance
|
||||
# or anything else that gets passed to validate_python
|
||||
# thus validators _must_ handle all cases
|
||||
value: Any,
|
||||
/,
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class FreeModelBeforeValidator(Protocol):
|
||||
"""A `@model_validator` decorated function signature. This is used when `mode='before'`."""
|
||||
|
||||
def __call__( # noqa: D102
|
||||
self,
|
||||
# this can be a dict, a model instance
|
||||
# or anything else that gets passed to validate_python
|
||||
# thus validators _must_ handle all cases
|
||||
value: Any,
|
||||
info: _core_schema.ValidationInfo,
|
||||
/,
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class ModelBeforeValidator(Protocol):
|
||||
"""A `@model_validator` decorated function signature. This is used when `mode='before'`."""
|
||||
|
||||
def __call__( # noqa: D102
|
||||
self,
|
||||
cls: Any,
|
||||
# this can be a dict, a model instance
|
||||
# or anything else that gets passed to validate_python
|
||||
# thus validators _must_ handle all cases
|
||||
value: Any,
|
||||
info: _core_schema.ValidationInfo,
|
||||
/,
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
ModelAfterValidatorWithoutInfo = Callable[[_ModelType], _ModelType]
|
||||
"""A `@model_validator` decorated function signature. This is used when `mode='after'` and the function does not
|
||||
have info argument.
|
||||
"""
|
||||
|
||||
ModelAfterValidator = Callable[[_ModelType, _core_schema.ValidationInfo], _ModelType]
|
||||
"""A `@model_validator` decorated function signature. This is used when `mode='after'`."""
|
||||
|
||||
_AnyModelWrapValidator = Union[ModelWrapValidator[_ModelType], ModelWrapValidatorWithoutInfo[_ModelType]]
|
||||
_AnyModelBeforeValidator = Union[
|
||||
FreeModelBeforeValidator, ModelBeforeValidator, FreeModelBeforeValidatorWithoutInfo, ModelBeforeValidatorWithoutInfo
|
||||
]
|
||||
_AnyModelAfterValidator = Union[ModelAfterValidator[_ModelType], ModelAfterValidatorWithoutInfo[_ModelType]]
|
||||
|
||||
|
||||
@overload
|
||||
def model_validator(
|
||||
*,
|
||||
mode: Literal['wrap'],
|
||||
) -> Callable[
|
||||
[_AnyModelWrapValidator[_ModelType]], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def model_validator(
|
||||
*,
|
||||
mode: Literal['before'],
|
||||
) -> Callable[
|
||||
[_AnyModelBeforeValidator], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def model_validator(
|
||||
*,
|
||||
mode: Literal['after'],
|
||||
) -> Callable[
|
||||
[_AnyModelAfterValidator[_ModelType]], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]
|
||||
]: ...
|
||||
|
||||
|
||||
def model_validator(
|
||||
*,
|
||||
mode: Literal['wrap', 'before', 'after'],
|
||||
) -> Any:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.10/concepts/validators/#model-validators
|
||||
|
||||
Decorate model methods for validation purposes.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
from typing_extensions import Self
|
||||
|
||||
from pydantic import BaseModel, ValidationError, model_validator
|
||||
|
||||
class Square(BaseModel):
|
||||
width: float
|
||||
height: float
|
||||
|
||||
@model_validator(mode='after')
|
||||
def verify_square(self) -> Self:
|
||||
if self.width != self.height:
|
||||
raise ValueError('width and height do not match')
|
||||
return self
|
||||
|
||||
s = Square(width=1, height=1)
|
||||
print(repr(s))
|
||||
#> Square(width=1.0, height=1.0)
|
||||
|
||||
try:
|
||||
Square(width=1, height=2)
|
||||
except ValidationError as e:
|
||||
print(e)
|
||||
'''
|
||||
1 validation error for Square
|
||||
Value error, width and height do not match [type=value_error, input_value={'width': 1, 'height': 2}, input_type=dict]
|
||||
'''
|
||||
```
|
||||
|
||||
For more in depth examples, see [Model Validators](../concepts/validators.md#model-validators).
|
||||
|
||||
Args:
|
||||
mode: A required string literal that specifies the validation mode.
|
||||
It can be one of the following: 'wrap', 'before', or 'after'.
|
||||
|
||||
Returns:
|
||||
A decorator that can be used to decorate a function to be used as a model validator.
|
||||
"""
|
||||
|
||||
def dec(f: Any) -> _decorators.PydanticDescriptorProxy[Any]:
|
||||
# auto apply the @classmethod decorator
|
||||
f = _decorators.ensure_classmethod_based_on_signature(f)
|
||||
dec_info = _decorators.ModelValidatorDecoratorInfo(mode=mode)
|
||||
return _decorators.PydanticDescriptorProxy(f, dec_info)
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
AnyType = TypeVar('AnyType')
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# If we add configurable attributes to IsInstance, we'd probably need to stop hiding it from type checkers like this
|
||||
InstanceOf = Annotated[AnyType, ...] # `IsInstance[Sequence]` will be recognized by type checkers as `Sequence`
|
||||
|
||||
else:
|
||||
|
||||
@dataclasses.dataclass(**_internal_dataclass.slots_true)
|
||||
class InstanceOf:
|
||||
'''Generic type for annotating a type that is an instance of a given class.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from pydantic import BaseModel, InstanceOf
|
||||
|
||||
class Foo:
|
||||
...
|
||||
|
||||
class Bar(BaseModel):
|
||||
foo: InstanceOf[Foo]
|
||||
|
||||
Bar(foo=Foo())
|
||||
try:
|
||||
Bar(foo=42)
|
||||
except ValidationError as e:
|
||||
print(e)
|
||||
"""
|
||||
[
|
||||
│ {
|
||||
│ │ 'type': 'is_instance_of',
|
||||
│ │ 'loc': ('foo',),
|
||||
│ │ 'msg': 'Input should be an instance of Foo',
|
||||
│ │ 'input': 42,
|
||||
│ │ 'ctx': {'class': 'Foo'},
|
||||
│ │ 'url': 'https://errors.pydantic.dev/0.38.0/v/is_instance_of'
|
||||
│ }
|
||||
]
|
||||
"""
|
||||
```
|
||||
'''
|
||||
|
||||
@classmethod
|
||||
def __class_getitem__(cls, item: AnyType) -> AnyType:
|
||||
return Annotated[item, cls()]
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
from pydantic import PydanticSchemaGenerationError
|
||||
|
||||
# use the generic _origin_ as the second argument to isinstance when appropriate
|
||||
instance_of_schema = core_schema.is_instance_schema(_generics.get_origin(source) or source)
|
||||
|
||||
try:
|
||||
# Try to generate the "standard" schema, which will be used when loading from JSON
|
||||
original_schema = handler(source)
|
||||
except PydanticSchemaGenerationError:
|
||||
# If that fails, just produce a schema that can validate from python
|
||||
return instance_of_schema
|
||||
else:
|
||||
# Use the "original" approach to serialization
|
||||
instance_of_schema['serialization'] = core_schema.wrap_serializer_function_ser_schema(
|
||||
function=lambda v, h: h(v), schema=original_schema
|
||||
)
|
||||
return core_schema.json_or_python_schema(python_schema=instance_of_schema, json_schema=original_schema)
|
||||
|
||||
__hash__ = object.__hash__
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
SkipValidation = Annotated[AnyType, ...] # SkipValidation[list[str]] will be treated by type checkers as list[str]
|
||||
else:
|
||||
|
||||
@dataclasses.dataclass(**_internal_dataclass.slots_true)
|
||||
class SkipValidation:
|
||||
"""If this is applied as an annotation (e.g., via `x: Annotated[int, SkipValidation]`), validation will be
|
||||
skipped. You can also use `SkipValidation[int]` as a shorthand for `Annotated[int, SkipValidation]`.
|
||||
|
||||
This can be useful if you want to use a type annotation for documentation/IDE/type-checking purposes,
|
||||
and know that it is safe to skip validation for one or more of the fields.
|
||||
|
||||
Because this converts the validation schema to `any_schema`, subsequent annotation-applied transformations
|
||||
may not have the expected effects. Therefore, when used, this annotation should generally be the final
|
||||
annotation applied to a type.
|
||||
"""
|
||||
|
||||
def __class_getitem__(cls, item: Any) -> Any:
|
||||
return Annotated[item, SkipValidation()]
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
original_schema = handler(source)
|
||||
metadata = {'pydantic_js_annotation_functions': [lambda _c, h: h(original_schema)]}
|
||||
return core_schema.any_schema(
|
||||
metadata=metadata,
|
||||
serialization=core_schema.wrap_serializer_function_ser_schema(
|
||||
function=lambda v, h: h(v), schema=original_schema
|
||||
),
|
||||
)
|
||||
|
||||
__hash__ = object.__hash__
|
||||
@@ -0,0 +1,5 @@
|
||||
"""The `generics` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
@@ -0,0 +1,5 @@
|
||||
"""The `json` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,5 @@
|
||||
"""The `parse` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
@@ -0,0 +1,171 @@
|
||||
"""Usage docs: https://docs.pydantic.dev/2.10/concepts/plugins#build-a-plugin
|
||||
|
||||
Plugin interface for Pydantic plugins, and related types.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, NamedTuple
|
||||
|
||||
from pydantic_core import CoreConfig, CoreSchema, ValidationError
|
||||
from typing_extensions import Literal, Protocol, TypeAlias
|
||||
|
||||
__all__ = (
|
||||
'PydanticPluginProtocol',
|
||||
'BaseValidateHandlerProtocol',
|
||||
'ValidatePythonHandlerProtocol',
|
||||
'ValidateJsonHandlerProtocol',
|
||||
'ValidateStringsHandlerProtocol',
|
||||
'NewSchemaReturns',
|
||||
'SchemaTypePath',
|
||||
'SchemaKind',
|
||||
)
|
||||
|
||||
NewSchemaReturns: TypeAlias = 'tuple[ValidatePythonHandlerProtocol | None, ValidateJsonHandlerProtocol | None, ValidateStringsHandlerProtocol | None]'
|
||||
|
||||
|
||||
class SchemaTypePath(NamedTuple):
|
||||
"""Path defining where `schema_type` was defined, or where `TypeAdapter` was called."""
|
||||
|
||||
module: str
|
||||
name: str
|
||||
|
||||
|
||||
SchemaKind: TypeAlias = Literal['BaseModel', 'TypeAdapter', 'dataclass', 'create_model', 'validate_call']
|
||||
|
||||
|
||||
class PydanticPluginProtocol(Protocol):
|
||||
"""Protocol defining the interface for Pydantic plugins."""
|
||||
|
||||
def new_schema_validator(
|
||||
self,
|
||||
schema: CoreSchema,
|
||||
schema_type: Any,
|
||||
schema_type_path: SchemaTypePath,
|
||||
schema_kind: SchemaKind,
|
||||
config: CoreConfig | None,
|
||||
plugin_settings: dict[str, object],
|
||||
) -> tuple[
|
||||
ValidatePythonHandlerProtocol | None, ValidateJsonHandlerProtocol | None, ValidateStringsHandlerProtocol | None
|
||||
]:
|
||||
"""This method is called for each plugin every time a new [`SchemaValidator`][pydantic_core.SchemaValidator]
|
||||
is created.
|
||||
|
||||
It should return an event handler for each of the three validation methods, or `None` if the plugin does not
|
||||
implement that method.
|
||||
|
||||
Args:
|
||||
schema: The schema to validate against.
|
||||
schema_type: The original type which the schema was created from, e.g. the model class.
|
||||
schema_type_path: Path defining where `schema_type` was defined, or where `TypeAdapter` was called.
|
||||
schema_kind: The kind of schema to validate against.
|
||||
config: The config to use for validation.
|
||||
plugin_settings: Any plugin settings.
|
||||
|
||||
Returns:
|
||||
A tuple of optional event handlers for each of the three validation methods -
|
||||
`validate_python`, `validate_json`, `validate_strings`.
|
||||
"""
|
||||
raise NotImplementedError('Pydantic plugins should implement `new_schema_validator`.')
|
||||
|
||||
|
||||
class BaseValidateHandlerProtocol(Protocol):
|
||||
"""Base class for plugin callbacks protocols.
|
||||
|
||||
You shouldn't implement this protocol directly, instead use one of the subclasses with adds the correctly
|
||||
typed `on_error` method.
|
||||
"""
|
||||
|
||||
on_enter: Callable[..., None]
|
||||
"""`on_enter` is changed to be more specific on all subclasses"""
|
||||
|
||||
def on_success(self, result: Any) -> None:
|
||||
"""Callback to be notified of successful validation.
|
||||
|
||||
Args:
|
||||
result: The result of the validation.
|
||||
"""
|
||||
return
|
||||
|
||||
def on_error(self, error: ValidationError) -> None:
|
||||
"""Callback to be notified of validation errors.
|
||||
|
||||
Args:
|
||||
error: The validation error.
|
||||
"""
|
||||
return
|
||||
|
||||
def on_exception(self, exception: Exception) -> None:
|
||||
"""Callback to be notified of validation exceptions.
|
||||
|
||||
Args:
|
||||
exception: The exception raised during validation.
|
||||
"""
|
||||
return
|
||||
|
||||
|
||||
class ValidatePythonHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
|
||||
"""Event handler for `SchemaValidator.validate_python`."""
|
||||
|
||||
def on_enter(
|
||||
self,
|
||||
input: Any,
|
||||
*,
|
||||
strict: bool | None = None,
|
||||
from_attributes: bool | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
self_instance: Any | None = None,
|
||||
) -> None:
|
||||
"""Callback to be notified of validation start, and create an instance of the event handler.
|
||||
|
||||
Args:
|
||||
input: The input to be validated.
|
||||
strict: Whether to validate the object in strict mode.
|
||||
from_attributes: Whether to validate objects as inputs by extracting attributes.
|
||||
context: The context to use for validation, this is passed to functional validators.
|
||||
self_instance: An instance of a model to set attributes on from validation, this is used when running
|
||||
validation from the `__init__` method of a model.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ValidateJsonHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
|
||||
"""Event handler for `SchemaValidator.validate_json`."""
|
||||
|
||||
def on_enter(
|
||||
self,
|
||||
input: str | bytes | bytearray,
|
||||
*,
|
||||
strict: bool | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
self_instance: Any | None = None,
|
||||
) -> None:
|
||||
"""Callback to be notified of validation start, and create an instance of the event handler.
|
||||
|
||||
Args:
|
||||
input: The JSON data to be validated.
|
||||
strict: Whether to validate the object in strict mode.
|
||||
context: The context to use for validation, this is passed to functional validators.
|
||||
self_instance: An instance of a model to set attributes on from validation, this is used when running
|
||||
validation from the `__init__` method of a model.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
StringInput: TypeAlias = 'dict[str, StringInput]'
|
||||
|
||||
|
||||
class ValidateStringsHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
|
||||
"""Event handler for `SchemaValidator.validate_strings`."""
|
||||
|
||||
def on_enter(
|
||||
self, input: StringInput, *, strict: bool | None = None, context: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
"""Callback to be notified of validation start, and create an instance of the event handler.
|
||||
|
||||
Args:
|
||||
input: The string data to be validated.
|
||||
strict: Whether to validate the object in strict mode.
|
||||
context: The context to use for validation, this is passed to functional validators.
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.metadata as importlib_metadata
|
||||
import os
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Final, Iterable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import PydanticPluginProtocol
|
||||
|
||||
|
||||
PYDANTIC_ENTRY_POINT_GROUP: Final[str] = 'pydantic'
|
||||
|
||||
# cache of plugins
|
||||
_plugins: dict[str, PydanticPluginProtocol] | None = None
|
||||
# return no plugins while loading plugins to avoid recursion and errors while import plugins
|
||||
# this means that if plugins use pydantic
|
||||
_loading_plugins: bool = False
|
||||
|
||||
|
||||
def get_plugins() -> Iterable[PydanticPluginProtocol]:
|
||||
"""Load plugins for Pydantic.
|
||||
|
||||
Inspired by: https://github.com/pytest-dev/pluggy/blob/1.3.0/src/pluggy/_manager.py#L376-L402
|
||||
"""
|
||||
disabled_plugins = os.getenv('PYDANTIC_DISABLE_PLUGINS')
|
||||
global _plugins, _loading_plugins
|
||||
if _loading_plugins:
|
||||
# this happens when plugins themselves use pydantic, we return no plugins
|
||||
return ()
|
||||
elif disabled_plugins in ('__all__', '1', 'true'):
|
||||
return ()
|
||||
elif _plugins is None:
|
||||
_plugins = {}
|
||||
# set _loading_plugins so any plugins that use pydantic don't themselves use plugins
|
||||
_loading_plugins = True
|
||||
try:
|
||||
for dist in importlib_metadata.distributions():
|
||||
for entry_point in dist.entry_points:
|
||||
if entry_point.group != PYDANTIC_ENTRY_POINT_GROUP:
|
||||
continue
|
||||
if entry_point.value in _plugins:
|
||||
continue
|
||||
if disabled_plugins is not None and entry_point.name in disabled_plugins.split(','):
|
||||
continue
|
||||
try:
|
||||
_plugins[entry_point.value] = entry_point.load()
|
||||
except (ImportError, AttributeError) as e:
|
||||
warnings.warn(
|
||||
f'{e.__class__.__name__} while loading the `{entry_point.name}` Pydantic plugin, '
|
||||
f'this plugin will not be installed.\n\n{e!r}'
|
||||
)
|
||||
finally:
|
||||
_loading_plugins = False
|
||||
|
||||
return _plugins.values()
|
||||
@@ -0,0 +1,139 @@
|
||||
"""Pluggable schema validator for pydantic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar
|
||||
|
||||
from pydantic_core import CoreConfig, CoreSchema, SchemaValidator, ValidationError
|
||||
from typing_extensions import Literal, ParamSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import BaseValidateHandlerProtocol, PydanticPluginProtocol, SchemaKind, SchemaTypePath
|
||||
|
||||
|
||||
P = ParamSpec('P')
|
||||
R = TypeVar('R')
|
||||
Event = Literal['on_validate_python', 'on_validate_json', 'on_validate_strings']
|
||||
events: list[Event] = list(Event.__args__) # type: ignore
|
||||
|
||||
|
||||
def create_schema_validator(
|
||||
schema: CoreSchema,
|
||||
schema_type: Any,
|
||||
schema_type_module: str,
|
||||
schema_type_name: str,
|
||||
schema_kind: SchemaKind,
|
||||
config: CoreConfig | None = None,
|
||||
plugin_settings: dict[str, Any] | None = None,
|
||||
) -> SchemaValidator | PluggableSchemaValidator:
|
||||
"""Create a `SchemaValidator` or `PluggableSchemaValidator` if plugins are installed.
|
||||
|
||||
Returns:
|
||||
If plugins are installed then return `PluggableSchemaValidator`, otherwise return `SchemaValidator`.
|
||||
"""
|
||||
from . import SchemaTypePath
|
||||
from ._loader import get_plugins
|
||||
|
||||
plugins = get_plugins()
|
||||
if plugins:
|
||||
return PluggableSchemaValidator(
|
||||
schema,
|
||||
schema_type,
|
||||
SchemaTypePath(schema_type_module, schema_type_name),
|
||||
schema_kind,
|
||||
config,
|
||||
plugins,
|
||||
plugin_settings or {},
|
||||
)
|
||||
else:
|
||||
return SchemaValidator(schema, config)
|
||||
|
||||
|
||||
class PluggableSchemaValidator:
|
||||
"""Pluggable schema validator."""
|
||||
|
||||
__slots__ = '_schema_validator', 'validate_json', 'validate_python', 'validate_strings'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: CoreSchema,
|
||||
schema_type: Any,
|
||||
schema_type_path: SchemaTypePath,
|
||||
schema_kind: SchemaKind,
|
||||
config: CoreConfig | None,
|
||||
plugins: Iterable[PydanticPluginProtocol],
|
||||
plugin_settings: dict[str, Any],
|
||||
) -> None:
|
||||
self._schema_validator = SchemaValidator(schema, config)
|
||||
|
||||
python_event_handlers: list[BaseValidateHandlerProtocol] = []
|
||||
json_event_handlers: list[BaseValidateHandlerProtocol] = []
|
||||
strings_event_handlers: list[BaseValidateHandlerProtocol] = []
|
||||
for plugin in plugins:
|
||||
try:
|
||||
p, j, s = plugin.new_schema_validator(
|
||||
schema, schema_type, schema_type_path, schema_kind, config, plugin_settings
|
||||
)
|
||||
except TypeError as e: # pragma: no cover
|
||||
raise TypeError(f'Error using plugin `{plugin.__module__}:{plugin.__class__.__name__}`: {e}') from e
|
||||
if p is not None:
|
||||
python_event_handlers.append(p)
|
||||
if j is not None:
|
||||
json_event_handlers.append(j)
|
||||
if s is not None:
|
||||
strings_event_handlers.append(s)
|
||||
|
||||
self.validate_python = build_wrapper(self._schema_validator.validate_python, python_event_handlers)
|
||||
self.validate_json = build_wrapper(self._schema_validator.validate_json, json_event_handlers)
|
||||
self.validate_strings = build_wrapper(self._schema_validator.validate_strings, strings_event_handlers)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._schema_validator, name)
|
||||
|
||||
|
||||
def build_wrapper(func: Callable[P, R], event_handlers: list[BaseValidateHandlerProtocol]) -> Callable[P, R]:
|
||||
if not event_handlers:
|
||||
return func
|
||||
else:
|
||||
on_enters = tuple(h.on_enter for h in event_handlers if filter_handlers(h, 'on_enter'))
|
||||
on_successes = tuple(h.on_success for h in event_handlers if filter_handlers(h, 'on_success'))
|
||||
on_errors = tuple(h.on_error for h in event_handlers if filter_handlers(h, 'on_error'))
|
||||
on_exceptions = tuple(h.on_exception for h in event_handlers if filter_handlers(h, 'on_exception'))
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
for on_enter_handler in on_enters:
|
||||
on_enter_handler(*args, **kwargs)
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
except ValidationError as error:
|
||||
for on_error_handler in on_errors:
|
||||
on_error_handler(error)
|
||||
raise
|
||||
except Exception as exception:
|
||||
for on_exception_handler in on_exceptions:
|
||||
on_exception_handler(exception)
|
||||
raise
|
||||
else:
|
||||
for on_success_handler in on_successes:
|
||||
on_success_handler(result)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def filter_handlers(handler_cls: BaseValidateHandlerProtocol, method_name: str) -> bool:
|
||||
"""Filter out handler methods which are not implemented by the plugin directly - e.g. are missing
|
||||
or are inherited from the protocol.
|
||||
"""
|
||||
handler = getattr(handler_cls, method_name, None)
|
||||
if handler is None:
|
||||
return False
|
||||
elif handler.__module__ == 'pydantic.plugin':
|
||||
# this is the original handler, from the protocol due to runtime inheritance
|
||||
# we don't want to call it
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
@@ -0,0 +1,156 @@
|
||||
"""RootModel class and type definitions."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import typing
|
||||
from copy import copy, deepcopy
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from . import PydanticUserError
|
||||
from ._internal import _model_construction, _repr
|
||||
from .main import BaseModel, _object_setattr
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import Literal, Self, dataclass_transform
|
||||
|
||||
from .fields import Field as PydanticModelField
|
||||
from .fields import PrivateAttr as PydanticModelPrivateAttr
|
||||
|
||||
# dataclass_transform could be applied to RootModel directly, but `ModelMetaclass`'s dataclass_transform
|
||||
# takes priority (at least with pyright). We trick type checkers into thinking we apply dataclass_transform
|
||||
# on a new metaclass.
|
||||
@dataclass_transform(kw_only_default=False, field_specifiers=(PydanticModelField, PydanticModelPrivateAttr))
|
||||
class _RootModelMetaclass(_model_construction.ModelMetaclass): ...
|
||||
else:
|
||||
_RootModelMetaclass = _model_construction.ModelMetaclass
|
||||
|
||||
__all__ = ('RootModel',)
|
||||
|
||||
RootModelRootType = typing.TypeVar('RootModelRootType')
|
||||
|
||||
|
||||
class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootModelMetaclass):
|
||||
"""Usage docs: https://docs.pydantic.dev/2.10/concepts/models/#rootmodel-and-custom-root-types
|
||||
|
||||
A Pydantic `BaseModel` for the root object of the model.
|
||||
|
||||
Attributes:
|
||||
root: The root object of the model.
|
||||
__pydantic_root_model__: Whether the model is a RootModel.
|
||||
__pydantic_private__: Private fields in the model.
|
||||
__pydantic_extra__: Extra fields in the model.
|
||||
|
||||
"""
|
||||
|
||||
__pydantic_root_model__ = True
|
||||
__pydantic_private__ = None
|
||||
__pydantic_extra__ = None
|
||||
|
||||
root: RootModelRootType
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
extra = cls.model_config.get('extra')
|
||||
if extra is not None:
|
||||
raise PydanticUserError(
|
||||
"`RootModel` does not support setting `model_config['extra']`", code='root-model-extra'
|
||||
)
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
def __init__(self, /, root: RootModelRootType = PydanticUndefined, **data) -> None: # type: ignore
|
||||
__tracebackhide__ = True
|
||||
if data:
|
||||
if root is not PydanticUndefined:
|
||||
raise ValueError(
|
||||
'"RootModel.__init__" accepts either a single positional argument or arbitrary keyword arguments'
|
||||
)
|
||||
root = data # type: ignore
|
||||
self.__pydantic_validator__.validate_python(root, self_instance=self)
|
||||
|
||||
__init__.__pydantic_base_init__ = True # pyright: ignore[reportFunctionMemberAccess]
|
||||
|
||||
@classmethod
|
||||
def model_construct(cls, root: RootModelRootType, _fields_set: set[str] | None = None) -> Self: # type: ignore
|
||||
"""Create a new model using the provided root object and update fields set.
|
||||
|
||||
Args:
|
||||
root: The root object of the model.
|
||||
_fields_set: The set of fields to be updated.
|
||||
|
||||
Returns:
|
||||
The new model.
|
||||
|
||||
Raises:
|
||||
NotImplemented: If the model is not a subclass of `RootModel`.
|
||||
"""
|
||||
return super().model_construct(root=root, _fields_set=_fields_set)
|
||||
|
||||
def __getstate__(self) -> dict[Any, Any]:
|
||||
return {
|
||||
'__dict__': self.__dict__,
|
||||
'__pydantic_fields_set__': self.__pydantic_fields_set__,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[Any, Any]) -> None:
|
||||
_object_setattr(self, '__pydantic_fields_set__', state['__pydantic_fields_set__'])
|
||||
_object_setattr(self, '__dict__', state['__dict__'])
|
||||
|
||||
def __copy__(self) -> Self:
|
||||
"""Returns a shallow copy of the model."""
|
||||
cls = type(self)
|
||||
m = cls.__new__(cls)
|
||||
_object_setattr(m, '__dict__', copy(self.__dict__))
|
||||
_object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__))
|
||||
return m
|
||||
|
||||
def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self:
|
||||
"""Returns a deep copy of the model."""
|
||||
cls = type(self)
|
||||
m = cls.__new__(cls)
|
||||
_object_setattr(m, '__dict__', deepcopy(self.__dict__, memo=memo))
|
||||
# This next line doesn't need a deepcopy because __pydantic_fields_set__ is a set[str],
|
||||
# and attempting a deepcopy would be marginally slower.
|
||||
_object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__))
|
||||
return m
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
|
||||
def model_dump( # type: ignore
|
||||
self,
|
||||
*,
|
||||
mode: Literal['json', 'python'] | str = 'python',
|
||||
include: Any = None,
|
||||
exclude: Any = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
round_trip: bool = False,
|
||||
warnings: bool | Literal['none', 'warn', 'error'] = True,
|
||||
serialize_as_any: bool = False,
|
||||
) -> Any:
|
||||
"""This method is included just to get a more accurate return type for type checkers.
|
||||
It is included in this `if TYPE_CHECKING:` block since no override is actually necessary.
|
||||
|
||||
See the documentation of `BaseModel.model_dump` for more details about the arguments.
|
||||
|
||||
Generally, this method will have a return type of `RootModelRootType`, assuming that `RootModelRootType` is
|
||||
not a `BaseModel` subclass. If `RootModelRootType` is a `BaseModel` subclass, then the return
|
||||
type will likely be `dict[str, Any]`, as `model_dump` calls are recursive. The return type could
|
||||
even be something different, in the case of a custom serializer.
|
||||
Thus, `Any` is used here to catch all of these cases.
|
||||
"""
|
||||
...
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, RootModel):
|
||||
return NotImplemented
|
||||
return self.__pydantic_fields__['root'].annotation == other.__pydantic_fields__[
|
||||
'root'
|
||||
].annotation and super().__eq__(other)
|
||||
|
||||
def __repr_args__(self) -> _repr.ReprArgs:
|
||||
yield 'root', self.root
|
||||
@@ -0,0 +1,5 @@
|
||||
"""The `schema` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
@@ -0,0 +1,5 @@
|
||||
"""The `tools` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
@@ -0,0 +1,676 @@
|
||||
"""Type adapter specification."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import sys
|
||||
from dataclasses import is_dataclass
|
||||
from types import FrameType
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
Iterable,
|
||||
Literal,
|
||||
TypeVar,
|
||||
cast,
|
||||
final,
|
||||
overload,
|
||||
)
|
||||
|
||||
from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator, Some
|
||||
from typing_extensions import ParamSpec, is_typeddict
|
||||
|
||||
from pydantic.errors import PydanticUserError
|
||||
from pydantic.main import BaseModel, IncEx
|
||||
|
||||
from ._internal import _config, _generate_schema, _mock_val_ser, _namespace_utils, _repr, _typing_extra, _utils
|
||||
from .config import ConfigDict
|
||||
from .errors import PydanticUndefinedAnnotation
|
||||
from .json_schema import (
|
||||
DEFAULT_REF_TEMPLATE,
|
||||
GenerateJsonSchema,
|
||||
JsonSchemaKeyT,
|
||||
JsonSchemaMode,
|
||||
JsonSchemaValue,
|
||||
)
|
||||
from .plugin._schema_validator import PluggableSchemaValidator, create_schema_validator
|
||||
|
||||
T = TypeVar('T')
|
||||
R = TypeVar('R')
|
||||
P = ParamSpec('P')
|
||||
TypeAdapterT = TypeVar('TypeAdapterT', bound='TypeAdapter')
|
||||
|
||||
|
||||
def _getattr_no_parents(obj: Any, attribute: str) -> Any:
|
||||
"""Returns the attribute value without attempting to look up attributes from parent types."""
|
||||
if hasattr(obj, '__dict__'):
|
||||
try:
|
||||
return obj.__dict__[attribute]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
slots = getattr(obj, '__slots__', None)
|
||||
if slots is not None and attribute in slots:
|
||||
return getattr(obj, attribute)
|
||||
else:
|
||||
raise AttributeError(attribute)
|
||||
|
||||
|
||||
def _type_has_config(type_: Any) -> bool:
|
||||
"""Returns whether the type has config."""
|
||||
type_ = _typing_extra.annotated_type(type_) or type_
|
||||
try:
|
||||
return issubclass(type_, BaseModel) or is_dataclass(type_) or is_typeddict(type_)
|
||||
except TypeError:
|
||||
# type is not a class
|
||||
return False
|
||||
|
||||
|
||||
@final
|
||||
class TypeAdapter(Generic[T]):
|
||||
"""Usage docs: https://docs.pydantic.dev/2.10/concepts/type_adapter/
|
||||
|
||||
Type adapters provide a flexible way to perform validation and serialization based on a Python type.
|
||||
|
||||
A `TypeAdapter` instance exposes some of the functionality from `BaseModel` instance methods
|
||||
for types that do not have such methods (such as dataclasses, primitive types, and more).
|
||||
|
||||
**Note:** `TypeAdapter` instances are not types, and cannot be used as type annotations for fields.
|
||||
|
||||
Args:
|
||||
type: The type associated with the `TypeAdapter`.
|
||||
config: Configuration for the `TypeAdapter`, should be a dictionary conforming to
|
||||
[`ConfigDict`][pydantic.config.ConfigDict].
|
||||
|
||||
!!! note
|
||||
You cannot provide a configuration when instantiating a `TypeAdapter` if the type you're using
|
||||
has its own config that cannot be overridden (ex: `BaseModel`, `TypedDict`, and `dataclass`). A
|
||||
[`type-adapter-config-unused`](../errors/usage_errors.md#type-adapter-config-unused) error will
|
||||
be raised in this case.
|
||||
_parent_depth: Depth at which to search for the [parent frame][frame-objects]. This frame is used when
|
||||
resolving forward annotations during schema building, by looking for the globals and locals of this
|
||||
frame. Defaults to 2, which will result in the frame where the `TypeAdapter` was instantiated.
|
||||
|
||||
!!! note
|
||||
This parameter is named with an underscore to suggest its private nature and discourage use.
|
||||
It may be deprecated in a minor version, so we only recommend using it if you're comfortable
|
||||
with potential change in behavior/support. It's default value is 2 because internally,
|
||||
the `TypeAdapter` class makes another call to fetch the frame.
|
||||
module: The module that passes to plugin if provided.
|
||||
|
||||
Attributes:
|
||||
core_schema: The core schema for the type.
|
||||
validator: The schema validator for the type.
|
||||
serializer: The schema serializer for the type.
|
||||
pydantic_complete: Whether the core schema for the type is successfully built.
|
||||
|
||||
??? tip "Compatibility with `mypy`"
|
||||
Depending on the type used, `mypy` might raise an error when instantiating a `TypeAdapter`. As a workaround, you can explicitly
|
||||
annotate your variable:
|
||||
|
||||
```py
|
||||
from typing import Union
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
ta: TypeAdapter[Union[str, int]] = TypeAdapter(Union[str, int]) # type: ignore[arg-type]
|
||||
```
|
||||
|
||||
??? info "Namespace management nuances and implementation details"
|
||||
|
||||
Here, we collect some notes on namespace management, and subtle differences from `BaseModel`:
|
||||
|
||||
`BaseModel` uses its own `__module__` to find out where it was defined
|
||||
and then looks for symbols to resolve forward references in those globals.
|
||||
On the other hand, `TypeAdapter` can be initialized with arbitrary objects,
|
||||
which may not be types and thus do not have a `__module__` available.
|
||||
So instead we look at the globals in our parent stack frame.
|
||||
|
||||
It is expected that the `ns_resolver` passed to this function will have the correct
|
||||
namespace for the type we're adapting. See the source code for `TypeAdapter.__init__`
|
||||
and `TypeAdapter.rebuild` for various ways to construct this namespace.
|
||||
|
||||
This works for the case where this function is called in a module that
|
||||
has the target of forward references in its scope, but
|
||||
does not always work for more complex cases.
|
||||
|
||||
For example, take the following:
|
||||
|
||||
```python {title="a.py"}
|
||||
from typing import Dict, List
|
||||
|
||||
IntList = List[int]
|
||||
OuterDict = Dict[str, 'IntList']
|
||||
```
|
||||
|
||||
```python {test="skip" title="b.py"}
|
||||
from a import OuterDict
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
IntList = int # replaces the symbol the forward reference is looking for
|
||||
v = TypeAdapter(OuterDict)
|
||||
v({'x': 1}) # should fail but doesn't
|
||||
```
|
||||
|
||||
If `OuterDict` were a `BaseModel`, this would work because it would resolve
|
||||
the forward reference within the `a.py` namespace.
|
||||
But `TypeAdapter(OuterDict)` can't determine what module `OuterDict` came from.
|
||||
|
||||
In other words, the assumption that _all_ forward references exist in the
|
||||
module we are being called from is not technically always true.
|
||||
Although most of the time it is and it works fine for recursive models and such,
|
||||
`BaseModel`'s behavior isn't perfect either and _can_ break in similar ways,
|
||||
so there is no right or wrong between the two.
|
||||
|
||||
But at the very least this behavior is _subtly_ different from `BaseModel`'s.
|
||||
"""
|
||||
|
||||
core_schema: CoreSchema
|
||||
validator: SchemaValidator | PluggableSchemaValidator
|
||||
serializer: SchemaSerializer
|
||||
pydantic_complete: bool
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
type: type[T],
|
||||
*,
|
||||
config: ConfigDict | None = ...,
|
||||
_parent_depth: int = ...,
|
||||
module: str | None = ...,
|
||||
) -> None: ...
|
||||
|
||||
# This second overload is for unsupported special forms (such as Annotated, Union, etc.)
|
||||
# Currently there is no way to type this correctly
|
||||
# See https://github.com/python/typing/pull/1618
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
type: Any,
|
||||
*,
|
||||
config: ConfigDict | None = ...,
|
||||
_parent_depth: int = ...,
|
||||
module: str | None = ...,
|
||||
) -> None: ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: Any,
|
||||
*,
|
||||
config: ConfigDict | None = None,
|
||||
_parent_depth: int = 2,
|
||||
module: str | None = None,
|
||||
) -> None:
|
||||
if _type_has_config(type) and config is not None:
|
||||
raise PydanticUserError(
|
||||
'Cannot use `config` when the type is a BaseModel, dataclass or TypedDict.'
|
||||
' These types can have their own config and setting the config via the `config`'
|
||||
' parameter to TypeAdapter will not override it, thus the `config` you passed to'
|
||||
' TypeAdapter becomes meaningless, which is probably not what you want.',
|
||||
code='type-adapter-config-unused',
|
||||
)
|
||||
|
||||
self._type = type
|
||||
self._config = config
|
||||
self._parent_depth = _parent_depth
|
||||
self.pydantic_complete = False
|
||||
|
||||
parent_frame = self._fetch_parent_frame()
|
||||
if parent_frame is not None:
|
||||
globalns = parent_frame.f_globals
|
||||
# Do not provide a local ns if the type adapter happens to be instantiated at the module level:
|
||||
localns = parent_frame.f_locals if parent_frame.f_locals is not globalns else {}
|
||||
else:
|
||||
globalns = {}
|
||||
localns = {}
|
||||
|
||||
self._module_name = module or cast(str, globalns.get('__name__', ''))
|
||||
self._init_core_attrs(
|
||||
ns_resolver=_namespace_utils.NsResolver(
|
||||
namespaces_tuple=_namespace_utils.NamespacesTuple(locals=localns, globals=globalns),
|
||||
parent_namespace=localns,
|
||||
),
|
||||
force=False,
|
||||
)
|
||||
|
||||
def _fetch_parent_frame(self) -> FrameType | None:
|
||||
frame = sys._getframe(self._parent_depth)
|
||||
if frame.f_globals.get('__name__') == 'typing':
|
||||
# Because `TypeAdapter` is generic, explicitly parametrizing the class results
|
||||
# in a `typing._GenericAlias` instance, which proxies instantiation calls to the
|
||||
# "real" `TypeAdapter` class and thus adding an extra frame to the call. To avoid
|
||||
# pulling anything from the `typing` module, use the correct frame (the one before):
|
||||
return frame.f_back
|
||||
|
||||
return frame
|
||||
|
||||
def _init_core_attrs(
|
||||
self, ns_resolver: _namespace_utils.NsResolver, force: bool, raise_errors: bool = False
|
||||
) -> bool:
|
||||
"""Initialize the core schema, validator, and serializer for the type.
|
||||
|
||||
Args:
|
||||
ns_resolver: The namespace resolver to use when building the core schema for the adapted type.
|
||||
force: Whether to force the construction of the core schema, validator, and serializer.
|
||||
If `force` is set to `False` and `_defer_build` is `True`, the core schema, validator, and serializer will be set to mocks.
|
||||
raise_errors: Whether to raise errors if initializing any of the core attrs fails.
|
||||
|
||||
Returns:
|
||||
`True` if the core schema, validator, and serializer were successfully initialized, otherwise `False`.
|
||||
|
||||
Raises:
|
||||
PydanticUndefinedAnnotation: If `PydanticUndefinedAnnotation` occurs in`__get_pydantic_core_schema__`
|
||||
and `raise_errors=True`.
|
||||
"""
|
||||
if not force and self._defer_build:
|
||||
_mock_val_ser.set_type_adapter_mocks(self, str(self._type))
|
||||
self.pydantic_complete = False
|
||||
return False
|
||||
|
||||
try:
|
||||
self.core_schema = _getattr_no_parents(self._type, '__pydantic_core_schema__')
|
||||
self.validator = _getattr_no_parents(self._type, '__pydantic_validator__')
|
||||
self.serializer = _getattr_no_parents(self._type, '__pydantic_serializer__')
|
||||
|
||||
# TODO: we don't go through the rebuild logic here directly because we don't want
|
||||
# to repeat all of the namespace fetching logic that we've already done
|
||||
# so we simply skip to the block below that does the actual schema generation
|
||||
if (
|
||||
isinstance(self.core_schema, _mock_val_ser.MockCoreSchema)
|
||||
or isinstance(self.validator, _mock_val_ser.MockValSer)
|
||||
or isinstance(self.serializer, _mock_val_ser.MockValSer)
|
||||
):
|
||||
raise AttributeError()
|
||||
except AttributeError:
|
||||
config_wrapper = _config.ConfigWrapper(self._config)
|
||||
|
||||
schema_generator = _generate_schema.GenerateSchema(config_wrapper, ns_resolver=ns_resolver)
|
||||
|
||||
try:
|
||||
core_schema = schema_generator.generate_schema(self._type)
|
||||
except PydanticUndefinedAnnotation:
|
||||
if raise_errors:
|
||||
raise
|
||||
_mock_val_ser.set_type_adapter_mocks(self, str(self._type))
|
||||
return False
|
||||
|
||||
try:
|
||||
self.core_schema = schema_generator.clean_schema(core_schema)
|
||||
except schema_generator.CollectedInvalid:
|
||||
_mock_val_ser.set_type_adapter_mocks(self, str(self._type))
|
||||
return False
|
||||
|
||||
core_config = config_wrapper.core_config(None)
|
||||
|
||||
self.validator = create_schema_validator(
|
||||
schema=self.core_schema,
|
||||
schema_type=self._type,
|
||||
schema_type_module=self._module_name,
|
||||
schema_type_name=str(self._type),
|
||||
schema_kind='TypeAdapter',
|
||||
config=core_config,
|
||||
plugin_settings=config_wrapper.plugin_settings,
|
||||
)
|
||||
self.serializer = SchemaSerializer(self.core_schema, core_config)
|
||||
|
||||
self.pydantic_complete = True
|
||||
return True
|
||||
|
||||
@property
|
||||
def _defer_build(self) -> bool:
|
||||
config = self._config if self._config is not None else self._model_config
|
||||
if config:
|
||||
return config.get('defer_build') is True
|
||||
return False
|
||||
|
||||
@property
|
||||
def _model_config(self) -> ConfigDict | None:
|
||||
type_: Any = _typing_extra.annotated_type(self._type) or self._type # Eg FastAPI heavily uses Annotated
|
||||
if _utils.lenient_issubclass(type_, BaseModel):
|
||||
return type_.model_config
|
||||
return getattr(type_, '__pydantic_config__', None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'TypeAdapter({_repr.display_as_type(self._type)})'
|
||||
|
||||
def rebuild(
|
||||
self,
|
||||
*,
|
||||
force: bool = False,
|
||||
raise_errors: bool = True,
|
||||
_parent_namespace_depth: int = 2,
|
||||
_types_namespace: _namespace_utils.MappingNamespace | None = None,
|
||||
) -> bool | None:
|
||||
"""Try to rebuild the pydantic-core schema for the adapter's type.
|
||||
|
||||
This may be necessary when one of the annotations is a ForwardRef which could not be resolved during
|
||||
the initial attempt to build the schema, and automatic rebuilding fails.
|
||||
|
||||
Args:
|
||||
force: Whether to force the rebuilding of the type adapter's schema, defaults to `False`.
|
||||
raise_errors: Whether to raise errors, defaults to `True`.
|
||||
_parent_namespace_depth: Depth at which to search for the [parent frame][frame-objects]. This
|
||||
frame is used when resolving forward annotations during schema rebuilding, by looking for
|
||||
the locals of this frame. Defaults to 2, which will result in the frame where the method
|
||||
was called.
|
||||
_types_namespace: An explicit types namespace to use, instead of using the local namespace
|
||||
from the parent frame. Defaults to `None`.
|
||||
|
||||
Returns:
|
||||
Returns `None` if the schema is already "complete" and rebuilding was not required.
|
||||
If rebuilding _was_ required, returns `True` if rebuilding was successful, otherwise `False`.
|
||||
"""
|
||||
if not force and self.pydantic_complete:
|
||||
return None
|
||||
|
||||
if _types_namespace is not None:
|
||||
rebuild_ns = _types_namespace
|
||||
elif _parent_namespace_depth > 0:
|
||||
rebuild_ns = _typing_extra.parent_frame_namespace(parent_depth=_parent_namespace_depth, force=True) or {}
|
||||
else:
|
||||
rebuild_ns = {}
|
||||
|
||||
# we have to manually fetch globals here because there's no type on the stack of the NsResolver
|
||||
# and so we skip the globalns = get_module_ns_of(typ) call that would normally happen
|
||||
globalns = sys._getframe(max(_parent_namespace_depth - 1, 1)).f_globals
|
||||
ns_resolver = _namespace_utils.NsResolver(
|
||||
namespaces_tuple=_namespace_utils.NamespacesTuple(locals=rebuild_ns, globals=globalns),
|
||||
parent_namespace=rebuild_ns,
|
||||
)
|
||||
return self._init_core_attrs(ns_resolver=ns_resolver, force=True, raise_errors=raise_errors)
|
||||
|
||||
def validate_python(
|
||||
self,
|
||||
object: Any,
|
||||
/,
|
||||
*,
|
||||
strict: bool | None = None,
|
||||
from_attributes: bool | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
experimental_allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False,
|
||||
) -> T:
|
||||
"""Validate a Python object against the model.
|
||||
|
||||
Args:
|
||||
object: The Python object to validate against the model.
|
||||
strict: Whether to strictly check types.
|
||||
from_attributes: Whether to extract data from object attributes.
|
||||
context: Additional context to pass to the validator.
|
||||
experimental_allow_partial: **Experimental** whether to enable
|
||||
[partial validation](../concepts/experimental.md#partial-validation), e.g. to process streams.
|
||||
* False / 'off': Default behavior, no partial validation.
|
||||
* True / 'on': Enable partial validation.
|
||||
* 'trailing-strings': Enable partial validation and allow trailing strings in the input.
|
||||
|
||||
!!! note
|
||||
When using `TypeAdapter` with a Pydantic `dataclass`, the use of the `from_attributes`
|
||||
argument is not supported.
|
||||
|
||||
Returns:
|
||||
The validated object.
|
||||
"""
|
||||
return self.validator.validate_python(
|
||||
object,
|
||||
strict=strict,
|
||||
from_attributes=from_attributes,
|
||||
context=context,
|
||||
allow_partial=experimental_allow_partial,
|
||||
)
|
||||
|
||||
def validate_json(
|
||||
self,
|
||||
data: str | bytes | bytearray,
|
||||
/,
|
||||
*,
|
||||
strict: bool | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
experimental_allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False,
|
||||
) -> T:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.10/concepts/json/#json-parsing
|
||||
|
||||
Validate a JSON string or bytes against the model.
|
||||
|
||||
Args:
|
||||
data: The JSON data to validate against the model.
|
||||
strict: Whether to strictly check types.
|
||||
context: Additional context to use during validation.
|
||||
experimental_allow_partial: **Experimental** whether to enable
|
||||
[partial validation](../concepts/experimental.md#partial-validation), e.g. to process streams.
|
||||
* False / 'off': Default behavior, no partial validation.
|
||||
* True / 'on': Enable partial validation.
|
||||
* 'trailing-strings': Enable partial validation and allow trailing strings in the input.
|
||||
|
||||
Returns:
|
||||
The validated object.
|
||||
"""
|
||||
return self.validator.validate_json(
|
||||
data, strict=strict, context=context, allow_partial=experimental_allow_partial
|
||||
)
|
||||
|
||||
def validate_strings(
|
||||
self,
|
||||
obj: Any,
|
||||
/,
|
||||
*,
|
||||
strict: bool | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
experimental_allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False,
|
||||
) -> T:
|
||||
"""Validate object contains string data against the model.
|
||||
|
||||
Args:
|
||||
obj: The object contains string data to validate.
|
||||
strict: Whether to strictly check types.
|
||||
context: Additional context to use during validation.
|
||||
experimental_allow_partial: **Experimental** whether to enable
|
||||
[partial validation](../concepts/experimental.md#partial-validation), e.g. to process streams.
|
||||
* False / 'off': Default behavior, no partial validation.
|
||||
* True / 'on': Enable partial validation.
|
||||
* 'trailing-strings': Enable partial validation and allow trailing strings in the input.
|
||||
|
||||
Returns:
|
||||
The validated object.
|
||||
"""
|
||||
return self.validator.validate_strings(
|
||||
obj, strict=strict, context=context, allow_partial=experimental_allow_partial
|
||||
)
|
||||
|
||||
def get_default_value(self, *, strict: bool | None = None, context: dict[str, Any] | None = None) -> Some[T] | None:
|
||||
"""Get the default value for the wrapped type.
|
||||
|
||||
Args:
|
||||
strict: Whether to strictly check types.
|
||||
context: Additional context to pass to the validator.
|
||||
|
||||
Returns:
|
||||
The default value wrapped in a `Some` if there is one or None if not.
|
||||
"""
|
||||
return self.validator.get_default_value(strict=strict, context=context)
|
||||
|
||||
def dump_python(
|
||||
self,
|
||||
instance: T,
|
||||
/,
|
||||
*,
|
||||
mode: Literal['json', 'python'] = 'python',
|
||||
include: IncEx | None = None,
|
||||
exclude: IncEx | None = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
round_trip: bool = False,
|
||||
warnings: bool | Literal['none', 'warn', 'error'] = True,
|
||||
serialize_as_any: bool = False,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
"""Dump an instance of the adapted type to a Python object.
|
||||
|
||||
Args:
|
||||
instance: The Python object to serialize.
|
||||
mode: The output format.
|
||||
include: Fields to include in the output.
|
||||
exclude: Fields to exclude from the output.
|
||||
by_alias: Whether to use alias names for field names.
|
||||
exclude_unset: Whether to exclude unset fields.
|
||||
exclude_defaults: Whether to exclude fields with default values.
|
||||
exclude_none: Whether to exclude fields with None values.
|
||||
round_trip: Whether to output the serialized data in a way that is compatible with deserialization.
|
||||
warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors,
|
||||
"error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError].
|
||||
serialize_as_any: Whether to serialize fields with duck-typing serialization behavior.
|
||||
context: Additional context to pass to the serializer.
|
||||
|
||||
Returns:
|
||||
The serialized object.
|
||||
"""
|
||||
return self.serializer.to_python(
|
||||
instance,
|
||||
mode=mode,
|
||||
by_alias=by_alias,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
round_trip=round_trip,
|
||||
warnings=warnings,
|
||||
serialize_as_any=serialize_as_any,
|
||||
context=context,
|
||||
)
|
||||
|
||||
def dump_json(
|
||||
self,
|
||||
instance: T,
|
||||
/,
|
||||
*,
|
||||
indent: int | None = None,
|
||||
include: IncEx | None = None,
|
||||
exclude: IncEx | None = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
round_trip: bool = False,
|
||||
warnings: bool | Literal['none', 'warn', 'error'] = True,
|
||||
serialize_as_any: bool = False,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> bytes:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.10/concepts/json/#json-serialization
|
||||
|
||||
Serialize an instance of the adapted type to JSON.
|
||||
|
||||
Args:
|
||||
instance: The instance to be serialized.
|
||||
indent: Number of spaces for JSON indentation.
|
||||
include: Fields to include.
|
||||
exclude: Fields to exclude.
|
||||
by_alias: Whether to use alias names for field names.
|
||||
exclude_unset: Whether to exclude unset fields.
|
||||
exclude_defaults: Whether to exclude fields with default values.
|
||||
exclude_none: Whether to exclude fields with a value of `None`.
|
||||
round_trip: Whether to serialize and deserialize the instance to ensure round-tripping.
|
||||
warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors,
|
||||
"error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError].
|
||||
serialize_as_any: Whether to serialize fields with duck-typing serialization behavior.
|
||||
context: Additional context to pass to the serializer.
|
||||
|
||||
Returns:
|
||||
The JSON representation of the given instance as bytes.
|
||||
"""
|
||||
return self.serializer.to_json(
|
||||
instance,
|
||||
indent=indent,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
round_trip=round_trip,
|
||||
warnings=warnings,
|
||||
serialize_as_any=serialize_as_any,
|
||||
context=context,
|
||||
)
|
||||
|
||||
def json_schema(
|
||||
self,
|
||||
*,
|
||||
by_alias: bool = True,
|
||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
|
||||
mode: JsonSchemaMode = 'validation',
|
||||
) -> dict[str, Any]:
|
||||
"""Generate a JSON schema for the adapted type.
|
||||
|
||||
Args:
|
||||
by_alias: Whether to use alias names for field names.
|
||||
ref_template: The format string used for generating $ref strings.
|
||||
schema_generator: The generator class used for creating the schema.
|
||||
mode: The mode to use for schema generation.
|
||||
|
||||
Returns:
|
||||
The JSON schema for the model as a dictionary.
|
||||
"""
|
||||
schema_generator_instance = schema_generator(by_alias=by_alias, ref_template=ref_template)
|
||||
if isinstance(self.core_schema, _mock_val_ser.MockCoreSchema):
|
||||
self.core_schema.rebuild()
|
||||
assert not isinstance(self.core_schema, _mock_val_ser.MockCoreSchema), 'this is a bug! please report it'
|
||||
return schema_generator_instance.generate(self.core_schema, mode=mode)
|
||||
|
||||
@staticmethod
|
||||
def json_schemas(
|
||||
inputs: Iterable[tuple[JsonSchemaKeyT, JsonSchemaMode, TypeAdapter[Any]]],
|
||||
/,
|
||||
*,
|
||||
by_alias: bool = True,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
|
||||
) -> tuple[dict[tuple[JsonSchemaKeyT, JsonSchemaMode], JsonSchemaValue], JsonSchemaValue]:
|
||||
"""Generate a JSON schema including definitions from multiple type adapters.
|
||||
|
||||
Args:
|
||||
inputs: Inputs to schema generation. The first two items will form the keys of the (first)
|
||||
output mapping; the type adapters will provide the core schemas that get converted into
|
||||
definitions in the output JSON schema.
|
||||
by_alias: Whether to use alias names.
|
||||
title: The title for the schema.
|
||||
description: The description for the schema.
|
||||
ref_template: The format string used for generating $ref strings.
|
||||
schema_generator: The generator class used for creating the schema.
|
||||
|
||||
Returns:
|
||||
A tuple where:
|
||||
|
||||
- The first element is a dictionary whose keys are tuples of JSON schema key type and JSON mode, and
|
||||
whose values are the JSON schema corresponding to that pair of inputs. (These schemas may have
|
||||
JsonRef references to definitions that are defined in the second returned element.)
|
||||
- The second element is a JSON schema containing all definitions referenced in the first returned
|
||||
element, along with the optional title and description keys.
|
||||
|
||||
"""
|
||||
schema_generator_instance = schema_generator(by_alias=by_alias, ref_template=ref_template)
|
||||
|
||||
inputs_ = []
|
||||
for key, mode, adapter in inputs:
|
||||
# This is the same pattern we follow for model json schemas - we attempt a core schema rebuild if we detect a mock
|
||||
if isinstance(adapter.core_schema, _mock_val_ser.MockCoreSchema):
|
||||
adapter.core_schema.rebuild()
|
||||
assert not isinstance(
|
||||
adapter.core_schema, _mock_val_ser.MockCoreSchema
|
||||
), 'this is a bug! please report it'
|
||||
inputs_.append((key, mode, adapter.core_schema))
|
||||
|
||||
json_schemas_map, definitions = schema_generator_instance.generate_definitions(inputs_)
|
||||
|
||||
json_schema: dict[str, Any] = {}
|
||||
if definitions:
|
||||
json_schema['$defs'] = definitions
|
||||
if title:
|
||||
json_schema['title'] = title
|
||||
if description:
|
||||
json_schema['description'] = description
|
||||
|
||||
return json_schemas_map, json_schema
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,5 @@
|
||||
"""`typing` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
@@ -0,0 +1,5 @@
|
||||
"""The `utils` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
@@ -0,0 +1,131 @@
|
||||
# flake8: noqa
|
||||
from pydantic.v1 import dataclasses
|
||||
from pydantic.v1.annotated_types import create_model_from_namedtuple, create_model_from_typeddict
|
||||
from pydantic.v1.class_validators import root_validator, validator
|
||||
from pydantic.v1.config import BaseConfig, ConfigDict, Extra
|
||||
from pydantic.v1.decorator import validate_arguments
|
||||
from pydantic.v1.env_settings import BaseSettings
|
||||
from pydantic.v1.error_wrappers import ValidationError
|
||||
from pydantic.v1.errors import *
|
||||
from pydantic.v1.fields import Field, PrivateAttr, Required
|
||||
from pydantic.v1.main import *
|
||||
from pydantic.v1.networks import *
|
||||
from pydantic.v1.parse import Protocol
|
||||
from pydantic.v1.tools import *
|
||||
from pydantic.v1.types import *
|
||||
from pydantic.v1.version import VERSION, compiled
|
||||
|
||||
__version__ = VERSION
|
||||
|
||||
# WARNING __all__ from pydantic.errors is not included here, it will be removed as an export here in v2
|
||||
# please use "from pydantic.v1.errors import ..." instead
|
||||
__all__ = [
|
||||
# annotated types utils
|
||||
'create_model_from_namedtuple',
|
||||
'create_model_from_typeddict',
|
||||
# dataclasses
|
||||
'dataclasses',
|
||||
# class_validators
|
||||
'root_validator',
|
||||
'validator',
|
||||
# config
|
||||
'BaseConfig',
|
||||
'ConfigDict',
|
||||
'Extra',
|
||||
# decorator
|
||||
'validate_arguments',
|
||||
# env_settings
|
||||
'BaseSettings',
|
||||
# error_wrappers
|
||||
'ValidationError',
|
||||
# fields
|
||||
'Field',
|
||||
'Required',
|
||||
# main
|
||||
'BaseModel',
|
||||
'create_model',
|
||||
'validate_model',
|
||||
# network
|
||||
'AnyUrl',
|
||||
'AnyHttpUrl',
|
||||
'FileUrl',
|
||||
'HttpUrl',
|
||||
'stricturl',
|
||||
'EmailStr',
|
||||
'NameEmail',
|
||||
'IPvAnyAddress',
|
||||
'IPvAnyInterface',
|
||||
'IPvAnyNetwork',
|
||||
'PostgresDsn',
|
||||
'CockroachDsn',
|
||||
'AmqpDsn',
|
||||
'RedisDsn',
|
||||
'MongoDsn',
|
||||
'KafkaDsn',
|
||||
'validate_email',
|
||||
# parse
|
||||
'Protocol',
|
||||
# tools
|
||||
'parse_file_as',
|
||||
'parse_obj_as',
|
||||
'parse_raw_as',
|
||||
'schema_of',
|
||||
'schema_json_of',
|
||||
# types
|
||||
'NoneStr',
|
||||
'NoneBytes',
|
||||
'StrBytes',
|
||||
'NoneStrBytes',
|
||||
'StrictStr',
|
||||
'ConstrainedBytes',
|
||||
'conbytes',
|
||||
'ConstrainedList',
|
||||
'conlist',
|
||||
'ConstrainedSet',
|
||||
'conset',
|
||||
'ConstrainedFrozenSet',
|
||||
'confrozenset',
|
||||
'ConstrainedStr',
|
||||
'constr',
|
||||
'PyObject',
|
||||
'ConstrainedInt',
|
||||
'conint',
|
||||
'PositiveInt',
|
||||
'NegativeInt',
|
||||
'NonNegativeInt',
|
||||
'NonPositiveInt',
|
||||
'ConstrainedFloat',
|
||||
'confloat',
|
||||
'PositiveFloat',
|
||||
'NegativeFloat',
|
||||
'NonNegativeFloat',
|
||||
'NonPositiveFloat',
|
||||
'FiniteFloat',
|
||||
'ConstrainedDecimal',
|
||||
'condecimal',
|
||||
'ConstrainedDate',
|
||||
'condate',
|
||||
'UUID1',
|
||||
'UUID3',
|
||||
'UUID4',
|
||||
'UUID5',
|
||||
'FilePath',
|
||||
'DirectoryPath',
|
||||
'Json',
|
||||
'JsonWrapper',
|
||||
'SecretField',
|
||||
'SecretStr',
|
||||
'SecretBytes',
|
||||
'StrictBool',
|
||||
'StrictBytes',
|
||||
'StrictInt',
|
||||
'StrictFloat',
|
||||
'PaymentCardNumber',
|
||||
'PrivateAttr',
|
||||
'ByteSize',
|
||||
'PastDate',
|
||||
'FutureDate',
|
||||
# version
|
||||
'compiled',
|
||||
'VERSION',
|
||||
]
|
||||
@@ -0,0 +1,391 @@
|
||||
"""
|
||||
Register Hypothesis strategies for Pydantic custom types.
|
||||
|
||||
This enables fully-automatic generation of test data for most Pydantic classes.
|
||||
|
||||
Note that this module has *no* runtime impact on Pydantic itself; instead it
|
||||
is registered as a setuptools entry point and Hypothesis will import it if
|
||||
Pydantic is installed. See also:
|
||||
|
||||
https://hypothesis.readthedocs.io/en/latest/strategies.html#registering-strategies-via-setuptools-entry-points
|
||||
https://hypothesis.readthedocs.io/en/latest/data.html#hypothesis.strategies.register_type_strategy
|
||||
https://hypothesis.readthedocs.io/en/latest/strategies.html#interaction-with-pytest-cov
|
||||
https://docs.pydantic.dev/usage/types/#pydantic-types
|
||||
|
||||
Note that because our motivation is to *improve user experience*, the strategies
|
||||
are always sound (never generate invalid data) but sacrifice completeness for
|
||||
maintainability (ie may be unable to generate some tricky but valid data).
|
||||
|
||||
Finally, this module makes liberal use of `# type: ignore[<code>]` pragmas.
|
||||
This is because Hypothesis annotates `register_type_strategy()` with
|
||||
`(T, SearchStrategy[T])`, but in most cases we register e.g. `ConstrainedInt`
|
||||
to generate instances of the builtin `int` type which match the constraints.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import datetime
|
||||
import ipaddress
|
||||
import json
|
||||
import math
|
||||
from fractions import Fraction
|
||||
from typing import Callable, Dict, Type, Union, cast, overload
|
||||
|
||||
import hypothesis.strategies as st
|
||||
|
||||
import pydantic
|
||||
import pydantic.color
|
||||
import pydantic.types
|
||||
from pydantic.v1.utils import lenient_issubclass
|
||||
|
||||
# FilePath and DirectoryPath are explicitly unsupported, as we'd have to create
|
||||
# them on-disk, and that's unsafe in general without being told *where* to do so.
|
||||
#
|
||||
# URLs are unsupported because it's easy for users to define their own strategy for
|
||||
# "normal" URLs, and hard for us to define a general strategy which includes "weird"
|
||||
# URLs but doesn't also have unpredictable performance problems.
|
||||
#
|
||||
# conlist() and conset() are unsupported for now, because the workarounds for
|
||||
# Cython and Hypothesis to handle parametrized generic types are incompatible.
|
||||
# We are rethinking Hypothesis compatibility in Pydantic v2.
|
||||
|
||||
# Emails
|
||||
try:
|
||||
import email_validator
|
||||
except ImportError: # pragma: no cover
|
||||
pass
|
||||
else:
|
||||
|
||||
def is_valid_email(s: str) -> bool:
|
||||
# Hypothesis' st.emails() occasionally generates emails like 0@A0--0.ac
|
||||
# that are invalid according to email-validator, so we filter those out.
|
||||
try:
|
||||
email_validator.validate_email(s, check_deliverability=False)
|
||||
return True
|
||||
except email_validator.EmailNotValidError: # pragma: no cover
|
||||
return False
|
||||
|
||||
# Note that these strategies deliberately stay away from any tricky Unicode
|
||||
# or other encoding issues; we're just trying to generate *something* valid.
|
||||
st.register_type_strategy(pydantic.EmailStr, st.emails().filter(is_valid_email)) # type: ignore[arg-type]
|
||||
st.register_type_strategy(
|
||||
pydantic.NameEmail,
|
||||
st.builds(
|
||||
'{} <{}>'.format, # type: ignore[arg-type]
|
||||
st.from_regex('[A-Za-z0-9_]+( [A-Za-z0-9_]+){0,5}', fullmatch=True),
|
||||
st.emails().filter(is_valid_email),
|
||||
),
|
||||
)
|
||||
|
||||
# PyObject - dotted names, in this case taken from the math module.
|
||||
st.register_type_strategy(
|
||||
pydantic.PyObject, # type: ignore[arg-type]
|
||||
st.sampled_from(
|
||||
[cast(pydantic.PyObject, f'math.{name}') for name in sorted(vars(math)) if not name.startswith('_')]
|
||||
),
|
||||
)
|
||||
|
||||
# CSS3 Colors; as name, hex, rgb(a) tuples or strings, or hsl strings
|
||||
_color_regexes = (
|
||||
'|'.join(
|
||||
(
|
||||
pydantic.color.r_hex_short,
|
||||
pydantic.color.r_hex_long,
|
||||
pydantic.color.r_rgb,
|
||||
pydantic.color.r_rgba,
|
||||
pydantic.color.r_hsl,
|
||||
pydantic.color.r_hsla,
|
||||
)
|
||||
)
|
||||
# Use more precise regex patterns to avoid value-out-of-range errors
|
||||
.replace(pydantic.color._r_sl, r'(?:(\d\d?(?:\.\d+)?|100(?:\.0+)?)%)')
|
||||
.replace(pydantic.color._r_alpha, r'(?:(0(?:\.\d+)?|1(?:\.0+)?|\.\d+|\d{1,2}%))')
|
||||
.replace(pydantic.color._r_255, r'(?:((?:\d|\d\d|[01]\d\d|2[0-4]\d|25[0-4])(?:\.\d+)?|255(?:\.0+)?))')
|
||||
)
|
||||
st.register_type_strategy(
|
||||
pydantic.color.Color,
|
||||
st.one_of(
|
||||
st.sampled_from(sorted(pydantic.color.COLORS_BY_NAME)),
|
||||
st.tuples(
|
||||
st.integers(0, 255),
|
||||
st.integers(0, 255),
|
||||
st.integers(0, 255),
|
||||
st.none() | st.floats(0, 1) | st.floats(0, 100).map('{}%'.format),
|
||||
),
|
||||
st.from_regex(_color_regexes, fullmatch=True),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Card numbers, valid according to the Luhn algorithm
|
||||
|
||||
|
||||
def add_luhn_digit(card_number: str) -> str:
|
||||
# See https://en.wikipedia.org/wiki/Luhn_algorithm
|
||||
for digit in '0123456789':
|
||||
with contextlib.suppress(Exception):
|
||||
pydantic.PaymentCardNumber.validate_luhn_check_digit(card_number + digit)
|
||||
return card_number + digit
|
||||
raise AssertionError('Unreachable') # pragma: no cover
|
||||
|
||||
|
||||
card_patterns = (
|
||||
# Note that these patterns omit the Luhn check digit; that's added by the function above
|
||||
'4[0-9]{14}', # Visa
|
||||
'5[12345][0-9]{13}', # Mastercard
|
||||
'3[47][0-9]{12}', # American Express
|
||||
'[0-26-9][0-9]{10,17}', # other (incomplete to avoid overlap)
|
||||
)
|
||||
st.register_type_strategy(
|
||||
pydantic.PaymentCardNumber,
|
||||
st.from_regex('|'.join(card_patterns), fullmatch=True).map(add_luhn_digit), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# UUIDs
|
||||
st.register_type_strategy(pydantic.UUID1, st.uuids(version=1))
|
||||
st.register_type_strategy(pydantic.UUID3, st.uuids(version=3))
|
||||
st.register_type_strategy(pydantic.UUID4, st.uuids(version=4))
|
||||
st.register_type_strategy(pydantic.UUID5, st.uuids(version=5))
|
||||
|
||||
# Secrets
|
||||
st.register_type_strategy(pydantic.SecretBytes, st.binary().map(pydantic.SecretBytes))
|
||||
st.register_type_strategy(pydantic.SecretStr, st.text().map(pydantic.SecretStr))
|
||||
|
||||
# IP addresses, networks, and interfaces
|
||||
st.register_type_strategy(pydantic.IPvAnyAddress, st.ip_addresses()) # type: ignore[arg-type]
|
||||
st.register_type_strategy(
|
||||
pydantic.IPvAnyInterface,
|
||||
st.from_type(ipaddress.IPv4Interface) | st.from_type(ipaddress.IPv6Interface), # type: ignore[arg-type]
|
||||
)
|
||||
st.register_type_strategy(
|
||||
pydantic.IPvAnyNetwork,
|
||||
st.from_type(ipaddress.IPv4Network) | st.from_type(ipaddress.IPv6Network), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# We hook into the con***() functions and the ConstrainedNumberMeta metaclass,
|
||||
# so here we only have to register subclasses for other constrained types which
|
||||
# don't go via those mechanisms. Then there are the registration hooks below.
|
||||
st.register_type_strategy(pydantic.StrictBool, st.booleans())
|
||||
st.register_type_strategy(pydantic.StrictStr, st.text())
|
||||
|
||||
|
||||
# FutureDate, PastDate
|
||||
st.register_type_strategy(pydantic.FutureDate, st.dates(min_value=datetime.date.today() + datetime.timedelta(days=1)))
|
||||
st.register_type_strategy(pydantic.PastDate, st.dates(max_value=datetime.date.today() - datetime.timedelta(days=1)))
|
||||
|
||||
|
||||
# Constrained-type resolver functions
|
||||
#
|
||||
# For these ones, we actually want to inspect the type in order to work out a
|
||||
# satisfying strategy. First up, the machinery for tracking resolver functions:
|
||||
|
||||
RESOLVERS: Dict[type, Callable[[type], st.SearchStrategy]] = {} # type: ignore[type-arg]
|
||||
|
||||
|
||||
@overload
|
||||
def _registered(typ: Type[pydantic.types.T]) -> Type[pydantic.types.T]:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def _registered(typ: pydantic.types.ConstrainedNumberMeta) -> pydantic.types.ConstrainedNumberMeta:
|
||||
pass
|
||||
|
||||
|
||||
def _registered(
|
||||
typ: Union[Type[pydantic.types.T], pydantic.types.ConstrainedNumberMeta]
|
||||
) -> Union[Type[pydantic.types.T], pydantic.types.ConstrainedNumberMeta]:
|
||||
# This function replaces the version in `pydantic.types`, in order to
|
||||
# effect the registration of new constrained types so that Hypothesis
|
||||
# can generate valid examples.
|
||||
pydantic.types._DEFINED_TYPES.add(typ)
|
||||
for supertype, resolver in RESOLVERS.items():
|
||||
if issubclass(typ, supertype):
|
||||
st.register_type_strategy(typ, resolver(typ)) # type: ignore
|
||||
return typ
|
||||
raise NotImplementedError(f'Unknown type {typ!r} has no resolver to register') # pragma: no cover
|
||||
|
||||
|
||||
def resolves(
|
||||
typ: Union[type, pydantic.types.ConstrainedNumberMeta]
|
||||
) -> Callable[[Callable[..., st.SearchStrategy]], Callable[..., st.SearchStrategy]]: # type: ignore[type-arg]
|
||||
def inner(f): # type: ignore
|
||||
assert f not in RESOLVERS
|
||||
RESOLVERS[typ] = f
|
||||
return f
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
# Type-to-strategy resolver functions
|
||||
|
||||
|
||||
@resolves(pydantic.JsonWrapper)
|
||||
def resolve_json(cls): # type: ignore[no-untyped-def]
|
||||
try:
|
||||
inner = st.none() if cls.inner_type is None else st.from_type(cls.inner_type)
|
||||
except Exception: # pragma: no cover
|
||||
finite = st.floats(allow_infinity=False, allow_nan=False)
|
||||
inner = st.recursive(
|
||||
base=st.one_of(st.none(), st.booleans(), st.integers(), finite, st.text()),
|
||||
extend=lambda x: st.lists(x) | st.dictionaries(st.text(), x), # type: ignore
|
||||
)
|
||||
inner_type = getattr(cls, 'inner_type', None)
|
||||
return st.builds(
|
||||
cls.inner_type.json if lenient_issubclass(inner_type, pydantic.BaseModel) else json.dumps,
|
||||
inner,
|
||||
ensure_ascii=st.booleans(),
|
||||
indent=st.none() | st.integers(0, 16),
|
||||
sort_keys=st.booleans(),
|
||||
)
|
||||
|
||||
|
||||
@resolves(pydantic.ConstrainedBytes)
|
||||
def resolve_conbytes(cls): # type: ignore[no-untyped-def] # pragma: no cover
|
||||
min_size = cls.min_length or 0
|
||||
max_size = cls.max_length
|
||||
if not cls.strip_whitespace:
|
||||
return st.binary(min_size=min_size, max_size=max_size)
|
||||
# Fun with regex to ensure we neither start nor end with whitespace
|
||||
repeats = '{{{},{}}}'.format(
|
||||
min_size - 2 if min_size > 2 else 0,
|
||||
max_size - 2 if (max_size or 0) > 2 else '',
|
||||
)
|
||||
if min_size >= 2:
|
||||
pattern = rf'\W.{repeats}\W'
|
||||
elif min_size == 1:
|
||||
pattern = rf'\W(.{repeats}\W)?'
|
||||
else:
|
||||
assert min_size == 0
|
||||
pattern = rf'(\W(.{repeats}\W)?)?'
|
||||
return st.from_regex(pattern.encode(), fullmatch=True)
|
||||
|
||||
|
||||
@resolves(pydantic.ConstrainedDecimal)
|
||||
def resolve_condecimal(cls): # type: ignore[no-untyped-def]
|
||||
min_value = cls.ge
|
||||
max_value = cls.le
|
||||
if cls.gt is not None:
|
||||
assert min_value is None, 'Set `gt` or `ge`, but not both'
|
||||
min_value = cls.gt
|
||||
if cls.lt is not None:
|
||||
assert max_value is None, 'Set `lt` or `le`, but not both'
|
||||
max_value = cls.lt
|
||||
s = st.decimals(min_value, max_value, allow_nan=False, places=cls.decimal_places)
|
||||
if cls.lt is not None:
|
||||
s = s.filter(lambda d: d < cls.lt)
|
||||
if cls.gt is not None:
|
||||
s = s.filter(lambda d: cls.gt < d)
|
||||
return s
|
||||
|
||||
|
||||
@resolves(pydantic.ConstrainedFloat)
|
||||
def resolve_confloat(cls): # type: ignore[no-untyped-def]
|
||||
min_value = cls.ge
|
||||
max_value = cls.le
|
||||
exclude_min = False
|
||||
exclude_max = False
|
||||
|
||||
if cls.gt is not None:
|
||||
assert min_value is None, 'Set `gt` or `ge`, but not both'
|
||||
min_value = cls.gt
|
||||
exclude_min = True
|
||||
if cls.lt is not None:
|
||||
assert max_value is None, 'Set `lt` or `le`, but not both'
|
||||
max_value = cls.lt
|
||||
exclude_max = True
|
||||
|
||||
if cls.multiple_of is None:
|
||||
return st.floats(min_value, max_value, exclude_min=exclude_min, exclude_max=exclude_max, allow_nan=False)
|
||||
|
||||
if min_value is not None:
|
||||
min_value = math.ceil(min_value / cls.multiple_of)
|
||||
if exclude_min:
|
||||
min_value = min_value + 1
|
||||
if max_value is not None:
|
||||
assert max_value >= cls.multiple_of, 'Cannot build model with max value smaller than multiple of'
|
||||
max_value = math.floor(max_value / cls.multiple_of)
|
||||
if exclude_max:
|
||||
max_value = max_value - 1
|
||||
|
||||
return st.integers(min_value, max_value).map(lambda x: x * cls.multiple_of)
|
||||
|
||||
|
||||
@resolves(pydantic.ConstrainedInt)
|
||||
def resolve_conint(cls): # type: ignore[no-untyped-def]
|
||||
min_value = cls.ge
|
||||
max_value = cls.le
|
||||
if cls.gt is not None:
|
||||
assert min_value is None, 'Set `gt` or `ge`, but not both'
|
||||
min_value = cls.gt + 1
|
||||
if cls.lt is not None:
|
||||
assert max_value is None, 'Set `lt` or `le`, but not both'
|
||||
max_value = cls.lt - 1
|
||||
|
||||
if cls.multiple_of is None or cls.multiple_of == 1:
|
||||
return st.integers(min_value, max_value)
|
||||
|
||||
# These adjustments and the .map handle integer-valued multiples, while the
|
||||
# .filter handles trickier cases as for confloat.
|
||||
if min_value is not None:
|
||||
min_value = math.ceil(Fraction(min_value) / Fraction(cls.multiple_of))
|
||||
if max_value is not None:
|
||||
max_value = math.floor(Fraction(max_value) / Fraction(cls.multiple_of))
|
||||
return st.integers(min_value, max_value).map(lambda x: x * cls.multiple_of)
|
||||
|
||||
|
||||
@resolves(pydantic.ConstrainedDate)
|
||||
def resolve_condate(cls): # type: ignore[no-untyped-def]
|
||||
if cls.ge is not None:
|
||||
assert cls.gt is None, 'Set `gt` or `ge`, but not both'
|
||||
min_value = cls.ge
|
||||
elif cls.gt is not None:
|
||||
min_value = cls.gt + datetime.timedelta(days=1)
|
||||
else:
|
||||
min_value = datetime.date.min
|
||||
if cls.le is not None:
|
||||
assert cls.lt is None, 'Set `lt` or `le`, but not both'
|
||||
max_value = cls.le
|
||||
elif cls.lt is not None:
|
||||
max_value = cls.lt - datetime.timedelta(days=1)
|
||||
else:
|
||||
max_value = datetime.date.max
|
||||
return st.dates(min_value, max_value)
|
||||
|
||||
|
||||
@resolves(pydantic.ConstrainedStr)
|
||||
def resolve_constr(cls): # type: ignore[no-untyped-def] # pragma: no cover
|
||||
min_size = cls.min_length or 0
|
||||
max_size = cls.max_length
|
||||
|
||||
if cls.regex is None and not cls.strip_whitespace:
|
||||
return st.text(min_size=min_size, max_size=max_size)
|
||||
|
||||
if cls.regex is not None:
|
||||
strategy = st.from_regex(cls.regex)
|
||||
if cls.strip_whitespace:
|
||||
strategy = strategy.filter(lambda s: s == s.strip())
|
||||
elif cls.strip_whitespace:
|
||||
repeats = '{{{},{}}}'.format(
|
||||
min_size - 2 if min_size > 2 else 0,
|
||||
max_size - 2 if (max_size or 0) > 2 else '',
|
||||
)
|
||||
if min_size >= 2:
|
||||
strategy = st.from_regex(rf'\W.{repeats}\W')
|
||||
elif min_size == 1:
|
||||
strategy = st.from_regex(rf'\W(.{repeats}\W)?')
|
||||
else:
|
||||
assert min_size == 0
|
||||
strategy = st.from_regex(rf'(\W(.{repeats}\W)?)?')
|
||||
|
||||
if min_size == 0 and max_size is None:
|
||||
return strategy
|
||||
elif max_size is None:
|
||||
return strategy.filter(lambda s: min_size <= len(s))
|
||||
return strategy.filter(lambda s: min_size <= len(s) <= max_size)
|
||||
|
||||
|
||||
# Finally, register all previously-defined types, and patch in our new function
|
||||
for typ in list(pydantic.types._DEFINED_TYPES):
|
||||
_registered(typ)
|
||||
pydantic.types._registered = _registered
|
||||
st.register_type_strategy(pydantic.Json, resolve_json)
|
||||
@@ -0,0 +1,72 @@
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, NamedTuple, Type
|
||||
|
||||
from pydantic.v1.fields import Required
|
||||
from pydantic.v1.main import BaseModel, create_model
|
||||
from pydantic.v1.typing import is_typeddict, is_typeddict_special
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
|
||||
def is_legacy_typeddict(typeddict_cls: Type['TypedDict']) -> bool: # type: ignore[valid-type]
|
||||
return is_typeddict(typeddict_cls) and type(typeddict_cls).__module__ == 'typing'
|
||||
|
||||
else:
|
||||
|
||||
def is_legacy_typeddict(_: Any) -> Any:
|
||||
return False
|
||||
|
||||
|
||||
def create_model_from_typeddict(
|
||||
# Mypy bug: `Type[TypedDict]` is resolved as `Any` https://github.com/python/mypy/issues/11030
|
||||
typeddict_cls: Type['TypedDict'], # type: ignore[valid-type]
|
||||
**kwargs: Any,
|
||||
) -> Type['BaseModel']:
|
||||
"""
|
||||
Create a `BaseModel` based on the fields of a `TypedDict`.
|
||||
Since `typing.TypedDict` in Python 3.8 does not store runtime information about optional keys,
|
||||
we raise an error if this happens (see https://bugs.python.org/issue38834).
|
||||
"""
|
||||
field_definitions: Dict[str, Any]
|
||||
|
||||
# Best case scenario: with python 3.9+ or when `TypedDict` is imported from `typing_extensions`
|
||||
if not hasattr(typeddict_cls, '__required_keys__'):
|
||||
raise TypeError(
|
||||
'You should use `typing_extensions.TypedDict` instead of `typing.TypedDict` with Python < 3.9.2. '
|
||||
'Without it, there is no way to differentiate required and optional fields when subclassed.'
|
||||
)
|
||||
|
||||
if is_legacy_typeddict(typeddict_cls) and any(
|
||||
is_typeddict_special(t) for t in typeddict_cls.__annotations__.values()
|
||||
):
|
||||
raise TypeError(
|
||||
'You should use `typing_extensions.TypedDict` instead of `typing.TypedDict` with Python < 3.11. '
|
||||
'Without it, there is no way to reflect Required/NotRequired keys.'
|
||||
)
|
||||
|
||||
required_keys: FrozenSet[str] = typeddict_cls.__required_keys__ # type: ignore[attr-defined]
|
||||
field_definitions = {
|
||||
field_name: (field_type, Required if field_name in required_keys else None)
|
||||
for field_name, field_type in typeddict_cls.__annotations__.items()
|
||||
}
|
||||
|
||||
return create_model(typeddict_cls.__name__, **kwargs, **field_definitions)
|
||||
|
||||
|
||||
def create_model_from_namedtuple(namedtuple_cls: Type['NamedTuple'], **kwargs: Any) -> Type['BaseModel']:
|
||||
"""
|
||||
Create a `BaseModel` based on the fields of a named tuple.
|
||||
A named tuple can be created with `typing.NamedTuple` and declared annotations
|
||||
but also with `collections.namedtuple`, in this case we consider all fields
|
||||
to have type `Any`.
|
||||
"""
|
||||
# With python 3.10+, `__annotations__` always exists but can be empty hence the `getattr... or...` logic
|
||||
namedtuple_annotations: Dict[str, Type[Any]] = getattr(namedtuple_cls, '__annotations__', None) or {
|
||||
k: Any for k in namedtuple_cls._fields
|
||||
}
|
||||
field_definitions: Dict[str, Any] = {
|
||||
field_name: (field_type, Required) for field_name, field_type in namedtuple_annotations.items()
|
||||
}
|
||||
return create_model(namedtuple_cls.__name__, **kwargs, **field_definitions)
|
||||
@@ -0,0 +1,361 @@
|
||||
import warnings
|
||||
from collections import ChainMap
|
||||
from functools import partial, partialmethod, wraps
|
||||
from itertools import chain
|
||||
from types import FunctionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, overload
|
||||
|
||||
from pydantic.v1.errors import ConfigError
|
||||
from pydantic.v1.typing import AnyCallable
|
||||
from pydantic.v1.utils import ROOT_KEY, in_ipython
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.typing import AnyClassMethod
|
||||
|
||||
|
||||
class Validator:
|
||||
__slots__ = 'func', 'pre', 'each_item', 'always', 'check_fields', 'skip_on_failure'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: AnyCallable,
|
||||
pre: bool = False,
|
||||
each_item: bool = False,
|
||||
always: bool = False,
|
||||
check_fields: bool = False,
|
||||
skip_on_failure: bool = False,
|
||||
):
|
||||
self.func = func
|
||||
self.pre = pre
|
||||
self.each_item = each_item
|
||||
self.always = always
|
||||
self.check_fields = check_fields
|
||||
self.skip_on_failure = skip_on_failure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from inspect import Signature
|
||||
|
||||
from pydantic.v1.config import BaseConfig
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic.v1.types import ModelOrDc
|
||||
|
||||
ValidatorCallable = Callable[[Optional[ModelOrDc], Any, Dict[str, Any], ModelField, Type[BaseConfig]], Any]
|
||||
ValidatorsList = List[ValidatorCallable]
|
||||
ValidatorListDict = Dict[str, List[Validator]]
|
||||
|
||||
_FUNCS: Set[str] = set()
|
||||
VALIDATOR_CONFIG_KEY = '__validator_config__'
|
||||
ROOT_VALIDATOR_CONFIG_KEY = '__root_validator_config__'
|
||||
|
||||
|
||||
def validator(
|
||||
*fields: str,
|
||||
pre: bool = False,
|
||||
each_item: bool = False,
|
||||
always: bool = False,
|
||||
check_fields: bool = True,
|
||||
whole: Optional[bool] = None,
|
||||
allow_reuse: bool = False,
|
||||
) -> Callable[[AnyCallable], 'AnyClassMethod']:
|
||||
"""
|
||||
Decorate methods on the class indicating that they should be used to validate fields
|
||||
:param fields: which field(s) the method should be called on
|
||||
:param pre: whether or not this validator should be called before the standard validators (else after)
|
||||
:param each_item: for complex objects (sets, lists etc.) whether to validate individual elements rather than the
|
||||
whole object
|
||||
:param always: whether this method and other validators should be called even if the value is missing
|
||||
:param check_fields: whether to check that the fields actually exist on the model
|
||||
:param allow_reuse: whether to track and raise an error if another validator refers to the decorated function
|
||||
"""
|
||||
if not fields:
|
||||
raise ConfigError('validator with no fields specified')
|
||||
elif isinstance(fields[0], FunctionType):
|
||||
raise ConfigError(
|
||||
"validators should be used with fields and keyword arguments, not bare. " # noqa: Q000
|
||||
"E.g. usage should be `@validator('<field_name>', ...)`"
|
||||
)
|
||||
elif not all(isinstance(field, str) for field in fields):
|
||||
raise ConfigError(
|
||||
"validator fields should be passed as separate string args. " # noqa: Q000
|
||||
"E.g. usage should be `@validator('<field_name_1>', '<field_name_2>', ...)`"
|
||||
)
|
||||
|
||||
if whole is not None:
|
||||
warnings.warn(
|
||||
'The "whole" keyword argument is deprecated, use "each_item" (inverse meaning, default False) instead',
|
||||
DeprecationWarning,
|
||||
)
|
||||
assert each_item is False, '"each_item" and "whole" conflict, remove "whole"'
|
||||
each_item = not whole
|
||||
|
||||
def dec(f: AnyCallable) -> 'AnyClassMethod':
|
||||
f_cls = _prepare_validator(f, allow_reuse)
|
||||
setattr(
|
||||
f_cls,
|
||||
VALIDATOR_CONFIG_KEY,
|
||||
(
|
||||
fields,
|
||||
Validator(func=f_cls.__func__, pre=pre, each_item=each_item, always=always, check_fields=check_fields),
|
||||
),
|
||||
)
|
||||
return f_cls
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
@overload
|
||||
def root_validator(_func: AnyCallable) -> 'AnyClassMethod':
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def root_validator(
|
||||
*, pre: bool = False, allow_reuse: bool = False, skip_on_failure: bool = False
|
||||
) -> Callable[[AnyCallable], 'AnyClassMethod']:
|
||||
...
|
||||
|
||||
|
||||
def root_validator(
|
||||
_func: Optional[AnyCallable] = None, *, pre: bool = False, allow_reuse: bool = False, skip_on_failure: bool = False
|
||||
) -> Union['AnyClassMethod', Callable[[AnyCallable], 'AnyClassMethod']]:
|
||||
"""
|
||||
Decorate methods on a model indicating that they should be used to validate (and perhaps modify) data either
|
||||
before or after standard model parsing/validation is performed.
|
||||
"""
|
||||
if _func:
|
||||
f_cls = _prepare_validator(_func, allow_reuse)
|
||||
setattr(
|
||||
f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure)
|
||||
)
|
||||
return f_cls
|
||||
|
||||
def dec(f: AnyCallable) -> 'AnyClassMethod':
|
||||
f_cls = _prepare_validator(f, allow_reuse)
|
||||
setattr(
|
||||
f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure)
|
||||
)
|
||||
return f_cls
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
def _prepare_validator(function: AnyCallable, allow_reuse: bool) -> 'AnyClassMethod':
|
||||
"""
|
||||
Avoid validators with duplicated names since without this, validators can be overwritten silently
|
||||
which generally isn't the intended behaviour, don't run in ipython (see #312) or if allow_reuse is False.
|
||||
"""
|
||||
f_cls = function if isinstance(function, classmethod) else classmethod(function)
|
||||
if not in_ipython() and not allow_reuse:
|
||||
ref = (
|
||||
getattr(f_cls.__func__, '__module__', '<No __module__>')
|
||||
+ '.'
|
||||
+ getattr(f_cls.__func__, '__qualname__', f'<No __qualname__: id:{id(f_cls.__func__)}>')
|
||||
)
|
||||
if ref in _FUNCS:
|
||||
raise ConfigError(f'duplicate validator function "{ref}"; if this is intended, set `allow_reuse=True`')
|
||||
_FUNCS.add(ref)
|
||||
return f_cls
|
||||
|
||||
|
||||
class ValidatorGroup:
|
||||
def __init__(self, validators: 'ValidatorListDict') -> None:
|
||||
self.validators = validators
|
||||
self.used_validators = {'*'}
|
||||
|
||||
def get_validators(self, name: str) -> Optional[Dict[str, Validator]]:
|
||||
self.used_validators.add(name)
|
||||
validators = self.validators.get(name, [])
|
||||
if name != ROOT_KEY:
|
||||
validators += self.validators.get('*', [])
|
||||
if validators:
|
||||
return {getattr(v.func, '__name__', f'<No __name__: id:{id(v.func)}>'): v for v in validators}
|
||||
else:
|
||||
return None
|
||||
|
||||
def check_for_unused(self) -> None:
|
||||
unused_validators = set(
|
||||
chain.from_iterable(
|
||||
(
|
||||
getattr(v.func, '__name__', f'<No __name__: id:{id(v.func)}>')
|
||||
for v in self.validators[f]
|
||||
if v.check_fields
|
||||
)
|
||||
for f in (self.validators.keys() - self.used_validators)
|
||||
)
|
||||
)
|
||||
if unused_validators:
|
||||
fn = ', '.join(unused_validators)
|
||||
raise ConfigError(
|
||||
f"Validators defined with incorrect fields: {fn} " # noqa: Q000
|
||||
f"(use check_fields=False if you're inheriting from the model and intended this)"
|
||||
)
|
||||
|
||||
|
||||
def extract_validators(namespace: Dict[str, Any]) -> Dict[str, List[Validator]]:
|
||||
validators: Dict[str, List[Validator]] = {}
|
||||
for var_name, value in namespace.items():
|
||||
validator_config = getattr(value, VALIDATOR_CONFIG_KEY, None)
|
||||
if validator_config:
|
||||
fields, v = validator_config
|
||||
for field in fields:
|
||||
if field in validators:
|
||||
validators[field].append(v)
|
||||
else:
|
||||
validators[field] = [v]
|
||||
return validators
|
||||
|
||||
|
||||
def extract_root_validators(namespace: Dict[str, Any]) -> Tuple[List[AnyCallable], List[Tuple[bool, AnyCallable]]]:
|
||||
from inspect import signature
|
||||
|
||||
pre_validators: List[AnyCallable] = []
|
||||
post_validators: List[Tuple[bool, AnyCallable]] = []
|
||||
for name, value in namespace.items():
|
||||
validator_config: Optional[Validator] = getattr(value, ROOT_VALIDATOR_CONFIG_KEY, None)
|
||||
if validator_config:
|
||||
sig = signature(validator_config.func)
|
||||
args = list(sig.parameters.keys())
|
||||
if args[0] == 'self':
|
||||
raise ConfigError(
|
||||
f'Invalid signature for root validator {name}: {sig}, "self" not permitted as first argument, '
|
||||
f'should be: (cls, values).'
|
||||
)
|
||||
if len(args) != 2:
|
||||
raise ConfigError(f'Invalid signature for root validator {name}: {sig}, should be: (cls, values).')
|
||||
# check function signature
|
||||
if validator_config.pre:
|
||||
pre_validators.append(validator_config.func)
|
||||
else:
|
||||
post_validators.append((validator_config.skip_on_failure, validator_config.func))
|
||||
return pre_validators, post_validators
|
||||
|
||||
|
||||
def inherit_validators(base_validators: 'ValidatorListDict', validators: 'ValidatorListDict') -> 'ValidatorListDict':
|
||||
for field, field_validators in base_validators.items():
|
||||
if field not in validators:
|
||||
validators[field] = []
|
||||
validators[field] += field_validators
|
||||
return validators
|
||||
|
||||
|
||||
def make_generic_validator(validator: AnyCallable) -> 'ValidatorCallable':
|
||||
"""
|
||||
Make a generic function which calls a validator with the right arguments.
|
||||
|
||||
Unfortunately other approaches (eg. return a partial of a function that builds the arguments) is slow,
|
||||
hence this laborious way of doing things.
|
||||
|
||||
It's done like this so validators don't all need **kwargs in their signature, eg. any combination of
|
||||
the arguments "values", "fields" and/or "config" are permitted.
|
||||
"""
|
||||
from inspect import signature
|
||||
|
||||
if not isinstance(validator, (partial, partialmethod)):
|
||||
# This should be the default case, so overhead is reduced
|
||||
sig = signature(validator)
|
||||
args = list(sig.parameters.keys())
|
||||
else:
|
||||
# Fix the generated argument lists of partial methods
|
||||
sig = signature(validator.func)
|
||||
args = [
|
||||
k
|
||||
for k in signature(validator.func).parameters.keys()
|
||||
if k not in validator.args | validator.keywords.keys()
|
||||
]
|
||||
|
||||
first_arg = args.pop(0)
|
||||
if first_arg == 'self':
|
||||
raise ConfigError(
|
||||
f'Invalid signature for validator {validator}: {sig}, "self" not permitted as first argument, '
|
||||
f'should be: (cls, value, values, config, field), "values", "config" and "field" are all optional.'
|
||||
)
|
||||
elif first_arg == 'cls':
|
||||
# assume the second argument is value
|
||||
return wraps(validator)(_generic_validator_cls(validator, sig, set(args[1:])))
|
||||
else:
|
||||
# assume the first argument was value which has already been removed
|
||||
return wraps(validator)(_generic_validator_basic(validator, sig, set(args)))
|
||||
|
||||
|
||||
def prep_validators(v_funcs: Iterable[AnyCallable]) -> 'ValidatorsList':
|
||||
return [make_generic_validator(f) for f in v_funcs if f]
|
||||
|
||||
|
||||
all_kwargs = {'values', 'field', 'config'}
|
||||
|
||||
|
||||
def _generic_validator_cls(validator: AnyCallable, sig: 'Signature', args: Set[str]) -> 'ValidatorCallable':
|
||||
# assume the first argument is value
|
||||
has_kwargs = False
|
||||
if 'kwargs' in args:
|
||||
has_kwargs = True
|
||||
args -= {'kwargs'}
|
||||
|
||||
if not args.issubset(all_kwargs):
|
||||
raise ConfigError(
|
||||
f'Invalid signature for validator {validator}: {sig}, should be: '
|
||||
f'(cls, value, values, config, field), "values", "config" and "field" are all optional.'
|
||||
)
|
||||
|
||||
if has_kwargs:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field, config=config)
|
||||
elif args == set():
|
||||
return lambda cls, v, values, field, config: validator(cls, v)
|
||||
elif args == {'values'}:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, values=values)
|
||||
elif args == {'field'}:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, field=field)
|
||||
elif args == {'config'}:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, config=config)
|
||||
elif args == {'values', 'field'}:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field)
|
||||
elif args == {'values', 'config'}:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, values=values, config=config)
|
||||
elif args == {'field', 'config'}:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, field=field, config=config)
|
||||
else:
|
||||
# args == {'values', 'field', 'config'}
|
||||
return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field, config=config)
|
||||
|
||||
|
||||
def _generic_validator_basic(validator: AnyCallable, sig: 'Signature', args: Set[str]) -> 'ValidatorCallable':
|
||||
has_kwargs = False
|
||||
if 'kwargs' in args:
|
||||
has_kwargs = True
|
||||
args -= {'kwargs'}
|
||||
|
||||
if not args.issubset(all_kwargs):
|
||||
raise ConfigError(
|
||||
f'Invalid signature for validator {validator}: {sig}, should be: '
|
||||
f'(value, values, config, field), "values", "config" and "field" are all optional.'
|
||||
)
|
||||
|
||||
if has_kwargs:
|
||||
return lambda cls, v, values, field, config: validator(v, values=values, field=field, config=config)
|
||||
elif args == set():
|
||||
return lambda cls, v, values, field, config: validator(v)
|
||||
elif args == {'values'}:
|
||||
return lambda cls, v, values, field, config: validator(v, values=values)
|
||||
elif args == {'field'}:
|
||||
return lambda cls, v, values, field, config: validator(v, field=field)
|
||||
elif args == {'config'}:
|
||||
return lambda cls, v, values, field, config: validator(v, config=config)
|
||||
elif args == {'values', 'field'}:
|
||||
return lambda cls, v, values, field, config: validator(v, values=values, field=field)
|
||||
elif args == {'values', 'config'}:
|
||||
return lambda cls, v, values, field, config: validator(v, values=values, config=config)
|
||||
elif args == {'field', 'config'}:
|
||||
return lambda cls, v, values, field, config: validator(v, field=field, config=config)
|
||||
else:
|
||||
# args == {'values', 'field', 'config'}
|
||||
return lambda cls, v, values, field, config: validator(v, values=values, field=field, config=config)
|
||||
|
||||
|
||||
def gather_all_validators(type_: 'ModelOrDc') -> Dict[str, 'AnyClassMethod']:
|
||||
all_attributes = ChainMap(*[cls.__dict__ for cls in type_.__mro__]) # type: ignore[arg-type,var-annotated]
|
||||
return {
|
||||
k: v
|
||||
for k, v in all_attributes.items()
|
||||
if hasattr(v, VALIDATOR_CONFIG_KEY) or hasattr(v, ROOT_VALIDATOR_CONFIG_KEY)
|
||||
}
|
||||
@@ -0,0 +1,494 @@
|
||||
"""
|
||||
Color definitions are used as per CSS3 specification:
|
||||
http://www.w3.org/TR/css3-color/#svg-color
|
||||
|
||||
A few colors have multiple names referring to the sames colors, eg. `grey` and `gray` or `aqua` and `cyan`.
|
||||
|
||||
In these cases the LAST color when sorted alphabetically takes preferences,
|
||||
eg. Color((0, 255, 255)).as_named() == 'cyan' because "cyan" comes after "aqua".
|
||||
"""
|
||||
import math
|
||||
import re
|
||||
from colorsys import hls_to_rgb, rgb_to_hls
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, cast
|
||||
|
||||
from pydantic.v1.errors import ColorError
|
||||
from pydantic.v1.utils import Representation, almost_equal_floats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.typing import CallableGenerator, ReprArgs
|
||||
|
||||
ColorTuple = Union[Tuple[int, int, int], Tuple[int, int, int, float]]
|
||||
ColorType = Union[ColorTuple, str]
|
||||
HslColorTuple = Union[Tuple[float, float, float], Tuple[float, float, float, float]]
|
||||
|
||||
|
||||
class RGBA:
|
||||
"""
|
||||
Internal use only as a representation of a color.
|
||||
"""
|
||||
|
||||
__slots__ = 'r', 'g', 'b', 'alpha', '_tuple'
|
||||
|
||||
def __init__(self, r: float, g: float, b: float, alpha: Optional[float]):
|
||||
self.r = r
|
||||
self.g = g
|
||||
self.b = b
|
||||
self.alpha = alpha
|
||||
|
||||
self._tuple: Tuple[float, float, float, Optional[float]] = (r, g, b, alpha)
|
||||
|
||||
def __getitem__(self, item: Any) -> Any:
|
||||
return self._tuple[item]
|
||||
|
||||
|
||||
# these are not compiled here to avoid import slowdown, they'll be compiled the first time they're used, then cached
|
||||
r_hex_short = r'\s*(?:#|0x)?([0-9a-f])([0-9a-f])([0-9a-f])([0-9a-f])?\s*'
|
||||
r_hex_long = r'\s*(?:#|0x)?([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})?\s*'
|
||||
_r_255 = r'(\d{1,3}(?:\.\d+)?)'
|
||||
_r_comma = r'\s*,\s*'
|
||||
r_rgb = fr'\s*rgb\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}\)\s*'
|
||||
_r_alpha = r'(\d(?:\.\d+)?|\.\d+|\d{1,2}%)'
|
||||
r_rgba = fr'\s*rgba\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_alpha}\s*\)\s*'
|
||||
_r_h = r'(-?\d+(?:\.\d+)?|-?\.\d+)(deg|rad|turn)?'
|
||||
_r_sl = r'(\d{1,3}(?:\.\d+)?)%'
|
||||
r_hsl = fr'\s*hsl\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}\s*\)\s*'
|
||||
r_hsla = fr'\s*hsl\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}{_r_comma}{_r_alpha}\s*\)\s*'
|
||||
|
||||
# colors where the two hex characters are the same, if all colors match this the short version of hex colors can be used
|
||||
repeat_colors = {int(c * 2, 16) for c in '0123456789abcdef'}
|
||||
rads = 2 * math.pi
|
||||
|
||||
|
||||
class Color(Representation):
|
||||
__slots__ = '_original', '_rgba'
|
||||
|
||||
def __init__(self, value: ColorType) -> None:
|
||||
self._rgba: RGBA
|
||||
self._original: ColorType
|
||||
if isinstance(value, (tuple, list)):
|
||||
self._rgba = parse_tuple(value)
|
||||
elif isinstance(value, str):
|
||||
self._rgba = parse_str(value)
|
||||
elif isinstance(value, Color):
|
||||
self._rgba = value._rgba
|
||||
value = value._original
|
||||
else:
|
||||
raise ColorError(reason='value must be a tuple, list or string')
|
||||
|
||||
# if we've got here value must be a valid color
|
||||
self._original = value
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(type='string', format='color')
|
||||
|
||||
def original(self) -> ColorType:
|
||||
"""
|
||||
Original value passed to Color
|
||||
"""
|
||||
return self._original
|
||||
|
||||
def as_named(self, *, fallback: bool = False) -> str:
|
||||
if self._rgba.alpha is None:
|
||||
rgb = cast(Tuple[int, int, int], self.as_rgb_tuple())
|
||||
try:
|
||||
return COLORS_BY_VALUE[rgb]
|
||||
except KeyError as e:
|
||||
if fallback:
|
||||
return self.as_hex()
|
||||
else:
|
||||
raise ValueError('no named color found, use fallback=True, as_hex() or as_rgb()') from e
|
||||
else:
|
||||
return self.as_hex()
|
||||
|
||||
def as_hex(self) -> str:
|
||||
"""
|
||||
Hex string representing the color can be 3, 4, 6 or 8 characters depending on whether the string
|
||||
a "short" representation of the color is possible and whether there's an alpha channel.
|
||||
"""
|
||||
values = [float_to_255(c) for c in self._rgba[:3]]
|
||||
if self._rgba.alpha is not None:
|
||||
values.append(float_to_255(self._rgba.alpha))
|
||||
|
||||
as_hex = ''.join(f'{v:02x}' for v in values)
|
||||
if all(c in repeat_colors for c in values):
|
||||
as_hex = ''.join(as_hex[c] for c in range(0, len(as_hex), 2))
|
||||
return '#' + as_hex
|
||||
|
||||
def as_rgb(self) -> str:
|
||||
"""
|
||||
Color as an rgb(<r>, <g>, <b>) or rgba(<r>, <g>, <b>, <a>) string.
|
||||
"""
|
||||
if self._rgba.alpha is None:
|
||||
return f'rgb({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)})'
|
||||
else:
|
||||
return (
|
||||
f'rgba({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)}, '
|
||||
f'{round(self._alpha_float(), 2)})'
|
||||
)
|
||||
|
||||
def as_rgb_tuple(self, *, alpha: Optional[bool] = None) -> ColorTuple:
|
||||
"""
|
||||
Color as an RGB or RGBA tuple; red, green and blue are in the range 0 to 255, alpha if included is
|
||||
in the range 0 to 1.
|
||||
|
||||
:param alpha: whether to include the alpha channel, options are
|
||||
None - (default) include alpha only if it's set (e.g. not None)
|
||||
True - always include alpha,
|
||||
False - always omit alpha,
|
||||
"""
|
||||
r, g, b = (float_to_255(c) for c in self._rgba[:3])
|
||||
if alpha is None:
|
||||
if self._rgba.alpha is None:
|
||||
return r, g, b
|
||||
else:
|
||||
return r, g, b, self._alpha_float()
|
||||
elif alpha:
|
||||
return r, g, b, self._alpha_float()
|
||||
else:
|
||||
# alpha is False
|
||||
return r, g, b
|
||||
|
||||
def as_hsl(self) -> str:
|
||||
"""
|
||||
Color as an hsl(<h>, <s>, <l>) or hsl(<h>, <s>, <l>, <a>) string.
|
||||
"""
|
||||
if self._rgba.alpha is None:
|
||||
h, s, li = self.as_hsl_tuple(alpha=False) # type: ignore
|
||||
return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%})'
|
||||
else:
|
||||
h, s, li, a = self.as_hsl_tuple(alpha=True) # type: ignore
|
||||
return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%}, {round(a, 2)})'
|
||||
|
||||
def as_hsl_tuple(self, *, alpha: Optional[bool] = None) -> HslColorTuple:
|
||||
"""
|
||||
Color as an HSL or HSLA tuple, e.g. hue, saturation, lightness and optionally alpha; all elements are in
|
||||
the range 0 to 1.
|
||||
|
||||
NOTE: this is HSL as used in HTML and most other places, not HLS as used in python's colorsys.
|
||||
|
||||
:param alpha: whether to include the alpha channel, options are
|
||||
None - (default) include alpha only if it's set (e.g. not None)
|
||||
True - always include alpha,
|
||||
False - always omit alpha,
|
||||
"""
|
||||
h, l, s = rgb_to_hls(self._rgba.r, self._rgba.g, self._rgba.b)
|
||||
if alpha is None:
|
||||
if self._rgba.alpha is None:
|
||||
return h, s, l
|
||||
else:
|
||||
return h, s, l, self._alpha_float()
|
||||
if alpha:
|
||||
return h, s, l, self._alpha_float()
|
||||
else:
|
||||
# alpha is False
|
||||
return h, s, l
|
||||
|
||||
def _alpha_float(self) -> float:
|
||||
return 1 if self._rgba.alpha is None else self._rgba.alpha
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
yield cls
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.as_named(fallback=True)
|
||||
|
||||
def __repr_args__(self) -> 'ReprArgs':
|
||||
return [(None, self.as_named(fallback=True))] + [('rgb', self.as_rgb_tuple())] # type: ignore
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, Color) and self.as_rgb_tuple() == other.as_rgb_tuple()
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.as_rgb_tuple())
|
||||
|
||||
|
||||
def parse_tuple(value: Tuple[Any, ...]) -> RGBA:
|
||||
"""
|
||||
Parse a tuple or list as a color.
|
||||
"""
|
||||
if len(value) == 3:
|
||||
r, g, b = (parse_color_value(v) for v in value)
|
||||
return RGBA(r, g, b, None)
|
||||
elif len(value) == 4:
|
||||
r, g, b = (parse_color_value(v) for v in value[:3])
|
||||
return RGBA(r, g, b, parse_float_alpha(value[3]))
|
||||
else:
|
||||
raise ColorError(reason='tuples must have length 3 or 4')
|
||||
|
||||
|
||||
def parse_str(value: str) -> RGBA:
|
||||
"""
|
||||
Parse a string to an RGBA tuple, trying the following formats (in this order):
|
||||
* named color, see COLORS_BY_NAME below
|
||||
* hex short eg. `<prefix>fff` (prefix can be `#`, `0x` or nothing)
|
||||
* hex long eg. `<prefix>ffffff` (prefix can be `#`, `0x` or nothing)
|
||||
* `rgb(<r>, <g>, <b>) `
|
||||
* `rgba(<r>, <g>, <b>, <a>)`
|
||||
"""
|
||||
value_lower = value.lower()
|
||||
try:
|
||||
r, g, b = COLORS_BY_NAME[value_lower]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
return ints_to_rgba(r, g, b, None)
|
||||
|
||||
m = re.fullmatch(r_hex_short, value_lower)
|
||||
if m:
|
||||
*rgb, a = m.groups()
|
||||
r, g, b = (int(v * 2, 16) for v in rgb)
|
||||
if a:
|
||||
alpha: Optional[float] = int(a * 2, 16) / 255
|
||||
else:
|
||||
alpha = None
|
||||
return ints_to_rgba(r, g, b, alpha)
|
||||
|
||||
m = re.fullmatch(r_hex_long, value_lower)
|
||||
if m:
|
||||
*rgb, a = m.groups()
|
||||
r, g, b = (int(v, 16) for v in rgb)
|
||||
if a:
|
||||
alpha = int(a, 16) / 255
|
||||
else:
|
||||
alpha = None
|
||||
return ints_to_rgba(r, g, b, alpha)
|
||||
|
||||
m = re.fullmatch(r_rgb, value_lower)
|
||||
if m:
|
||||
return ints_to_rgba(*m.groups(), None) # type: ignore
|
||||
|
||||
m = re.fullmatch(r_rgba, value_lower)
|
||||
if m:
|
||||
return ints_to_rgba(*m.groups()) # type: ignore
|
||||
|
||||
m = re.fullmatch(r_hsl, value_lower)
|
||||
if m:
|
||||
h, h_units, s, l_ = m.groups()
|
||||
return parse_hsl(h, h_units, s, l_)
|
||||
|
||||
m = re.fullmatch(r_hsla, value_lower)
|
||||
if m:
|
||||
h, h_units, s, l_, a = m.groups()
|
||||
return parse_hsl(h, h_units, s, l_, parse_float_alpha(a))
|
||||
|
||||
raise ColorError(reason='string not recognised as a valid color')
|
||||
|
||||
|
||||
def ints_to_rgba(r: Union[int, str], g: Union[int, str], b: Union[int, str], alpha: Optional[float]) -> RGBA:
|
||||
return RGBA(parse_color_value(r), parse_color_value(g), parse_color_value(b), parse_float_alpha(alpha))
|
||||
|
||||
|
||||
def parse_color_value(value: Union[int, str], max_val: int = 255) -> float:
|
||||
"""
|
||||
Parse a value checking it's a valid int in the range 0 to max_val and divide by max_val to give a number
|
||||
in the range 0 to 1
|
||||
"""
|
||||
try:
|
||||
color = float(value)
|
||||
except ValueError:
|
||||
raise ColorError(reason='color values must be a valid number')
|
||||
if 0 <= color <= max_val:
|
||||
return color / max_val
|
||||
else:
|
||||
raise ColorError(reason=f'color values must be in the range 0 to {max_val}')
|
||||
|
||||
|
||||
def parse_float_alpha(value: Union[None, str, float, int]) -> Optional[float]:
|
||||
"""
|
||||
Parse a value checking it's a valid float in the range 0 to 1
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
if isinstance(value, str) and value.endswith('%'):
|
||||
alpha = float(value[:-1]) / 100
|
||||
else:
|
||||
alpha = float(value)
|
||||
except ValueError:
|
||||
raise ColorError(reason='alpha values must be a valid float')
|
||||
|
||||
if almost_equal_floats(alpha, 1):
|
||||
return None
|
||||
elif 0 <= alpha <= 1:
|
||||
return alpha
|
||||
else:
|
||||
raise ColorError(reason='alpha values must be in the range 0 to 1')
|
||||
|
||||
|
||||
def parse_hsl(h: str, h_units: str, sat: str, light: str, alpha: Optional[float] = None) -> RGBA:
|
||||
"""
|
||||
Parse raw hue, saturation, lightness and alpha values and convert to RGBA.
|
||||
"""
|
||||
s_value, l_value = parse_color_value(sat, 100), parse_color_value(light, 100)
|
||||
|
||||
h_value = float(h)
|
||||
if h_units in {None, 'deg'}:
|
||||
h_value = h_value % 360 / 360
|
||||
elif h_units == 'rad':
|
||||
h_value = h_value % rads / rads
|
||||
else:
|
||||
# turns
|
||||
h_value = h_value % 1
|
||||
|
||||
r, g, b = hls_to_rgb(h_value, l_value, s_value)
|
||||
return RGBA(r, g, b, alpha)
|
||||
|
||||
|
||||
def float_to_255(c: float) -> int:
|
||||
return int(round(c * 255))
|
||||
|
||||
|
||||
COLORS_BY_NAME = {
|
||||
'aliceblue': (240, 248, 255),
|
||||
'antiquewhite': (250, 235, 215),
|
||||
'aqua': (0, 255, 255),
|
||||
'aquamarine': (127, 255, 212),
|
||||
'azure': (240, 255, 255),
|
||||
'beige': (245, 245, 220),
|
||||
'bisque': (255, 228, 196),
|
||||
'black': (0, 0, 0),
|
||||
'blanchedalmond': (255, 235, 205),
|
||||
'blue': (0, 0, 255),
|
||||
'blueviolet': (138, 43, 226),
|
||||
'brown': (165, 42, 42),
|
||||
'burlywood': (222, 184, 135),
|
||||
'cadetblue': (95, 158, 160),
|
||||
'chartreuse': (127, 255, 0),
|
||||
'chocolate': (210, 105, 30),
|
||||
'coral': (255, 127, 80),
|
||||
'cornflowerblue': (100, 149, 237),
|
||||
'cornsilk': (255, 248, 220),
|
||||
'crimson': (220, 20, 60),
|
||||
'cyan': (0, 255, 255),
|
||||
'darkblue': (0, 0, 139),
|
||||
'darkcyan': (0, 139, 139),
|
||||
'darkgoldenrod': (184, 134, 11),
|
||||
'darkgray': (169, 169, 169),
|
||||
'darkgreen': (0, 100, 0),
|
||||
'darkgrey': (169, 169, 169),
|
||||
'darkkhaki': (189, 183, 107),
|
||||
'darkmagenta': (139, 0, 139),
|
||||
'darkolivegreen': (85, 107, 47),
|
||||
'darkorange': (255, 140, 0),
|
||||
'darkorchid': (153, 50, 204),
|
||||
'darkred': (139, 0, 0),
|
||||
'darksalmon': (233, 150, 122),
|
||||
'darkseagreen': (143, 188, 143),
|
||||
'darkslateblue': (72, 61, 139),
|
||||
'darkslategray': (47, 79, 79),
|
||||
'darkslategrey': (47, 79, 79),
|
||||
'darkturquoise': (0, 206, 209),
|
||||
'darkviolet': (148, 0, 211),
|
||||
'deeppink': (255, 20, 147),
|
||||
'deepskyblue': (0, 191, 255),
|
||||
'dimgray': (105, 105, 105),
|
||||
'dimgrey': (105, 105, 105),
|
||||
'dodgerblue': (30, 144, 255),
|
||||
'firebrick': (178, 34, 34),
|
||||
'floralwhite': (255, 250, 240),
|
||||
'forestgreen': (34, 139, 34),
|
||||
'fuchsia': (255, 0, 255),
|
||||
'gainsboro': (220, 220, 220),
|
||||
'ghostwhite': (248, 248, 255),
|
||||
'gold': (255, 215, 0),
|
||||
'goldenrod': (218, 165, 32),
|
||||
'gray': (128, 128, 128),
|
||||
'green': (0, 128, 0),
|
||||
'greenyellow': (173, 255, 47),
|
||||
'grey': (128, 128, 128),
|
||||
'honeydew': (240, 255, 240),
|
||||
'hotpink': (255, 105, 180),
|
||||
'indianred': (205, 92, 92),
|
||||
'indigo': (75, 0, 130),
|
||||
'ivory': (255, 255, 240),
|
||||
'khaki': (240, 230, 140),
|
||||
'lavender': (230, 230, 250),
|
||||
'lavenderblush': (255, 240, 245),
|
||||
'lawngreen': (124, 252, 0),
|
||||
'lemonchiffon': (255, 250, 205),
|
||||
'lightblue': (173, 216, 230),
|
||||
'lightcoral': (240, 128, 128),
|
||||
'lightcyan': (224, 255, 255),
|
||||
'lightgoldenrodyellow': (250, 250, 210),
|
||||
'lightgray': (211, 211, 211),
|
||||
'lightgreen': (144, 238, 144),
|
||||
'lightgrey': (211, 211, 211),
|
||||
'lightpink': (255, 182, 193),
|
||||
'lightsalmon': (255, 160, 122),
|
||||
'lightseagreen': (32, 178, 170),
|
||||
'lightskyblue': (135, 206, 250),
|
||||
'lightslategray': (119, 136, 153),
|
||||
'lightslategrey': (119, 136, 153),
|
||||
'lightsteelblue': (176, 196, 222),
|
||||
'lightyellow': (255, 255, 224),
|
||||
'lime': (0, 255, 0),
|
||||
'limegreen': (50, 205, 50),
|
||||
'linen': (250, 240, 230),
|
||||
'magenta': (255, 0, 255),
|
||||
'maroon': (128, 0, 0),
|
||||
'mediumaquamarine': (102, 205, 170),
|
||||
'mediumblue': (0, 0, 205),
|
||||
'mediumorchid': (186, 85, 211),
|
||||
'mediumpurple': (147, 112, 219),
|
||||
'mediumseagreen': (60, 179, 113),
|
||||
'mediumslateblue': (123, 104, 238),
|
||||
'mediumspringgreen': (0, 250, 154),
|
||||
'mediumturquoise': (72, 209, 204),
|
||||
'mediumvioletred': (199, 21, 133),
|
||||
'midnightblue': (25, 25, 112),
|
||||
'mintcream': (245, 255, 250),
|
||||
'mistyrose': (255, 228, 225),
|
||||
'moccasin': (255, 228, 181),
|
||||
'navajowhite': (255, 222, 173),
|
||||
'navy': (0, 0, 128),
|
||||
'oldlace': (253, 245, 230),
|
||||
'olive': (128, 128, 0),
|
||||
'olivedrab': (107, 142, 35),
|
||||
'orange': (255, 165, 0),
|
||||
'orangered': (255, 69, 0),
|
||||
'orchid': (218, 112, 214),
|
||||
'palegoldenrod': (238, 232, 170),
|
||||
'palegreen': (152, 251, 152),
|
||||
'paleturquoise': (175, 238, 238),
|
||||
'palevioletred': (219, 112, 147),
|
||||
'papayawhip': (255, 239, 213),
|
||||
'peachpuff': (255, 218, 185),
|
||||
'peru': (205, 133, 63),
|
||||
'pink': (255, 192, 203),
|
||||
'plum': (221, 160, 221),
|
||||
'powderblue': (176, 224, 230),
|
||||
'purple': (128, 0, 128),
|
||||
'red': (255, 0, 0),
|
||||
'rosybrown': (188, 143, 143),
|
||||
'royalblue': (65, 105, 225),
|
||||
'saddlebrown': (139, 69, 19),
|
||||
'salmon': (250, 128, 114),
|
||||
'sandybrown': (244, 164, 96),
|
||||
'seagreen': (46, 139, 87),
|
||||
'seashell': (255, 245, 238),
|
||||
'sienna': (160, 82, 45),
|
||||
'silver': (192, 192, 192),
|
||||
'skyblue': (135, 206, 235),
|
||||
'slateblue': (106, 90, 205),
|
||||
'slategray': (112, 128, 144),
|
||||
'slategrey': (112, 128, 144),
|
||||
'snow': (255, 250, 250),
|
||||
'springgreen': (0, 255, 127),
|
||||
'steelblue': (70, 130, 180),
|
||||
'tan': (210, 180, 140),
|
||||
'teal': (0, 128, 128),
|
||||
'thistle': (216, 191, 216),
|
||||
'tomato': (255, 99, 71),
|
||||
'turquoise': (64, 224, 208),
|
||||
'violet': (238, 130, 238),
|
||||
'wheat': (245, 222, 179),
|
||||
'white': (255, 255, 255),
|
||||
'whitesmoke': (245, 245, 245),
|
||||
'yellow': (255, 255, 0),
|
||||
'yellowgreen': (154, 205, 50),
|
||||
}
|
||||
|
||||
COLORS_BY_VALUE = {v: k for k, v in COLORS_BY_NAME.items()}
|
||||
@@ -0,0 +1,191 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, ForwardRef, Optional, Tuple, Type, Union
|
||||
|
||||
from typing_extensions import Literal, Protocol
|
||||
|
||||
from pydantic.v1.typing import AnyArgTCallable, AnyCallable
|
||||
from pydantic.v1.utils import GetterDict
|
||||
from pydantic.v1.version import compiled
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import overload
|
||||
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic.v1.main import BaseModel
|
||||
|
||||
ConfigType = Type['BaseConfig']
|
||||
|
||||
class SchemaExtraCallable(Protocol):
|
||||
@overload
|
||||
def __call__(self, schema: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def __call__(self, schema: Dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
pass
|
||||
|
||||
else:
|
||||
SchemaExtraCallable = Callable[..., None]
|
||||
|
||||
__all__ = 'BaseConfig', 'ConfigDict', 'get_config', 'Extra', 'inherit_config', 'prepare_config'
|
||||
|
||||
|
||||
class Extra(str, Enum):
|
||||
allow = 'allow'
|
||||
ignore = 'ignore'
|
||||
forbid = 'forbid'
|
||||
|
||||
|
||||
# https://github.com/cython/cython/issues/4003
|
||||
# Fixed in Cython 3 and Pydantic v1 won't support Cython 3.
|
||||
# Pydantic v2 doesn't depend on Cython at all.
|
||||
if not compiled:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
class ConfigDict(TypedDict, total=False):
|
||||
title: Optional[str]
|
||||
anystr_lower: bool
|
||||
anystr_strip_whitespace: bool
|
||||
min_anystr_length: int
|
||||
max_anystr_length: Optional[int]
|
||||
validate_all: bool
|
||||
extra: Extra
|
||||
allow_mutation: bool
|
||||
frozen: bool
|
||||
allow_population_by_field_name: bool
|
||||
use_enum_values: bool
|
||||
fields: Dict[str, Union[str, Dict[str, str]]]
|
||||
validate_assignment: bool
|
||||
error_msg_templates: Dict[str, str]
|
||||
arbitrary_types_allowed: bool
|
||||
orm_mode: bool
|
||||
getter_dict: Type[GetterDict]
|
||||
alias_generator: Optional[Callable[[str], str]]
|
||||
keep_untouched: Tuple[type, ...]
|
||||
schema_extra: Union[Dict[str, object], 'SchemaExtraCallable']
|
||||
json_loads: Callable[[str], object]
|
||||
json_dumps: AnyArgTCallable[str]
|
||||
json_encoders: Dict[Type[object], AnyCallable]
|
||||
underscore_attrs_are_private: bool
|
||||
allow_inf_nan: bool
|
||||
copy_on_model_validation: Literal['none', 'deep', 'shallow']
|
||||
# whether dataclass `__post_init__` should be run after validation
|
||||
post_init_call: Literal['before_validation', 'after_validation']
|
||||
|
||||
else:
|
||||
ConfigDict = dict # type: ignore
|
||||
|
||||
|
||||
class BaseConfig:
|
||||
title: Optional[str] = None
|
||||
anystr_lower: bool = False
|
||||
anystr_upper: bool = False
|
||||
anystr_strip_whitespace: bool = False
|
||||
min_anystr_length: int = 0
|
||||
max_anystr_length: Optional[int] = None
|
||||
validate_all: bool = False
|
||||
extra: Extra = Extra.ignore
|
||||
allow_mutation: bool = True
|
||||
frozen: bool = False
|
||||
allow_population_by_field_name: bool = False
|
||||
use_enum_values: bool = False
|
||||
fields: Dict[str, Union[str, Dict[str, str]]] = {}
|
||||
validate_assignment: bool = False
|
||||
error_msg_templates: Dict[str, str] = {}
|
||||
arbitrary_types_allowed: bool = False
|
||||
orm_mode: bool = False
|
||||
getter_dict: Type[GetterDict] = GetterDict
|
||||
alias_generator: Optional[Callable[[str], str]] = None
|
||||
keep_untouched: Tuple[type, ...] = ()
|
||||
schema_extra: Union[Dict[str, Any], 'SchemaExtraCallable'] = {}
|
||||
json_loads: Callable[[str], Any] = json.loads
|
||||
json_dumps: Callable[..., str] = json.dumps
|
||||
json_encoders: Dict[Union[Type[Any], str, ForwardRef], AnyCallable] = {}
|
||||
underscore_attrs_are_private: bool = False
|
||||
allow_inf_nan: bool = True
|
||||
|
||||
# whether inherited models as fields should be reconstructed as base model,
|
||||
# and whether such a copy should be shallow or deep
|
||||
copy_on_model_validation: Literal['none', 'deep', 'shallow'] = 'shallow'
|
||||
|
||||
# whether `Union` should check all allowed types before even trying to coerce
|
||||
smart_union: bool = False
|
||||
# whether dataclass `__post_init__` should be run before or after validation
|
||||
post_init_call: Literal['before_validation', 'after_validation'] = 'before_validation'
|
||||
|
||||
@classmethod
|
||||
def get_field_info(cls, name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get properties of FieldInfo from the `fields` property of the config class.
|
||||
"""
|
||||
|
||||
fields_value = cls.fields.get(name)
|
||||
|
||||
if isinstance(fields_value, str):
|
||||
field_info: Dict[str, Any] = {'alias': fields_value}
|
||||
elif isinstance(fields_value, dict):
|
||||
field_info = fields_value
|
||||
else:
|
||||
field_info = {}
|
||||
|
||||
if 'alias' in field_info:
|
||||
field_info.setdefault('alias_priority', 2)
|
||||
|
||||
if field_info.get('alias_priority', 0) <= 1 and cls.alias_generator:
|
||||
alias = cls.alias_generator(name)
|
||||
if not isinstance(alias, str):
|
||||
raise TypeError(f'Config.alias_generator must return str, not {alias.__class__}')
|
||||
field_info.update(alias=alias, alias_priority=1)
|
||||
return field_info
|
||||
|
||||
@classmethod
|
||||
def prepare_field(cls, field: 'ModelField') -> None:
|
||||
"""
|
||||
Optional hook to check or modify fields during model creation.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def get_config(config: Union[ConfigDict, Type[object], None]) -> Type[BaseConfig]:
|
||||
if config is None:
|
||||
return BaseConfig
|
||||
|
||||
else:
|
||||
config_dict = (
|
||||
config
|
||||
if isinstance(config, dict)
|
||||
else {k: getattr(config, k) for k in dir(config) if not k.startswith('__')}
|
||||
)
|
||||
|
||||
class Config(BaseConfig):
|
||||
...
|
||||
|
||||
for k, v in config_dict.items():
|
||||
setattr(Config, k, v)
|
||||
return Config
|
||||
|
||||
|
||||
def inherit_config(self_config: 'ConfigType', parent_config: 'ConfigType', **namespace: Any) -> 'ConfigType':
|
||||
if not self_config:
|
||||
base_classes: Tuple['ConfigType', ...] = (parent_config,)
|
||||
elif self_config == parent_config:
|
||||
base_classes = (self_config,)
|
||||
else:
|
||||
base_classes = self_config, parent_config
|
||||
|
||||
namespace['json_encoders'] = {
|
||||
**getattr(parent_config, 'json_encoders', {}),
|
||||
**getattr(self_config, 'json_encoders', {}),
|
||||
**namespace.get('json_encoders', {}),
|
||||
}
|
||||
|
||||
return type('Config', base_classes, namespace)
|
||||
|
||||
|
||||
def prepare_config(config: Type[BaseConfig], cls_name: str) -> None:
|
||||
if not isinstance(config.extra, Extra):
|
||||
try:
|
||||
config.extra = Extra(config.extra)
|
||||
except ValueError:
|
||||
raise ValueError(f'"{cls_name}": {config.extra} is not a valid value for "extra"')
|
||||
@@ -0,0 +1,500 @@
|
||||
"""
|
||||
The main purpose is to enhance stdlib dataclasses by adding validation
|
||||
A pydantic dataclass can be generated from scratch or from a stdlib one.
|
||||
|
||||
Behind the scene, a pydantic dataclass is just like a regular one on which we attach
|
||||
a `BaseModel` and magic methods to trigger the validation of the data.
|
||||
`__init__` and `__post_init__` are hence overridden and have extra logic to be
|
||||
able to validate input data.
|
||||
|
||||
When a pydantic dataclass is generated from scratch, it's just a plain dataclass
|
||||
with validation triggered at initialization
|
||||
|
||||
The tricky part if for stdlib dataclasses that are converted after into pydantic ones e.g.
|
||||
|
||||
```py
|
||||
@dataclasses.dataclass
|
||||
class M:
|
||||
x: int
|
||||
|
||||
ValidatedM = pydantic.dataclasses.dataclass(M)
|
||||
```
|
||||
|
||||
We indeed still want to support equality, hashing, repr, ... as if it was the stdlib one!
|
||||
|
||||
```py
|
||||
assert isinstance(ValidatedM(x=1), M)
|
||||
assert ValidatedM(x=1) == M(x=1)
|
||||
```
|
||||
|
||||
This means we **don't want to create a new dataclass that inherits from it**
|
||||
The trick is to create a wrapper around `M` that will act as a proxy to trigger
|
||||
validation without altering default `M` behaviour.
|
||||
"""
|
||||
import copy
|
||||
import dataclasses
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
|
||||
try:
|
||||
from functools import cached_property
|
||||
except ImportError:
|
||||
# cached_property available only for python3.8+
|
||||
pass
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Optional, Type, TypeVar, Union, overload
|
||||
|
||||
from typing_extensions import dataclass_transform
|
||||
|
||||
from pydantic.v1.class_validators import gather_all_validators
|
||||
from pydantic.v1.config import BaseConfig, ConfigDict, Extra, get_config
|
||||
from pydantic.v1.error_wrappers import ValidationError
|
||||
from pydantic.v1.errors import DataclassTypeError
|
||||
from pydantic.v1.fields import Field, FieldInfo, Required, Undefined
|
||||
from pydantic.v1.main import create_model, validate_model
|
||||
from pydantic.v1.utils import ClassAttribute
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.main import BaseModel
|
||||
from pydantic.v1.typing import CallableGenerator, NoArgAnyCallable
|
||||
|
||||
DataclassT = TypeVar('DataclassT', bound='Dataclass')
|
||||
|
||||
DataclassClassOrWrapper = Union[Type['Dataclass'], 'DataclassProxy']
|
||||
|
||||
class Dataclass:
|
||||
# stdlib attributes
|
||||
__dataclass_fields__: ClassVar[Dict[str, Any]]
|
||||
__dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams`
|
||||
__post_init__: ClassVar[Callable[..., None]]
|
||||
|
||||
# Added by pydantic
|
||||
__pydantic_run_validation__: ClassVar[bool]
|
||||
__post_init_post_parse__: ClassVar[Callable[..., None]]
|
||||
__pydantic_initialised__: ClassVar[bool]
|
||||
__pydantic_model__: ClassVar[Type[BaseModel]]
|
||||
__pydantic_validate_values__: ClassVar[Callable[['Dataclass'], None]]
|
||||
__pydantic_has_field_info_default__: ClassVar[bool] # whether a `pydantic.Field` is used as default value
|
||||
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls: Type['Dataclass']) -> 'CallableGenerator':
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
'dataclass',
|
||||
'set_validation',
|
||||
'create_pydantic_model_from_dataclass',
|
||||
'is_builtin_dataclass',
|
||||
'make_dataclass_validator',
|
||||
]
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
|
||||
@overload
|
||||
def dataclass(
|
||||
*,
|
||||
init: bool = True,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool = False,
|
||||
config: Union[ConfigDict, Type[object], None] = None,
|
||||
validate_on_init: Optional[bool] = None,
|
||||
use_proxy: Optional[bool] = None,
|
||||
kw_only: bool = ...,
|
||||
) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
|
||||
...
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
|
||||
@overload
|
||||
def dataclass(
|
||||
_cls: Type[_T],
|
||||
*,
|
||||
init: bool = True,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool = False,
|
||||
config: Union[ConfigDict, Type[object], None] = None,
|
||||
validate_on_init: Optional[bool] = None,
|
||||
use_proxy: Optional[bool] = None,
|
||||
kw_only: bool = ...,
|
||||
) -> 'DataclassClassOrWrapper':
|
||||
...
|
||||
|
||||
else:
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
|
||||
@overload
|
||||
def dataclass(
|
||||
*,
|
||||
init: bool = True,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool = False,
|
||||
config: Union[ConfigDict, Type[object], None] = None,
|
||||
validate_on_init: Optional[bool] = None,
|
||||
use_proxy: Optional[bool] = None,
|
||||
) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
|
||||
...
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
|
||||
@overload
|
||||
def dataclass(
|
||||
_cls: Type[_T],
|
||||
*,
|
||||
init: bool = True,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool = False,
|
||||
config: Union[ConfigDict, Type[object], None] = None,
|
||||
validate_on_init: Optional[bool] = None,
|
||||
use_proxy: Optional[bool] = None,
|
||||
) -> 'DataclassClassOrWrapper':
|
||||
...
|
||||
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
|
||||
def dataclass(
|
||||
_cls: Optional[Type[_T]] = None,
|
||||
*,
|
||||
init: bool = True,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool = False,
|
||||
config: Union[ConfigDict, Type[object], None] = None,
|
||||
validate_on_init: Optional[bool] = None,
|
||||
use_proxy: Optional[bool] = None,
|
||||
kw_only: bool = False,
|
||||
) -> Union[Callable[[Type[_T]], 'DataclassClassOrWrapper'], 'DataclassClassOrWrapper']:
|
||||
"""
|
||||
Like the python standard lib dataclasses but with type validation.
|
||||
The result is either a pydantic dataclass that will validate input data
|
||||
or a wrapper that will trigger validation around a stdlib dataclass
|
||||
to avoid modifying it directly
|
||||
"""
|
||||
the_config = get_config(config)
|
||||
|
||||
def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper':
|
||||
should_use_proxy = (
|
||||
use_proxy
|
||||
if use_proxy is not None
|
||||
else (
|
||||
is_builtin_dataclass(cls)
|
||||
and (cls.__bases__[0] is object or set(dir(cls)) == set(dir(cls.__bases__[0])))
|
||||
)
|
||||
)
|
||||
if should_use_proxy:
|
||||
dc_cls_doc = ''
|
||||
dc_cls = DataclassProxy(cls)
|
||||
default_validate_on_init = False
|
||||
else:
|
||||
dc_cls_doc = cls.__doc__ or '' # needs to be done before generating dataclass
|
||||
if sys.version_info >= (3, 10):
|
||||
dc_cls = dataclasses.dataclass(
|
||||
cls,
|
||||
init=init,
|
||||
repr=repr,
|
||||
eq=eq,
|
||||
order=order,
|
||||
unsafe_hash=unsafe_hash,
|
||||
frozen=frozen,
|
||||
kw_only=kw_only,
|
||||
)
|
||||
else:
|
||||
dc_cls = dataclasses.dataclass( # type: ignore
|
||||
cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
|
||||
)
|
||||
default_validate_on_init = True
|
||||
|
||||
should_validate_on_init = default_validate_on_init if validate_on_init is None else validate_on_init
|
||||
_add_pydantic_validation_attributes(cls, the_config, should_validate_on_init, dc_cls_doc)
|
||||
dc_cls.__pydantic_model__.__try_update_forward_refs__(**{cls.__name__: cls})
|
||||
return dc_cls
|
||||
|
||||
if _cls is None:
|
||||
return wrap
|
||||
|
||||
return wrap(_cls)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_validation(cls: Type['DataclassT'], value: bool) -> Generator[Type['DataclassT'], None, None]:
|
||||
original_run_validation = cls.__pydantic_run_validation__
|
||||
try:
|
||||
cls.__pydantic_run_validation__ = value
|
||||
yield cls
|
||||
finally:
|
||||
cls.__pydantic_run_validation__ = original_run_validation
|
||||
|
||||
|
||||
class DataclassProxy:
|
||||
__slots__ = '__dataclass__'
|
||||
|
||||
def __init__(self, dc_cls: Type['Dataclass']) -> None:
|
||||
object.__setattr__(self, '__dataclass__', dc_cls)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
with set_validation(self.__dataclass__, True):
|
||||
return self.__dataclass__(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self.__dataclass__, name)
|
||||
|
||||
def __setattr__(self, __name: str, __value: Any) -> None:
|
||||
return setattr(self.__dataclass__, __name, __value)
|
||||
|
||||
def __instancecheck__(self, instance: Any) -> bool:
|
||||
return isinstance(instance, self.__dataclass__)
|
||||
|
||||
def __copy__(self) -> 'DataclassProxy':
|
||||
return DataclassProxy(copy.copy(self.__dataclass__))
|
||||
|
||||
def __deepcopy__(self, memo: Any) -> 'DataclassProxy':
|
||||
return DataclassProxy(copy.deepcopy(self.__dataclass__, memo))
|
||||
|
||||
|
||||
def _add_pydantic_validation_attributes( # noqa: C901 (ignore complexity)
|
||||
dc_cls: Type['Dataclass'],
|
||||
config: Type[BaseConfig],
|
||||
validate_on_init: bool,
|
||||
dc_cls_doc: str,
|
||||
) -> None:
|
||||
"""
|
||||
We need to replace the right method. If no `__post_init__` has been set in the stdlib dataclass
|
||||
it won't even exist (code is generated on the fly by `dataclasses`)
|
||||
By default, we run validation after `__init__` or `__post_init__` if defined
|
||||
"""
|
||||
init = dc_cls.__init__
|
||||
|
||||
@wraps(init)
|
||||
def handle_extra_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
|
||||
if config.extra == Extra.ignore:
|
||||
init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__})
|
||||
|
||||
elif config.extra == Extra.allow:
|
||||
for k, v in kwargs.items():
|
||||
self.__dict__.setdefault(k, v)
|
||||
init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__})
|
||||
|
||||
else:
|
||||
init(self, *args, **kwargs)
|
||||
|
||||
if hasattr(dc_cls, '__post_init__'):
|
||||
try:
|
||||
post_init = dc_cls.__post_init__.__wrapped__ # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
post_init = dc_cls.__post_init__
|
||||
|
||||
@wraps(post_init)
|
||||
def new_post_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
|
||||
if config.post_init_call == 'before_validation':
|
||||
post_init(self, *args, **kwargs)
|
||||
|
||||
if self.__class__.__pydantic_run_validation__:
|
||||
self.__pydantic_validate_values__()
|
||||
if hasattr(self, '__post_init_post_parse__'):
|
||||
self.__post_init_post_parse__(*args, **kwargs)
|
||||
|
||||
if config.post_init_call == 'after_validation':
|
||||
post_init(self, *args, **kwargs)
|
||||
|
||||
setattr(dc_cls, '__init__', handle_extra_init)
|
||||
setattr(dc_cls, '__post_init__', new_post_init)
|
||||
|
||||
else:
|
||||
|
||||
@wraps(init)
|
||||
def new_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
|
||||
handle_extra_init(self, *args, **kwargs)
|
||||
|
||||
if self.__class__.__pydantic_run_validation__:
|
||||
self.__pydantic_validate_values__()
|
||||
|
||||
if hasattr(self, '__post_init_post_parse__'):
|
||||
# We need to find again the initvars. To do that we use `__dataclass_fields__` instead of
|
||||
# public method `dataclasses.fields`
|
||||
|
||||
# get all initvars and their default values
|
||||
initvars_and_values: Dict[str, Any] = {}
|
||||
for i, f in enumerate(self.__class__.__dataclass_fields__.values()):
|
||||
if f._field_type is dataclasses._FIELD_INITVAR: # type: ignore[attr-defined]
|
||||
try:
|
||||
# set arg value by default
|
||||
initvars_and_values[f.name] = args[i]
|
||||
except IndexError:
|
||||
initvars_and_values[f.name] = kwargs.get(f.name, f.default)
|
||||
|
||||
self.__post_init_post_parse__(**initvars_and_values)
|
||||
|
||||
setattr(dc_cls, '__init__', new_init)
|
||||
|
||||
setattr(dc_cls, '__pydantic_run_validation__', ClassAttribute('__pydantic_run_validation__', validate_on_init))
|
||||
setattr(dc_cls, '__pydantic_initialised__', False)
|
||||
setattr(dc_cls, '__pydantic_model__', create_pydantic_model_from_dataclass(dc_cls, config, dc_cls_doc))
|
||||
setattr(dc_cls, '__pydantic_validate_values__', _dataclass_validate_values)
|
||||
setattr(dc_cls, '__validate__', classmethod(_validate_dataclass))
|
||||
setattr(dc_cls, '__get_validators__', classmethod(_get_validators))
|
||||
|
||||
if dc_cls.__pydantic_model__.__config__.validate_assignment and not dc_cls.__dataclass_params__.frozen:
|
||||
setattr(dc_cls, '__setattr__', _dataclass_validate_assignment_setattr)
|
||||
|
||||
|
||||
def _get_validators(cls: 'DataclassClassOrWrapper') -> 'CallableGenerator':
|
||||
yield cls.__validate__
|
||||
|
||||
|
||||
def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
|
||||
with set_validation(cls, True):
|
||||
if isinstance(v, cls):
|
||||
v.__pydantic_validate_values__()
|
||||
return v
|
||||
elif isinstance(v, (list, tuple)):
|
||||
return cls(*v)
|
||||
elif isinstance(v, dict):
|
||||
return cls(**v)
|
||||
else:
|
||||
raise DataclassTypeError(class_name=cls.__name__)
|
||||
|
||||
|
||||
def create_pydantic_model_from_dataclass(
|
||||
dc_cls: Type['Dataclass'],
|
||||
config: Type[Any] = BaseConfig,
|
||||
dc_cls_doc: Optional[str] = None,
|
||||
) -> Type['BaseModel']:
|
||||
field_definitions: Dict[str, Any] = {}
|
||||
for field in dataclasses.fields(dc_cls):
|
||||
default: Any = Undefined
|
||||
default_factory: Optional['NoArgAnyCallable'] = None
|
||||
field_info: FieldInfo
|
||||
|
||||
if field.default is not dataclasses.MISSING:
|
||||
default = field.default
|
||||
elif field.default_factory is not dataclasses.MISSING:
|
||||
default_factory = field.default_factory
|
||||
else:
|
||||
default = Required
|
||||
|
||||
if isinstance(default, FieldInfo):
|
||||
field_info = default
|
||||
dc_cls.__pydantic_has_field_info_default__ = True
|
||||
else:
|
||||
field_info = Field(default=default, default_factory=default_factory, **field.metadata)
|
||||
|
||||
field_definitions[field.name] = (field.type, field_info)
|
||||
|
||||
validators = gather_all_validators(dc_cls)
|
||||
model: Type['BaseModel'] = create_model(
|
||||
dc_cls.__name__,
|
||||
__config__=config,
|
||||
__module__=dc_cls.__module__,
|
||||
__validators__=validators,
|
||||
__cls_kwargs__={'__resolve_forward_refs__': False},
|
||||
**field_definitions,
|
||||
)
|
||||
model.__doc__ = dc_cls_doc if dc_cls_doc is not None else dc_cls.__doc__ or ''
|
||||
return model
|
||||
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
|
||||
def _is_field_cached_property(obj: 'Dataclass', k: str) -> bool:
|
||||
return isinstance(getattr(type(obj), k, None), cached_property)
|
||||
|
||||
else:
|
||||
|
||||
def _is_field_cached_property(obj: 'Dataclass', k: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _dataclass_validate_values(self: 'Dataclass') -> None:
|
||||
# validation errors can occur if this function is called twice on an already initialised dataclass.
|
||||
# for example if Extra.forbid is enabled, it would consider __pydantic_initialised__ an invalid extra property
|
||||
if getattr(self, '__pydantic_initialised__'):
|
||||
return
|
||||
if getattr(self, '__pydantic_has_field_info_default__', False):
|
||||
# We need to remove `FieldInfo` values since they are not valid as input
|
||||
# It's ok to do that because they are obviously the default values!
|
||||
input_data = {
|
||||
k: v
|
||||
for k, v in self.__dict__.items()
|
||||
if not (isinstance(v, FieldInfo) or _is_field_cached_property(self, k))
|
||||
}
|
||||
else:
|
||||
input_data = {k: v for k, v in self.__dict__.items() if not _is_field_cached_property(self, k)}
|
||||
d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__)
|
||||
if validation_error:
|
||||
raise validation_error
|
||||
self.__dict__.update(d)
|
||||
object.__setattr__(self, '__pydantic_initialised__', True)
|
||||
|
||||
|
||||
def _dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value: Any) -> None:
|
||||
if self.__pydantic_initialised__:
|
||||
d = dict(self.__dict__)
|
||||
d.pop(name, None)
|
||||
known_field = self.__pydantic_model__.__fields__.get(name, None)
|
||||
if known_field:
|
||||
value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__)
|
||||
if error_:
|
||||
raise ValidationError([error_], self.__class__)
|
||||
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
|
||||
def is_builtin_dataclass(_cls: Type[Any]) -> bool:
|
||||
"""
|
||||
Whether a class is a stdlib dataclass
|
||||
(useful to discriminated a pydantic dataclass that is actually a wrapper around a stdlib dataclass)
|
||||
|
||||
we check that
|
||||
- `_cls` is a dataclass
|
||||
- `_cls` is not a processed pydantic dataclass (with a basemodel attached)
|
||||
- `_cls` is not a pydantic dataclass inheriting directly from a stdlib dataclass
|
||||
e.g.
|
||||
```
|
||||
@dataclasses.dataclass
|
||||
class A:
|
||||
x: int
|
||||
|
||||
@pydantic.dataclasses.dataclass
|
||||
class B(A):
|
||||
y: int
|
||||
```
|
||||
In this case, when we first check `B`, we make an extra check and look at the annotations ('y'),
|
||||
which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x')
|
||||
"""
|
||||
return (
|
||||
dataclasses.is_dataclass(_cls)
|
||||
and not hasattr(_cls, '__pydantic_model__')
|
||||
and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {})))
|
||||
)
|
||||
|
||||
|
||||
def make_dataclass_validator(dc_cls: Type['Dataclass'], config: Type[BaseConfig]) -> 'CallableGenerator':
|
||||
"""
|
||||
Create a pydantic.dataclass from a builtin dataclass to add type validation
|
||||
and yield the validators
|
||||
It retrieves the parameters of the dataclass and forwards them to the newly created dataclass
|
||||
"""
|
||||
yield from _get_validators(dataclass(dc_cls, config=config, use_proxy=True))
|
||||
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
Functions to parse datetime objects.
|
||||
|
||||
We're using regular expressions rather than time.strptime because:
|
||||
- They provide both validation and parsing.
|
||||
- They're more flexible for datetimes.
|
||||
- The date/datetime/time constructors produce friendlier error messages.
|
||||
|
||||
Stolen from https://raw.githubusercontent.com/django/django/main/django/utils/dateparse.py at
|
||||
9718fa2e8abe430c3526a9278dd976443d4ae3c6
|
||||
|
||||
Changed to:
|
||||
* use standard python datetime types not django.utils.timezone
|
||||
* raise ValueError when regex doesn't match rather than returning None
|
||||
* support parsing unix timestamps for dates and datetimes
|
||||
"""
|
||||
import re
|
||||
from datetime import date, datetime, time, timedelta, timezone
|
||||
from typing import Dict, Optional, Type, Union
|
||||
|
||||
from pydantic.v1 import errors
|
||||
|
||||
date_expr = r'(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})'
|
||||
time_expr = (
|
||||
r'(?P<hour>\d{1,2}):(?P<minute>\d{1,2})'
|
||||
r'(?::(?P<second>\d{1,2})(?:\.(?P<microsecond>\d{1,6})\d{0,6})?)?'
|
||||
r'(?P<tzinfo>Z|[+-]\d{2}(?::?\d{2})?)?$'
|
||||
)
|
||||
|
||||
date_re = re.compile(f'{date_expr}$')
|
||||
time_re = re.compile(time_expr)
|
||||
datetime_re = re.compile(f'{date_expr}[T ]{time_expr}')
|
||||
|
||||
standard_duration_re = re.compile(
|
||||
r'^'
|
||||
r'(?:(?P<days>-?\d+) (days?, )?)?'
|
||||
r'((?:(?P<hours>-?\d+):)(?=\d+:\d+))?'
|
||||
r'(?:(?P<minutes>-?\d+):)?'
|
||||
r'(?P<seconds>-?\d+)'
|
||||
r'(?:\.(?P<microseconds>\d{1,6})\d{0,6})?'
|
||||
r'$'
|
||||
)
|
||||
|
||||
# Support the sections of ISO 8601 date representation that are accepted by timedelta
|
||||
iso8601_duration_re = re.compile(
|
||||
r'^(?P<sign>[-+]?)'
|
||||
r'P'
|
||||
r'(?:(?P<days>\d+(.\d+)?)D)?'
|
||||
r'(?:T'
|
||||
r'(?:(?P<hours>\d+(.\d+)?)H)?'
|
||||
r'(?:(?P<minutes>\d+(.\d+)?)M)?'
|
||||
r'(?:(?P<seconds>\d+(.\d+)?)S)?'
|
||||
r')?'
|
||||
r'$'
|
||||
)
|
||||
|
||||
EPOCH = datetime(1970, 1, 1)
|
||||
# if greater than this, the number is in ms, if less than or equal it's in seconds
|
||||
# (in seconds this is 11th October 2603, in ms it's 20th August 1970)
|
||||
MS_WATERSHED = int(2e10)
|
||||
# slightly more than datetime.max in ns - (datetime.max - EPOCH).total_seconds() * 1e9
|
||||
MAX_NUMBER = int(3e20)
|
||||
StrBytesIntFloat = Union[str, bytes, int, float]
|
||||
|
||||
|
||||
def get_numeric(value: StrBytesIntFloat, native_expected_type: str) -> Union[None, int, float]:
|
||||
if isinstance(value, (int, float)):
|
||||
return value
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
except TypeError:
|
||||
raise TypeError(f'invalid type; expected {native_expected_type}, string, bytes, int or float')
|
||||
|
||||
|
||||
def from_unix_seconds(seconds: Union[int, float]) -> datetime:
|
||||
if seconds > MAX_NUMBER:
|
||||
return datetime.max
|
||||
elif seconds < -MAX_NUMBER:
|
||||
return datetime.min
|
||||
|
||||
while abs(seconds) > MS_WATERSHED:
|
||||
seconds /= 1000
|
||||
dt = EPOCH + timedelta(seconds=seconds)
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _parse_timezone(value: Optional[str], error: Type[Exception]) -> Union[None, int, timezone]:
|
||||
if value == 'Z':
|
||||
return timezone.utc
|
||||
elif value is not None:
|
||||
offset_mins = int(value[-2:]) if len(value) > 3 else 0
|
||||
offset = 60 * int(value[1:3]) + offset_mins
|
||||
if value[0] == '-':
|
||||
offset = -offset
|
||||
try:
|
||||
return timezone(timedelta(minutes=offset))
|
||||
except ValueError:
|
||||
raise error()
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def parse_date(value: Union[date, StrBytesIntFloat]) -> date:
|
||||
"""
|
||||
Parse a date/int/float/string and return a datetime.date.
|
||||
|
||||
Raise ValueError if the input is well formatted but not a valid date.
|
||||
Raise ValueError if the input isn't well formatted.
|
||||
"""
|
||||
if isinstance(value, date):
|
||||
if isinstance(value, datetime):
|
||||
return value.date()
|
||||
else:
|
||||
return value
|
||||
|
||||
number = get_numeric(value, 'date')
|
||||
if number is not None:
|
||||
return from_unix_seconds(number).date()
|
||||
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode()
|
||||
|
||||
match = date_re.match(value) # type: ignore
|
||||
if match is None:
|
||||
raise errors.DateError()
|
||||
|
||||
kw = {k: int(v) for k, v in match.groupdict().items()}
|
||||
|
||||
try:
|
||||
return date(**kw)
|
||||
except ValueError:
|
||||
raise errors.DateError()
|
||||
|
||||
|
||||
def parse_time(value: Union[time, StrBytesIntFloat]) -> time:
|
||||
"""
|
||||
Parse a time/string and return a datetime.time.
|
||||
|
||||
Raise ValueError if the input is well formatted but not a valid time.
|
||||
Raise ValueError if the input isn't well formatted, in particular if it contains an offset.
|
||||
"""
|
||||
if isinstance(value, time):
|
||||
return value
|
||||
|
||||
number = get_numeric(value, 'time')
|
||||
if number is not None:
|
||||
if number >= 86400:
|
||||
# doesn't make sense since the time time loop back around to 0
|
||||
raise errors.TimeError()
|
||||
return (datetime.min + timedelta(seconds=number)).time()
|
||||
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode()
|
||||
|
||||
match = time_re.match(value) # type: ignore
|
||||
if match is None:
|
||||
raise errors.TimeError()
|
||||
|
||||
kw = match.groupdict()
|
||||
if kw['microsecond']:
|
||||
kw['microsecond'] = kw['microsecond'].ljust(6, '0')
|
||||
|
||||
tzinfo = _parse_timezone(kw.pop('tzinfo'), errors.TimeError)
|
||||
kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None}
|
||||
kw_['tzinfo'] = tzinfo
|
||||
|
||||
try:
|
||||
return time(**kw_) # type: ignore
|
||||
except ValueError:
|
||||
raise errors.TimeError()
|
||||
|
||||
|
||||
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime:
|
||||
"""
|
||||
Parse a datetime/int/float/string and return a datetime.datetime.
|
||||
|
||||
This function supports time zone offsets. When the input contains one,
|
||||
the output uses a timezone with a fixed offset from UTC.
|
||||
|
||||
Raise ValueError if the input is well formatted but not a valid datetime.
|
||||
Raise ValueError if the input isn't well formatted.
|
||||
"""
|
||||
if isinstance(value, datetime):
|
||||
return value
|
||||
|
||||
number = get_numeric(value, 'datetime')
|
||||
if number is not None:
|
||||
return from_unix_seconds(number)
|
||||
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode()
|
||||
|
||||
match = datetime_re.match(value) # type: ignore
|
||||
if match is None:
|
||||
raise errors.DateTimeError()
|
||||
|
||||
kw = match.groupdict()
|
||||
if kw['microsecond']:
|
||||
kw['microsecond'] = kw['microsecond'].ljust(6, '0')
|
||||
|
||||
tzinfo = _parse_timezone(kw.pop('tzinfo'), errors.DateTimeError)
|
||||
kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None}
|
||||
kw_['tzinfo'] = tzinfo
|
||||
|
||||
try:
|
||||
return datetime(**kw_) # type: ignore
|
||||
except ValueError:
|
||||
raise errors.DateTimeError()
|
||||
|
||||
|
||||
def parse_duration(value: StrBytesIntFloat) -> timedelta:
|
||||
"""
|
||||
Parse a duration int/float/string and return a datetime.timedelta.
|
||||
|
||||
The preferred format for durations in Django is '%d %H:%M:%S.%f'.
|
||||
|
||||
Also supports ISO 8601 representation.
|
||||
"""
|
||||
if isinstance(value, timedelta):
|
||||
return value
|
||||
|
||||
if isinstance(value, (int, float)):
|
||||
# below code requires a string
|
||||
value = f'{value:f}'
|
||||
elif isinstance(value, bytes):
|
||||
value = value.decode()
|
||||
|
||||
try:
|
||||
match = standard_duration_re.match(value) or iso8601_duration_re.match(value)
|
||||
except TypeError:
|
||||
raise TypeError('invalid type; expected timedelta, string, bytes, int or float')
|
||||
|
||||
if not match:
|
||||
raise errors.DurationError()
|
||||
|
||||
kw = match.groupdict()
|
||||
sign = -1 if kw.pop('sign', '+') == '-' else 1
|
||||
if kw.get('microseconds'):
|
||||
kw['microseconds'] = kw['microseconds'].ljust(6, '0')
|
||||
|
||||
if kw.get('seconds') and kw.get('microseconds') and kw['seconds'].startswith('-'):
|
||||
kw['microseconds'] = '-' + kw['microseconds']
|
||||
|
||||
kw_ = {k: float(v) for k, v in kw.items() if v is not None}
|
||||
|
||||
return sign * timedelta(**kw_)
|
||||
@@ -0,0 +1,264 @@
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload
|
||||
|
||||
from pydantic.v1 import validator
|
||||
from pydantic.v1.config import Extra
|
||||
from pydantic.v1.errors import ConfigError
|
||||
from pydantic.v1.main import BaseModel, create_model
|
||||
from pydantic.v1.typing import get_all_type_hints
|
||||
from pydantic.v1.utils import to_camel
|
||||
|
||||
__all__ = ('validate_arguments',)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.typing import AnyCallable
|
||||
|
||||
AnyCallableT = TypeVar('AnyCallableT', bound=AnyCallable)
|
||||
ConfigType = Union[None, Type[Any], Dict[str, Any]]
|
||||
|
||||
|
||||
@overload
|
||||
def validate_arguments(func: None = None, *, config: 'ConfigType' = None) -> Callable[['AnyCallableT'], 'AnyCallableT']:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def validate_arguments(func: 'AnyCallableT') -> 'AnyCallableT':
|
||||
...
|
||||
|
||||
|
||||
def validate_arguments(func: Optional['AnyCallableT'] = None, *, config: 'ConfigType' = None) -> Any:
|
||||
"""
|
||||
Decorator to validate the arguments passed to a function.
|
||||
"""
|
||||
|
||||
def validate(_func: 'AnyCallable') -> 'AnyCallable':
|
||||
vd = ValidatedFunction(_func, config)
|
||||
|
||||
@wraps(_func)
|
||||
def wrapper_function(*args: Any, **kwargs: Any) -> Any:
|
||||
return vd.call(*args, **kwargs)
|
||||
|
||||
wrapper_function.vd = vd # type: ignore
|
||||
wrapper_function.validate = vd.init_model_instance # type: ignore
|
||||
wrapper_function.raw_function = vd.raw_function # type: ignore
|
||||
wrapper_function.model = vd.model # type: ignore
|
||||
return wrapper_function
|
||||
|
||||
if func:
|
||||
return validate(func)
|
||||
else:
|
||||
return validate
|
||||
|
||||
|
||||
ALT_V_ARGS = 'v__args'
|
||||
ALT_V_KWARGS = 'v__kwargs'
|
||||
V_POSITIONAL_ONLY_NAME = 'v__positional_only'
|
||||
V_DUPLICATE_KWARGS = 'v__duplicate_kwargs'
|
||||
|
||||
|
||||
class ValidatedFunction:
|
||||
def __init__(self, function: 'AnyCallableT', config: 'ConfigType'): # noqa C901
|
||||
from inspect import Parameter, signature
|
||||
|
||||
parameters: Mapping[str, Parameter] = signature(function).parameters
|
||||
|
||||
if parameters.keys() & {ALT_V_ARGS, ALT_V_KWARGS, V_POSITIONAL_ONLY_NAME, V_DUPLICATE_KWARGS}:
|
||||
raise ConfigError(
|
||||
f'"{ALT_V_ARGS}", "{ALT_V_KWARGS}", "{V_POSITIONAL_ONLY_NAME}" and "{V_DUPLICATE_KWARGS}" '
|
||||
f'are not permitted as argument names when using the "{validate_arguments.__name__}" decorator'
|
||||
)
|
||||
|
||||
self.raw_function = function
|
||||
self.arg_mapping: Dict[int, str] = {}
|
||||
self.positional_only_args = set()
|
||||
self.v_args_name = 'args'
|
||||
self.v_kwargs_name = 'kwargs'
|
||||
|
||||
type_hints = get_all_type_hints(function)
|
||||
takes_args = False
|
||||
takes_kwargs = False
|
||||
fields: Dict[str, Tuple[Any, Any]] = {}
|
||||
for i, (name, p) in enumerate(parameters.items()):
|
||||
if p.annotation is p.empty:
|
||||
annotation = Any
|
||||
else:
|
||||
annotation = type_hints[name]
|
||||
|
||||
default = ... if p.default is p.empty else p.default
|
||||
if p.kind == Parameter.POSITIONAL_ONLY:
|
||||
self.arg_mapping[i] = name
|
||||
fields[name] = annotation, default
|
||||
fields[V_POSITIONAL_ONLY_NAME] = List[str], None
|
||||
self.positional_only_args.add(name)
|
||||
elif p.kind == Parameter.POSITIONAL_OR_KEYWORD:
|
||||
self.arg_mapping[i] = name
|
||||
fields[name] = annotation, default
|
||||
fields[V_DUPLICATE_KWARGS] = List[str], None
|
||||
elif p.kind == Parameter.KEYWORD_ONLY:
|
||||
fields[name] = annotation, default
|
||||
elif p.kind == Parameter.VAR_POSITIONAL:
|
||||
self.v_args_name = name
|
||||
fields[name] = Tuple[annotation, ...], None
|
||||
takes_args = True
|
||||
else:
|
||||
assert p.kind == Parameter.VAR_KEYWORD, p.kind
|
||||
self.v_kwargs_name = name
|
||||
fields[name] = Dict[str, annotation], None # type: ignore
|
||||
takes_kwargs = True
|
||||
|
||||
# these checks avoid a clash between "args" and a field with that name
|
||||
if not takes_args and self.v_args_name in fields:
|
||||
self.v_args_name = ALT_V_ARGS
|
||||
|
||||
# same with "kwargs"
|
||||
if not takes_kwargs and self.v_kwargs_name in fields:
|
||||
self.v_kwargs_name = ALT_V_KWARGS
|
||||
|
||||
if not takes_args:
|
||||
# we add the field so validation below can raise the correct exception
|
||||
fields[self.v_args_name] = List[Any], None
|
||||
|
||||
if not takes_kwargs:
|
||||
# same with kwargs
|
||||
fields[self.v_kwargs_name] = Dict[Any, Any], None
|
||||
|
||||
self.create_model(fields, takes_args, takes_kwargs, config)
|
||||
|
||||
def init_model_instance(self, *args: Any, **kwargs: Any) -> BaseModel:
|
||||
values = self.build_values(args, kwargs)
|
||||
return self.model(**values)
|
||||
|
||||
def call(self, *args: Any, **kwargs: Any) -> Any:
|
||||
m = self.init_model_instance(*args, **kwargs)
|
||||
return self.execute(m)
|
||||
|
||||
def build_values(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
values: Dict[str, Any] = {}
|
||||
if args:
|
||||
arg_iter = enumerate(args)
|
||||
while True:
|
||||
try:
|
||||
i, a = next(arg_iter)
|
||||
except StopIteration:
|
||||
break
|
||||
arg_name = self.arg_mapping.get(i)
|
||||
if arg_name is not None:
|
||||
values[arg_name] = a
|
||||
else:
|
||||
values[self.v_args_name] = [a] + [a for _, a in arg_iter]
|
||||
break
|
||||
|
||||
var_kwargs: Dict[str, Any] = {}
|
||||
wrong_positional_args = []
|
||||
duplicate_kwargs = []
|
||||
fields_alias = [
|
||||
field.alias
|
||||
for name, field in self.model.__fields__.items()
|
||||
if name not in (self.v_args_name, self.v_kwargs_name)
|
||||
]
|
||||
non_var_fields = set(self.model.__fields__) - {self.v_args_name, self.v_kwargs_name}
|
||||
for k, v in kwargs.items():
|
||||
if k in non_var_fields or k in fields_alias:
|
||||
if k in self.positional_only_args:
|
||||
wrong_positional_args.append(k)
|
||||
if k in values:
|
||||
duplicate_kwargs.append(k)
|
||||
values[k] = v
|
||||
else:
|
||||
var_kwargs[k] = v
|
||||
|
||||
if var_kwargs:
|
||||
values[self.v_kwargs_name] = var_kwargs
|
||||
if wrong_positional_args:
|
||||
values[V_POSITIONAL_ONLY_NAME] = wrong_positional_args
|
||||
if duplicate_kwargs:
|
||||
values[V_DUPLICATE_KWARGS] = duplicate_kwargs
|
||||
return values
|
||||
|
||||
def execute(self, m: BaseModel) -> Any:
|
||||
d = {k: v for k, v in m._iter() if k in m.__fields_set__ or m.__fields__[k].default_factory}
|
||||
var_kwargs = d.pop(self.v_kwargs_name, {})
|
||||
|
||||
if self.v_args_name in d:
|
||||
args_: List[Any] = []
|
||||
in_kwargs = False
|
||||
kwargs = {}
|
||||
for name, value in d.items():
|
||||
if in_kwargs:
|
||||
kwargs[name] = value
|
||||
elif name == self.v_args_name:
|
||||
args_ += value
|
||||
in_kwargs = True
|
||||
else:
|
||||
args_.append(value)
|
||||
return self.raw_function(*args_, **kwargs, **var_kwargs)
|
||||
elif self.positional_only_args:
|
||||
args_ = []
|
||||
kwargs = {}
|
||||
for name, value in d.items():
|
||||
if name in self.positional_only_args:
|
||||
args_.append(value)
|
||||
else:
|
||||
kwargs[name] = value
|
||||
return self.raw_function(*args_, **kwargs, **var_kwargs)
|
||||
else:
|
||||
return self.raw_function(**d, **var_kwargs)
|
||||
|
||||
def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool, config: 'ConfigType') -> None:
|
||||
pos_args = len(self.arg_mapping)
|
||||
|
||||
class CustomConfig:
|
||||
pass
|
||||
|
||||
if not TYPE_CHECKING: # pragma: no branch
|
||||
if isinstance(config, dict):
|
||||
CustomConfig = type('Config', (), config) # noqa: F811
|
||||
elif config is not None:
|
||||
CustomConfig = config # noqa: F811
|
||||
|
||||
if hasattr(CustomConfig, 'fields') or hasattr(CustomConfig, 'alias_generator'):
|
||||
raise ConfigError(
|
||||
'Setting the "fields" and "alias_generator" property on custom Config for '
|
||||
'@validate_arguments is not yet supported, please remove.'
|
||||
)
|
||||
|
||||
class DecoratorBaseModel(BaseModel):
|
||||
@validator(self.v_args_name, check_fields=False, allow_reuse=True)
|
||||
def check_args(cls, v: Optional[List[Any]]) -> Optional[List[Any]]:
|
||||
if takes_args or v is None:
|
||||
return v
|
||||
|
||||
raise TypeError(f'{pos_args} positional arguments expected but {pos_args + len(v)} given')
|
||||
|
||||
@validator(self.v_kwargs_name, check_fields=False, allow_reuse=True)
|
||||
def check_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
if takes_kwargs or v is None:
|
||||
return v
|
||||
|
||||
plural = '' if len(v) == 1 else 's'
|
||||
keys = ', '.join(map(repr, v.keys()))
|
||||
raise TypeError(f'unexpected keyword argument{plural}: {keys}')
|
||||
|
||||
@validator(V_POSITIONAL_ONLY_NAME, check_fields=False, allow_reuse=True)
|
||||
def check_positional_only(cls, v: Optional[List[str]]) -> None:
|
||||
if v is None:
|
||||
return
|
||||
|
||||
plural = '' if len(v) == 1 else 's'
|
||||
keys = ', '.join(map(repr, v))
|
||||
raise TypeError(f'positional-only argument{plural} passed as keyword argument{plural}: {keys}')
|
||||
|
||||
@validator(V_DUPLICATE_KWARGS, check_fields=False, allow_reuse=True)
|
||||
def check_duplicate_kwargs(cls, v: Optional[List[str]]) -> None:
|
||||
if v is None:
|
||||
return
|
||||
|
||||
plural = '' if len(v) == 1 else 's'
|
||||
keys = ', '.join(map(repr, v))
|
||||
raise TypeError(f'multiple values for argument{plural}: {keys}')
|
||||
|
||||
class Config(CustomConfig):
|
||||
extra = getattr(CustomConfig, 'extra', Extra.forbid)
|
||||
|
||||
self.model = create_model(to_camel(self.raw_function.__name__), __base__=DecoratorBaseModel, **fields)
|
||||
@@ -0,0 +1,350 @@
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import AbstractSet, Any, Callable, ClassVar, Dict, List, Mapping, Optional, Tuple, Type, Union
|
||||
|
||||
from pydantic.v1.config import BaseConfig, Extra
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic.v1.main import BaseModel
|
||||
from pydantic.v1.types import JsonWrapper
|
||||
from pydantic.v1.typing import StrPath, display_as_type, get_origin, is_union
|
||||
from pydantic.v1.utils import deep_update, lenient_issubclass, path_type, sequence_like
|
||||
|
||||
env_file_sentinel = str(object())
|
||||
|
||||
SettingsSourceCallable = Callable[['BaseSettings'], Dict[str, Any]]
|
||||
DotenvType = Union[StrPath, List[StrPath], Tuple[StrPath, ...]]
|
||||
|
||||
|
||||
class SettingsError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class BaseSettings(BaseModel):
|
||||
"""
|
||||
Base class for settings, allowing values to be overridden by environment variables.
|
||||
|
||||
This is useful in production for secrets you do not wish to save in code, it plays nicely with docker(-compose),
|
||||
Heroku and any 12 factor app design.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
__pydantic_self__,
|
||||
_env_file: Optional[DotenvType] = env_file_sentinel,
|
||||
_env_file_encoding: Optional[str] = None,
|
||||
_env_nested_delimiter: Optional[str] = None,
|
||||
_secrets_dir: Optional[StrPath] = None,
|
||||
**values: Any,
|
||||
) -> None:
|
||||
# Uses something other than `self` the first arg to allow "self" as a settable attribute
|
||||
super().__init__(
|
||||
**__pydantic_self__._build_values(
|
||||
values,
|
||||
_env_file=_env_file,
|
||||
_env_file_encoding=_env_file_encoding,
|
||||
_env_nested_delimiter=_env_nested_delimiter,
|
||||
_secrets_dir=_secrets_dir,
|
||||
)
|
||||
)
|
||||
|
||||
def _build_values(
|
||||
self,
|
||||
init_kwargs: Dict[str, Any],
|
||||
_env_file: Optional[DotenvType] = None,
|
||||
_env_file_encoding: Optional[str] = None,
|
||||
_env_nested_delimiter: Optional[str] = None,
|
||||
_secrets_dir: Optional[StrPath] = None,
|
||||
) -> Dict[str, Any]:
|
||||
# Configure built-in sources
|
||||
init_settings = InitSettingsSource(init_kwargs=init_kwargs)
|
||||
env_settings = EnvSettingsSource(
|
||||
env_file=(_env_file if _env_file != env_file_sentinel else self.__config__.env_file),
|
||||
env_file_encoding=(
|
||||
_env_file_encoding if _env_file_encoding is not None else self.__config__.env_file_encoding
|
||||
),
|
||||
env_nested_delimiter=(
|
||||
_env_nested_delimiter if _env_nested_delimiter is not None else self.__config__.env_nested_delimiter
|
||||
),
|
||||
env_prefix_len=len(self.__config__.env_prefix),
|
||||
)
|
||||
file_secret_settings = SecretsSettingsSource(secrets_dir=_secrets_dir or self.__config__.secrets_dir)
|
||||
# Provide a hook to set built-in sources priority and add / remove sources
|
||||
sources = self.__config__.customise_sources(
|
||||
init_settings=init_settings, env_settings=env_settings, file_secret_settings=file_secret_settings
|
||||
)
|
||||
if sources:
|
||||
return deep_update(*reversed([source(self) for source in sources]))
|
||||
else:
|
||||
# no one should mean to do this, but I think returning an empty dict is marginally preferable
|
||||
# to an informative error and much better than a confusing error
|
||||
return {}
|
||||
|
||||
class Config(BaseConfig):
|
||||
env_prefix: str = ''
|
||||
env_file: Optional[DotenvType] = None
|
||||
env_file_encoding: Optional[str] = None
|
||||
env_nested_delimiter: Optional[str] = None
|
||||
secrets_dir: Optional[StrPath] = None
|
||||
validate_all: bool = True
|
||||
extra: Extra = Extra.forbid
|
||||
arbitrary_types_allowed: bool = True
|
||||
case_sensitive: bool = False
|
||||
|
||||
@classmethod
|
||||
def prepare_field(cls, field: ModelField) -> None:
|
||||
env_names: Union[List[str], AbstractSet[str]]
|
||||
field_info_from_config = cls.get_field_info(field.name)
|
||||
|
||||
env = field_info_from_config.get('env') or field.field_info.extra.get('env')
|
||||
if env is None:
|
||||
if field.has_alias:
|
||||
warnings.warn(
|
||||
'aliases are no longer used by BaseSettings to define which environment variables to read. '
|
||||
'Instead use the "env" field setting. '
|
||||
'See https://pydantic-docs.helpmanual.io/usage/settings/#environment-variable-names',
|
||||
FutureWarning,
|
||||
)
|
||||
env_names = {cls.env_prefix + field.name}
|
||||
elif isinstance(env, str):
|
||||
env_names = {env}
|
||||
elif isinstance(env, (set, frozenset)):
|
||||
env_names = env
|
||||
elif sequence_like(env):
|
||||
env_names = list(env)
|
||||
else:
|
||||
raise TypeError(f'invalid field env: {env!r} ({display_as_type(env)}); should be string, list or set')
|
||||
|
||||
if not cls.case_sensitive:
|
||||
env_names = env_names.__class__(n.lower() for n in env_names)
|
||||
field.field_info.extra['env_names'] = env_names
|
||||
|
||||
@classmethod
|
||||
def customise_sources(
|
||||
cls,
|
||||
init_settings: SettingsSourceCallable,
|
||||
env_settings: SettingsSourceCallable,
|
||||
file_secret_settings: SettingsSourceCallable,
|
||||
) -> Tuple[SettingsSourceCallable, ...]:
|
||||
return init_settings, env_settings, file_secret_settings
|
||||
|
||||
@classmethod
|
||||
def parse_env_var(cls, field_name: str, raw_val: str) -> Any:
|
||||
return cls.json_loads(raw_val)
|
||||
|
||||
# populated by the metaclass using the Config class defined above, annotated here to help IDEs only
|
||||
__config__: ClassVar[Type[Config]]
|
||||
|
||||
|
||||
class InitSettingsSource:
|
||||
__slots__ = ('init_kwargs',)
|
||||
|
||||
def __init__(self, init_kwargs: Dict[str, Any]):
|
||||
self.init_kwargs = init_kwargs
|
||||
|
||||
def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
|
||||
return self.init_kwargs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'InitSettingsSource(init_kwargs={self.init_kwargs!r})'
|
||||
|
||||
|
||||
class EnvSettingsSource:
|
||||
__slots__ = ('env_file', 'env_file_encoding', 'env_nested_delimiter', 'env_prefix_len')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_file: Optional[DotenvType],
|
||||
env_file_encoding: Optional[str],
|
||||
env_nested_delimiter: Optional[str] = None,
|
||||
env_prefix_len: int = 0,
|
||||
):
|
||||
self.env_file: Optional[DotenvType] = env_file
|
||||
self.env_file_encoding: Optional[str] = env_file_encoding
|
||||
self.env_nested_delimiter: Optional[str] = env_nested_delimiter
|
||||
self.env_prefix_len: int = env_prefix_len
|
||||
|
||||
def __call__(self, settings: BaseSettings) -> Dict[str, Any]: # noqa C901
|
||||
"""
|
||||
Build environment variables suitable for passing to the Model.
|
||||
"""
|
||||
d: Dict[str, Any] = {}
|
||||
|
||||
if settings.__config__.case_sensitive:
|
||||
env_vars: Mapping[str, Optional[str]] = os.environ
|
||||
else:
|
||||
env_vars = {k.lower(): v for k, v in os.environ.items()}
|
||||
|
||||
dotenv_vars = self._read_env_files(settings.__config__.case_sensitive)
|
||||
if dotenv_vars:
|
||||
env_vars = {**dotenv_vars, **env_vars}
|
||||
|
||||
for field in settings.__fields__.values():
|
||||
env_val: Optional[str] = None
|
||||
for env_name in field.field_info.extra['env_names']:
|
||||
env_val = env_vars.get(env_name)
|
||||
if env_val is not None:
|
||||
break
|
||||
|
||||
is_complex, allow_parse_failure = self.field_is_complex(field)
|
||||
if is_complex:
|
||||
if env_val is None:
|
||||
# field is complex but no value found so far, try explode_env_vars
|
||||
env_val_built = self.explode_env_vars(field, env_vars)
|
||||
if env_val_built:
|
||||
d[field.alias] = env_val_built
|
||||
else:
|
||||
# field is complex and there's a value, decode that as JSON, then add explode_env_vars
|
||||
try:
|
||||
env_val = settings.__config__.parse_env_var(field.name, env_val)
|
||||
except ValueError as e:
|
||||
if not allow_parse_failure:
|
||||
raise SettingsError(f'error parsing env var "{env_name}"') from e
|
||||
|
||||
if isinstance(env_val, dict):
|
||||
d[field.alias] = deep_update(env_val, self.explode_env_vars(field, env_vars))
|
||||
else:
|
||||
d[field.alias] = env_val
|
||||
elif env_val is not None:
|
||||
# simplest case, field is not complex, we only need to add the value if it was found
|
||||
d[field.alias] = env_val
|
||||
|
||||
return d
|
||||
|
||||
def _read_env_files(self, case_sensitive: bool) -> Dict[str, Optional[str]]:
|
||||
env_files = self.env_file
|
||||
if env_files is None:
|
||||
return {}
|
||||
|
||||
if isinstance(env_files, (str, os.PathLike)):
|
||||
env_files = [env_files]
|
||||
|
||||
dotenv_vars = {}
|
||||
for env_file in env_files:
|
||||
env_path = Path(env_file).expanduser()
|
||||
if env_path.is_file():
|
||||
dotenv_vars.update(
|
||||
read_env_file(env_path, encoding=self.env_file_encoding, case_sensitive=case_sensitive)
|
||||
)
|
||||
|
||||
return dotenv_vars
|
||||
|
||||
def field_is_complex(self, field: ModelField) -> Tuple[bool, bool]:
|
||||
"""
|
||||
Find out if a field is complex, and if so whether JSON errors should be ignored
|
||||
"""
|
||||
if lenient_issubclass(field.annotation, JsonWrapper):
|
||||
return False, False
|
||||
|
||||
if field.is_complex():
|
||||
allow_parse_failure = False
|
||||
elif is_union(get_origin(field.type_)) and field.sub_fields and any(f.is_complex() for f in field.sub_fields):
|
||||
allow_parse_failure = True
|
||||
else:
|
||||
return False, False
|
||||
|
||||
return True, allow_parse_failure
|
||||
|
||||
def explode_env_vars(self, field: ModelField, env_vars: Mapping[str, Optional[str]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries.
|
||||
|
||||
This is applied to a single field, hence filtering by env_var prefix.
|
||||
"""
|
||||
prefixes = [f'{env_name}{self.env_nested_delimiter}' for env_name in field.field_info.extra['env_names']]
|
||||
result: Dict[str, Any] = {}
|
||||
for env_name, env_val in env_vars.items():
|
||||
if not any(env_name.startswith(prefix) for prefix in prefixes):
|
||||
continue
|
||||
# we remove the prefix before splitting in case the prefix has characters in common with the delimiter
|
||||
env_name_without_prefix = env_name[self.env_prefix_len :]
|
||||
_, *keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter)
|
||||
env_var = result
|
||||
for key in keys:
|
||||
env_var = env_var.setdefault(key, {})
|
||||
env_var[last_key] = env_val
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'EnvSettingsSource(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, '
|
||||
f'env_nested_delimiter={self.env_nested_delimiter!r})'
|
||||
)
|
||||
|
||||
|
||||
class SecretsSettingsSource:
|
||||
__slots__ = ('secrets_dir',)
|
||||
|
||||
def __init__(self, secrets_dir: Optional[StrPath]):
|
||||
self.secrets_dir: Optional[StrPath] = secrets_dir
|
||||
|
||||
def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
|
||||
"""
|
||||
Build fields from "secrets" files.
|
||||
"""
|
||||
secrets: Dict[str, Optional[str]] = {}
|
||||
|
||||
if self.secrets_dir is None:
|
||||
return secrets
|
||||
|
||||
secrets_path = Path(self.secrets_dir).expanduser()
|
||||
|
||||
if not secrets_path.exists():
|
||||
warnings.warn(f'directory "{secrets_path}" does not exist')
|
||||
return secrets
|
||||
|
||||
if not secrets_path.is_dir():
|
||||
raise SettingsError(f'secrets_dir must reference a directory, not a {path_type(secrets_path)}')
|
||||
|
||||
for field in settings.__fields__.values():
|
||||
for env_name in field.field_info.extra['env_names']:
|
||||
path = find_case_path(secrets_path, env_name, settings.__config__.case_sensitive)
|
||||
if not path:
|
||||
# path does not exist, we currently don't return a warning for this
|
||||
continue
|
||||
|
||||
if path.is_file():
|
||||
secret_value = path.read_text().strip()
|
||||
if field.is_complex():
|
||||
try:
|
||||
secret_value = settings.__config__.parse_env_var(field.name, secret_value)
|
||||
except ValueError as e:
|
||||
raise SettingsError(f'error parsing env var "{env_name}"') from e
|
||||
|
||||
secrets[field.alias] = secret_value
|
||||
else:
|
||||
warnings.warn(
|
||||
f'attempted to load secret file "{path}" but found a {path_type(path)} instead.',
|
||||
stacklevel=4,
|
||||
)
|
||||
return secrets
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'SecretsSettingsSource(secrets_dir={self.secrets_dir!r})'
|
||||
|
||||
|
||||
def read_env_file(
|
||||
file_path: StrPath, *, encoding: str = None, case_sensitive: bool = False
|
||||
) -> Dict[str, Optional[str]]:
|
||||
try:
|
||||
from dotenv import dotenv_values
|
||||
except ImportError as e:
|
||||
raise ImportError('python-dotenv is not installed, run `pip install pydantic[dotenv]`') from e
|
||||
|
||||
file_vars: Dict[str, Optional[str]] = dotenv_values(file_path, encoding=encoding or 'utf8')
|
||||
if not case_sensitive:
|
||||
return {k.lower(): v for k, v in file_vars.items()}
|
||||
else:
|
||||
return file_vars
|
||||
|
||||
|
||||
def find_case_path(dir_path: Path, file_name: str, case_sensitive: bool) -> Optional[Path]:
|
||||
"""
|
||||
Find a file within path's directory matching filename, optionally ignoring case.
|
||||
"""
|
||||
for f in dir_path.iterdir():
|
||||
if f.name == file_name:
|
||||
return f
|
||||
elif not case_sensitive and f.name.lower() == file_name.lower():
|
||||
return f
|
||||
return None
|
||||
@@ -0,0 +1,161 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
from pydantic.v1.json import pydantic_encoder
|
||||
from pydantic.v1.utils import Representation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from pydantic.v1.config import BaseConfig
|
||||
from pydantic.v1.types import ModelOrDc
|
||||
from pydantic.v1.typing import ReprArgs
|
||||
|
||||
Loc = Tuple[Union[int, str], ...]
|
||||
|
||||
class _ErrorDictRequired(TypedDict):
|
||||
loc: Loc
|
||||
msg: str
|
||||
type: str
|
||||
|
||||
class ErrorDict(_ErrorDictRequired, total=False):
|
||||
ctx: Dict[str, Any]
|
||||
|
||||
|
||||
__all__ = 'ErrorWrapper', 'ValidationError'
|
||||
|
||||
|
||||
class ErrorWrapper(Representation):
|
||||
__slots__ = 'exc', '_loc'
|
||||
|
||||
def __init__(self, exc: Exception, loc: Union[str, 'Loc']) -> None:
|
||||
self.exc = exc
|
||||
self._loc = loc
|
||||
|
||||
def loc_tuple(self) -> 'Loc':
|
||||
if isinstance(self._loc, tuple):
|
||||
return self._loc
|
||||
else:
|
||||
return (self._loc,)
|
||||
|
||||
def __repr_args__(self) -> 'ReprArgs':
|
||||
return [('exc', self.exc), ('loc', self.loc_tuple())]
|
||||
|
||||
|
||||
# ErrorList is something like Union[List[Union[List[ErrorWrapper], ErrorWrapper]], ErrorWrapper]
|
||||
# but recursive, therefore just use:
|
||||
ErrorList = Union[Sequence[Any], ErrorWrapper]
|
||||
|
||||
|
||||
class ValidationError(Representation, ValueError):
|
||||
__slots__ = 'raw_errors', 'model', '_error_cache'
|
||||
|
||||
def __init__(self, errors: Sequence[ErrorList], model: 'ModelOrDc') -> None:
|
||||
self.raw_errors = errors
|
||||
self.model = model
|
||||
self._error_cache: Optional[List['ErrorDict']] = None
|
||||
|
||||
def errors(self) -> List['ErrorDict']:
|
||||
if self._error_cache is None:
|
||||
try:
|
||||
config = self.model.__config__ # type: ignore
|
||||
except AttributeError:
|
||||
config = self.model.__pydantic_model__.__config__ # type: ignore
|
||||
self._error_cache = list(flatten_errors(self.raw_errors, config))
|
||||
return self._error_cache
|
||||
|
||||
def json(self, *, indent: Union[None, int, str] = 2) -> str:
|
||||
return json.dumps(self.errors(), indent=indent, default=pydantic_encoder)
|
||||
|
||||
def __str__(self) -> str:
|
||||
errors = self.errors()
|
||||
no_errors = len(errors)
|
||||
return (
|
||||
f'{no_errors} validation error{"" if no_errors == 1 else "s"} for {self.model.__name__}\n'
|
||||
f'{display_errors(errors)}'
|
||||
)
|
||||
|
||||
def __repr_args__(self) -> 'ReprArgs':
|
||||
return [('model', self.model.__name__), ('errors', self.errors())]
|
||||
|
||||
|
||||
def display_errors(errors: List['ErrorDict']) -> str:
|
||||
return '\n'.join(f'{_display_error_loc(e)}\n {e["msg"]} ({_display_error_type_and_ctx(e)})' for e in errors)
|
||||
|
||||
|
||||
def _display_error_loc(error: 'ErrorDict') -> str:
|
||||
return ' -> '.join(str(e) for e in error['loc'])
|
||||
|
||||
|
||||
def _display_error_type_and_ctx(error: 'ErrorDict') -> str:
|
||||
t = 'type=' + error['type']
|
||||
ctx = error.get('ctx')
|
||||
if ctx:
|
||||
return t + ''.join(f'; {k}={v}' for k, v in ctx.items())
|
||||
else:
|
||||
return t
|
||||
|
||||
|
||||
def flatten_errors(
|
||||
errors: Sequence[Any], config: Type['BaseConfig'], loc: Optional['Loc'] = None
|
||||
) -> Generator['ErrorDict', None, None]:
|
||||
for error in errors:
|
||||
if isinstance(error, ErrorWrapper):
|
||||
if loc:
|
||||
error_loc = loc + error.loc_tuple()
|
||||
else:
|
||||
error_loc = error.loc_tuple()
|
||||
|
||||
if isinstance(error.exc, ValidationError):
|
||||
yield from flatten_errors(error.exc.raw_errors, config, error_loc)
|
||||
else:
|
||||
yield error_dict(error.exc, config, error_loc)
|
||||
elif isinstance(error, list):
|
||||
yield from flatten_errors(error, config, loc=loc)
|
||||
else:
|
||||
raise RuntimeError(f'Unknown error object: {error}')
|
||||
|
||||
|
||||
def error_dict(exc: Exception, config: Type['BaseConfig'], loc: 'Loc') -> 'ErrorDict':
|
||||
type_ = get_exc_type(exc.__class__)
|
||||
msg_template = config.error_msg_templates.get(type_) or getattr(exc, 'msg_template', None)
|
||||
ctx = exc.__dict__
|
||||
if msg_template:
|
||||
msg = msg_template.format(**ctx)
|
||||
else:
|
||||
msg = str(exc)
|
||||
|
||||
d: 'ErrorDict' = {'loc': loc, 'msg': msg, 'type': type_}
|
||||
|
||||
if ctx:
|
||||
d['ctx'] = ctx
|
||||
|
||||
return d
|
||||
|
||||
|
||||
_EXC_TYPE_CACHE: Dict[Type[Exception], str] = {}
|
||||
|
||||
|
||||
def get_exc_type(cls: Type[Exception]) -> str:
|
||||
# slightly more efficient than using lru_cache since we don't need to worry about the cache filling up
|
||||
try:
|
||||
return _EXC_TYPE_CACHE[cls]
|
||||
except KeyError:
|
||||
r = _get_exc_type(cls)
|
||||
_EXC_TYPE_CACHE[cls] = r
|
||||
return r
|
||||
|
||||
|
||||
def _get_exc_type(cls: Type[Exception]) -> str:
|
||||
if issubclass(cls, AssertionError):
|
||||
return 'assertion_error'
|
||||
|
||||
base_name = 'type_error' if issubclass(cls, TypeError) else 'value_error'
|
||||
if cls in (TypeError, ValueError):
|
||||
# just TypeError or ValueError, no extra code
|
||||
return base_name
|
||||
|
||||
# if it's not a TypeError or ValueError, we just take the lowercase of the exception name
|
||||
# no chaining or snake case logic, use "code" for more complex error types.
|
||||
code = getattr(cls, 'code', None) or cls.__name__.replace('Error', '').lower()
|
||||
return base_name + '.' + code
|
||||
@@ -0,0 +1,646 @@
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Sequence, Set, Tuple, Type, Union
|
||||
|
||||
from pydantic.v1.typing import display_as_type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.typing import DictStrAny
|
||||
|
||||
# explicitly state exports to avoid "from pydantic.v1.errors import *" also importing Decimal, Path etc.
|
||||
__all__ = (
|
||||
'PydanticTypeError',
|
||||
'PydanticValueError',
|
||||
'ConfigError',
|
||||
'MissingError',
|
||||
'ExtraError',
|
||||
'NoneIsNotAllowedError',
|
||||
'NoneIsAllowedError',
|
||||
'WrongConstantError',
|
||||
'NotNoneError',
|
||||
'BoolError',
|
||||
'BytesError',
|
||||
'DictError',
|
||||
'EmailError',
|
||||
'UrlError',
|
||||
'UrlSchemeError',
|
||||
'UrlSchemePermittedError',
|
||||
'UrlUserInfoError',
|
||||
'UrlHostError',
|
||||
'UrlHostTldError',
|
||||
'UrlPortError',
|
||||
'UrlExtraError',
|
||||
'EnumError',
|
||||
'IntEnumError',
|
||||
'EnumMemberError',
|
||||
'IntegerError',
|
||||
'FloatError',
|
||||
'PathError',
|
||||
'PathNotExistsError',
|
||||
'PathNotAFileError',
|
||||
'PathNotADirectoryError',
|
||||
'PyObjectError',
|
||||
'SequenceError',
|
||||
'ListError',
|
||||
'SetError',
|
||||
'FrozenSetError',
|
||||
'TupleError',
|
||||
'TupleLengthError',
|
||||
'ListMinLengthError',
|
||||
'ListMaxLengthError',
|
||||
'ListUniqueItemsError',
|
||||
'SetMinLengthError',
|
||||
'SetMaxLengthError',
|
||||
'FrozenSetMinLengthError',
|
||||
'FrozenSetMaxLengthError',
|
||||
'AnyStrMinLengthError',
|
||||
'AnyStrMaxLengthError',
|
||||
'StrError',
|
||||
'StrRegexError',
|
||||
'NumberNotGtError',
|
||||
'NumberNotGeError',
|
||||
'NumberNotLtError',
|
||||
'NumberNotLeError',
|
||||
'NumberNotMultipleError',
|
||||
'DecimalError',
|
||||
'DecimalIsNotFiniteError',
|
||||
'DecimalMaxDigitsError',
|
||||
'DecimalMaxPlacesError',
|
||||
'DecimalWholeDigitsError',
|
||||
'DateTimeError',
|
||||
'DateError',
|
||||
'DateNotInThePastError',
|
||||
'DateNotInTheFutureError',
|
||||
'TimeError',
|
||||
'DurationError',
|
||||
'HashableError',
|
||||
'UUIDError',
|
||||
'UUIDVersionError',
|
||||
'ArbitraryTypeError',
|
||||
'ClassError',
|
||||
'SubclassError',
|
||||
'JsonError',
|
||||
'JsonTypeError',
|
||||
'PatternError',
|
||||
'DataclassTypeError',
|
||||
'CallableError',
|
||||
'IPvAnyAddressError',
|
||||
'IPvAnyInterfaceError',
|
||||
'IPvAnyNetworkError',
|
||||
'IPv4AddressError',
|
||||
'IPv6AddressError',
|
||||
'IPv4NetworkError',
|
||||
'IPv6NetworkError',
|
||||
'IPv4InterfaceError',
|
||||
'IPv6InterfaceError',
|
||||
'ColorError',
|
||||
'StrictBoolError',
|
||||
'NotDigitError',
|
||||
'LuhnValidationError',
|
||||
'InvalidLengthForBrand',
|
||||
'InvalidByteSize',
|
||||
'InvalidByteSizeUnit',
|
||||
'MissingDiscriminator',
|
||||
'InvalidDiscriminator',
|
||||
)
|
||||
|
||||
|
||||
def cls_kwargs(cls: Type['PydanticErrorMixin'], ctx: 'DictStrAny') -> 'PydanticErrorMixin':
|
||||
"""
|
||||
For built-in exceptions like ValueError or TypeError, we need to implement
|
||||
__reduce__ to override the default behaviour (instead of __getstate__/__setstate__)
|
||||
By default pickle protocol 2 calls `cls.__new__(cls, *args)`.
|
||||
Since we only use kwargs, we need a little constructor to change that.
|
||||
Note: the callable can't be a lambda as pickle looks in the namespace to find it
|
||||
"""
|
||||
return cls(**ctx)
|
||||
|
||||
|
||||
class PydanticErrorMixin:
|
||||
code: str
|
||||
msg_template: str
|
||||
|
||||
def __init__(self, **ctx: Any) -> None:
|
||||
self.__dict__ = ctx
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.msg_template.format(**self.__dict__)
|
||||
|
||||
def __reduce__(self) -> Tuple[Callable[..., 'PydanticErrorMixin'], Tuple[Type['PydanticErrorMixin'], 'DictStrAny']]:
|
||||
return cls_kwargs, (self.__class__, self.__dict__)
|
||||
|
||||
|
||||
class PydanticTypeError(PydanticErrorMixin, TypeError):
|
||||
pass
|
||||
|
||||
|
||||
class PydanticValueError(PydanticErrorMixin, ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ConfigError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class MissingError(PydanticValueError):
|
||||
msg_template = 'field required'
|
||||
|
||||
|
||||
class ExtraError(PydanticValueError):
|
||||
msg_template = 'extra fields not permitted'
|
||||
|
||||
|
||||
class NoneIsNotAllowedError(PydanticTypeError):
|
||||
code = 'none.not_allowed'
|
||||
msg_template = 'none is not an allowed value'
|
||||
|
||||
|
||||
class NoneIsAllowedError(PydanticTypeError):
|
||||
code = 'none.allowed'
|
||||
msg_template = 'value is not none'
|
||||
|
||||
|
||||
class WrongConstantError(PydanticValueError):
|
||||
code = 'const'
|
||||
|
||||
def __str__(self) -> str:
|
||||
permitted = ', '.join(repr(v) for v in self.permitted) # type: ignore
|
||||
return f'unexpected value; permitted: {permitted}'
|
||||
|
||||
|
||||
class NotNoneError(PydanticTypeError):
|
||||
code = 'not_none'
|
||||
msg_template = 'value is not None'
|
||||
|
||||
|
||||
class BoolError(PydanticTypeError):
|
||||
msg_template = 'value could not be parsed to a boolean'
|
||||
|
||||
|
||||
class BytesError(PydanticTypeError):
|
||||
msg_template = 'byte type expected'
|
||||
|
||||
|
||||
class DictError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid dict'
|
||||
|
||||
|
||||
class EmailError(PydanticValueError):
|
||||
msg_template = 'value is not a valid email address'
|
||||
|
||||
|
||||
class UrlError(PydanticValueError):
|
||||
code = 'url'
|
||||
|
||||
|
||||
class UrlSchemeError(UrlError):
|
||||
code = 'url.scheme'
|
||||
msg_template = 'invalid or missing URL scheme'
|
||||
|
||||
|
||||
class UrlSchemePermittedError(UrlError):
|
||||
code = 'url.scheme'
|
||||
msg_template = 'URL scheme not permitted'
|
||||
|
||||
def __init__(self, allowed_schemes: Set[str]):
|
||||
super().__init__(allowed_schemes=allowed_schemes)
|
||||
|
||||
|
||||
class UrlUserInfoError(UrlError):
|
||||
code = 'url.userinfo'
|
||||
msg_template = 'userinfo required in URL but missing'
|
||||
|
||||
|
||||
class UrlHostError(UrlError):
|
||||
code = 'url.host'
|
||||
msg_template = 'URL host invalid'
|
||||
|
||||
|
||||
class UrlHostTldError(UrlError):
|
||||
code = 'url.host'
|
||||
msg_template = 'URL host invalid, top level domain required'
|
||||
|
||||
|
||||
class UrlPortError(UrlError):
|
||||
code = 'url.port'
|
||||
msg_template = 'URL port invalid, port cannot exceed 65535'
|
||||
|
||||
|
||||
class UrlExtraError(UrlError):
|
||||
code = 'url.extra'
|
||||
msg_template = 'URL invalid, extra characters found after valid URL: {extra!r}'
|
||||
|
||||
|
||||
class EnumMemberError(PydanticTypeError):
|
||||
code = 'enum'
|
||||
|
||||
def __str__(self) -> str:
|
||||
permitted = ', '.join(repr(v.value) for v in self.enum_values) # type: ignore
|
||||
return f'value is not a valid enumeration member; permitted: {permitted}'
|
||||
|
||||
|
||||
class IntegerError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid integer'
|
||||
|
||||
|
||||
class FloatError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid float'
|
||||
|
||||
|
||||
class PathError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid path'
|
||||
|
||||
|
||||
class _PathValueError(PydanticValueError):
|
||||
def __init__(self, *, path: Path) -> None:
|
||||
super().__init__(path=str(path))
|
||||
|
||||
|
||||
class PathNotExistsError(_PathValueError):
|
||||
code = 'path.not_exists'
|
||||
msg_template = 'file or directory at path "{path}" does not exist'
|
||||
|
||||
|
||||
class PathNotAFileError(_PathValueError):
|
||||
code = 'path.not_a_file'
|
||||
msg_template = 'path "{path}" does not point to a file'
|
||||
|
||||
|
||||
class PathNotADirectoryError(_PathValueError):
|
||||
code = 'path.not_a_directory'
|
||||
msg_template = 'path "{path}" does not point to a directory'
|
||||
|
||||
|
||||
class PyObjectError(PydanticTypeError):
|
||||
msg_template = 'ensure this value contains valid import path or valid callable: {error_message}'
|
||||
|
||||
|
||||
class SequenceError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid sequence'
|
||||
|
||||
|
||||
class IterableError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid iterable'
|
||||
|
||||
|
||||
class ListError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid list'
|
||||
|
||||
|
||||
class SetError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid set'
|
||||
|
||||
|
||||
class FrozenSetError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid frozenset'
|
||||
|
||||
|
||||
class DequeError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid deque'
|
||||
|
||||
|
||||
class TupleError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid tuple'
|
||||
|
||||
|
||||
class TupleLengthError(PydanticValueError):
|
||||
code = 'tuple.length'
|
||||
msg_template = 'wrong tuple length {actual_length}, expected {expected_length}'
|
||||
|
||||
def __init__(self, *, actual_length: int, expected_length: int) -> None:
|
||||
super().__init__(actual_length=actual_length, expected_length=expected_length)
|
||||
|
||||
|
||||
class ListMinLengthError(PydanticValueError):
|
||||
code = 'list.min_items'
|
||||
msg_template = 'ensure this value has at least {limit_value} items'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class ListMaxLengthError(PydanticValueError):
|
||||
code = 'list.max_items'
|
||||
msg_template = 'ensure this value has at most {limit_value} items'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class ListUniqueItemsError(PydanticValueError):
|
||||
code = 'list.unique_items'
|
||||
msg_template = 'the list has duplicated items'
|
||||
|
||||
|
||||
class SetMinLengthError(PydanticValueError):
|
||||
code = 'set.min_items'
|
||||
msg_template = 'ensure this value has at least {limit_value} items'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class SetMaxLengthError(PydanticValueError):
|
||||
code = 'set.max_items'
|
||||
msg_template = 'ensure this value has at most {limit_value} items'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class FrozenSetMinLengthError(PydanticValueError):
|
||||
code = 'frozenset.min_items'
|
||||
msg_template = 'ensure this value has at least {limit_value} items'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class FrozenSetMaxLengthError(PydanticValueError):
|
||||
code = 'frozenset.max_items'
|
||||
msg_template = 'ensure this value has at most {limit_value} items'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class AnyStrMinLengthError(PydanticValueError):
|
||||
code = 'any_str.min_length'
|
||||
msg_template = 'ensure this value has at least {limit_value} characters'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class AnyStrMaxLengthError(PydanticValueError):
|
||||
code = 'any_str.max_length'
|
||||
msg_template = 'ensure this value has at most {limit_value} characters'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class StrError(PydanticTypeError):
|
||||
msg_template = 'str type expected'
|
||||
|
||||
|
||||
class StrRegexError(PydanticValueError):
|
||||
code = 'str.regex'
|
||||
msg_template = 'string does not match regex "{pattern}"'
|
||||
|
||||
def __init__(self, *, pattern: str) -> None:
|
||||
super().__init__(pattern=pattern)
|
||||
|
||||
|
||||
class _NumberBoundError(PydanticValueError):
|
||||
def __init__(self, *, limit_value: Union[int, float, Decimal]) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class NumberNotGtError(_NumberBoundError):
|
||||
code = 'number.not_gt'
|
||||
msg_template = 'ensure this value is greater than {limit_value}'
|
||||
|
||||
|
||||
class NumberNotGeError(_NumberBoundError):
|
||||
code = 'number.not_ge'
|
||||
msg_template = 'ensure this value is greater than or equal to {limit_value}'
|
||||
|
||||
|
||||
class NumberNotLtError(_NumberBoundError):
|
||||
code = 'number.not_lt'
|
||||
msg_template = 'ensure this value is less than {limit_value}'
|
||||
|
||||
|
||||
class NumberNotLeError(_NumberBoundError):
|
||||
code = 'number.not_le'
|
||||
msg_template = 'ensure this value is less than or equal to {limit_value}'
|
||||
|
||||
|
||||
class NumberNotFiniteError(PydanticValueError):
|
||||
code = 'number.not_finite_number'
|
||||
msg_template = 'ensure this value is a finite number'
|
||||
|
||||
|
||||
class NumberNotMultipleError(PydanticValueError):
|
||||
code = 'number.not_multiple'
|
||||
msg_template = 'ensure this value is a multiple of {multiple_of}'
|
||||
|
||||
def __init__(self, *, multiple_of: Union[int, float, Decimal]) -> None:
|
||||
super().__init__(multiple_of=multiple_of)
|
||||
|
||||
|
||||
class DecimalError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid decimal'
|
||||
|
||||
|
||||
class DecimalIsNotFiniteError(PydanticValueError):
|
||||
code = 'decimal.not_finite'
|
||||
msg_template = 'value is not a valid decimal'
|
||||
|
||||
|
||||
class DecimalMaxDigitsError(PydanticValueError):
|
||||
code = 'decimal.max_digits'
|
||||
msg_template = 'ensure that there are no more than {max_digits} digits in total'
|
||||
|
||||
def __init__(self, *, max_digits: int) -> None:
|
||||
super().__init__(max_digits=max_digits)
|
||||
|
||||
|
||||
class DecimalMaxPlacesError(PydanticValueError):
|
||||
code = 'decimal.max_places'
|
||||
msg_template = 'ensure that there are no more than {decimal_places} decimal places'
|
||||
|
||||
def __init__(self, *, decimal_places: int) -> None:
|
||||
super().__init__(decimal_places=decimal_places)
|
||||
|
||||
|
||||
class DecimalWholeDigitsError(PydanticValueError):
|
||||
code = 'decimal.whole_digits'
|
||||
msg_template = 'ensure that there are no more than {whole_digits} digits before the decimal point'
|
||||
|
||||
def __init__(self, *, whole_digits: int) -> None:
|
||||
super().__init__(whole_digits=whole_digits)
|
||||
|
||||
|
||||
class DateTimeError(PydanticValueError):
|
||||
msg_template = 'invalid datetime format'
|
||||
|
||||
|
||||
class DateError(PydanticValueError):
|
||||
msg_template = 'invalid date format'
|
||||
|
||||
|
||||
class DateNotInThePastError(PydanticValueError):
|
||||
code = 'date.not_in_the_past'
|
||||
msg_template = 'date is not in the past'
|
||||
|
||||
|
||||
class DateNotInTheFutureError(PydanticValueError):
|
||||
code = 'date.not_in_the_future'
|
||||
msg_template = 'date is not in the future'
|
||||
|
||||
|
||||
class TimeError(PydanticValueError):
|
||||
msg_template = 'invalid time format'
|
||||
|
||||
|
||||
class DurationError(PydanticValueError):
|
||||
msg_template = 'invalid duration format'
|
||||
|
||||
|
||||
class HashableError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid hashable'
|
||||
|
||||
|
||||
class UUIDError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid uuid'
|
||||
|
||||
|
||||
class UUIDVersionError(PydanticValueError):
|
||||
code = 'uuid.version'
|
||||
msg_template = 'uuid version {required_version} expected'
|
||||
|
||||
def __init__(self, *, required_version: int) -> None:
|
||||
super().__init__(required_version=required_version)
|
||||
|
||||
|
||||
class ArbitraryTypeError(PydanticTypeError):
|
||||
code = 'arbitrary_type'
|
||||
msg_template = 'instance of {expected_arbitrary_type} expected'
|
||||
|
||||
def __init__(self, *, expected_arbitrary_type: Type[Any]) -> None:
|
||||
super().__init__(expected_arbitrary_type=display_as_type(expected_arbitrary_type))
|
||||
|
||||
|
||||
class ClassError(PydanticTypeError):
|
||||
code = 'class'
|
||||
msg_template = 'a class is expected'
|
||||
|
||||
|
||||
class SubclassError(PydanticTypeError):
|
||||
code = 'subclass'
|
||||
msg_template = 'subclass of {expected_class} expected'
|
||||
|
||||
def __init__(self, *, expected_class: Type[Any]) -> None:
|
||||
super().__init__(expected_class=display_as_type(expected_class))
|
||||
|
||||
|
||||
class JsonError(PydanticValueError):
|
||||
msg_template = 'Invalid JSON'
|
||||
|
||||
|
||||
class JsonTypeError(PydanticTypeError):
|
||||
code = 'json'
|
||||
msg_template = 'JSON object must be str, bytes or bytearray'
|
||||
|
||||
|
||||
class PatternError(PydanticValueError):
|
||||
code = 'regex_pattern'
|
||||
msg_template = 'Invalid regular expression'
|
||||
|
||||
|
||||
class DataclassTypeError(PydanticTypeError):
|
||||
code = 'dataclass'
|
||||
msg_template = 'instance of {class_name}, tuple or dict expected'
|
||||
|
||||
|
||||
class CallableError(PydanticTypeError):
|
||||
msg_template = '{value} is not callable'
|
||||
|
||||
|
||||
class EnumError(PydanticTypeError):
|
||||
code = 'enum_instance'
|
||||
msg_template = '{value} is not a valid Enum instance'
|
||||
|
||||
|
||||
class IntEnumError(PydanticTypeError):
|
||||
code = 'int_enum_instance'
|
||||
msg_template = '{value} is not a valid IntEnum instance'
|
||||
|
||||
|
||||
class IPvAnyAddressError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv4 or IPv6 address'
|
||||
|
||||
|
||||
class IPvAnyInterfaceError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv4 or IPv6 interface'
|
||||
|
||||
|
||||
class IPvAnyNetworkError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv4 or IPv6 network'
|
||||
|
||||
|
||||
class IPv4AddressError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv4 address'
|
||||
|
||||
|
||||
class IPv6AddressError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv6 address'
|
||||
|
||||
|
||||
class IPv4NetworkError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv4 network'
|
||||
|
||||
|
||||
class IPv6NetworkError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv6 network'
|
||||
|
||||
|
||||
class IPv4InterfaceError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv4 interface'
|
||||
|
||||
|
||||
class IPv6InterfaceError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv6 interface'
|
||||
|
||||
|
||||
class ColorError(PydanticValueError):
|
||||
msg_template = 'value is not a valid color: {reason}'
|
||||
|
||||
|
||||
class StrictBoolError(PydanticValueError):
|
||||
msg_template = 'value is not a valid boolean'
|
||||
|
||||
|
||||
class NotDigitError(PydanticValueError):
|
||||
code = 'payment_card_number.digits'
|
||||
msg_template = 'card number is not all digits'
|
||||
|
||||
|
||||
class LuhnValidationError(PydanticValueError):
|
||||
code = 'payment_card_number.luhn_check'
|
||||
msg_template = 'card number is not luhn valid'
|
||||
|
||||
|
||||
class InvalidLengthForBrand(PydanticValueError):
|
||||
code = 'payment_card_number.invalid_length_for_brand'
|
||||
msg_template = 'Length for a {brand} card must be {required_length}'
|
||||
|
||||
|
||||
class InvalidByteSize(PydanticValueError):
|
||||
msg_template = 'could not parse value and unit from byte string'
|
||||
|
||||
|
||||
class InvalidByteSizeUnit(PydanticValueError):
|
||||
msg_template = 'could not interpret byte unit: {unit}'
|
||||
|
||||
|
||||
class MissingDiscriminator(PydanticValueError):
|
||||
code = 'discriminated_union.missing_discriminator'
|
||||
msg_template = 'Discriminator {discriminator_key!r} is missing in value'
|
||||
|
||||
|
||||
class InvalidDiscriminator(PydanticValueError):
|
||||
code = 'discriminated_union.invalid_discriminator'
|
||||
msg_template = (
|
||||
'No match for discriminator {discriminator_key!r} and value {discriminator_value!r} '
|
||||
'(allowed values: {allowed_values})'
|
||||
)
|
||||
|
||||
def __init__(self, *, discriminator_key: str, discriminator_value: Any, allowed_values: Sequence[Any]) -> None:
|
||||
super().__init__(
|
||||
discriminator_key=discriminator_key,
|
||||
discriminator_value=discriminator_value,
|
||||
allowed_values=', '.join(map(repr, allowed_values)),
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,400 @@
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from weakref import WeakKeyDictionary, WeakValueDictionary
|
||||
|
||||
from typing_extensions import Annotated, Literal as ExtLiteral
|
||||
|
||||
from pydantic.v1.class_validators import gather_all_validators
|
||||
from pydantic.v1.fields import DeferredType
|
||||
from pydantic.v1.main import BaseModel, create_model
|
||||
from pydantic.v1.types import JsonWrapper
|
||||
from pydantic.v1.typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base
|
||||
from pydantic.v1.utils import all_identical, lenient_issubclass
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import _UnionGenericAlias
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Literal
|
||||
|
||||
GenericModelT = TypeVar('GenericModelT', bound='GenericModel')
|
||||
TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type
|
||||
|
||||
CacheKey = Tuple[Type[Any], Any, Tuple[Any, ...]]
|
||||
Parametrization = Mapping[TypeVarType, Type[Any]]
|
||||
|
||||
# weak dictionaries allow the dynamically created parametrized versions of generic models to get collected
|
||||
# once they are no longer referenced by the caller.
|
||||
if sys.version_info >= (3, 9): # Typing for weak dictionaries available at 3.9
|
||||
GenericTypesCache = WeakValueDictionary[CacheKey, Type[BaseModel]]
|
||||
AssignedParameters = WeakKeyDictionary[Type[BaseModel], Parametrization]
|
||||
else:
|
||||
GenericTypesCache = WeakValueDictionary
|
||||
AssignedParameters = WeakKeyDictionary
|
||||
|
||||
# _generic_types_cache is a Mapping from __class_getitem__ arguments to the parametrized version of generic models.
|
||||
# This ensures multiple calls of e.g. A[B] return always the same class.
|
||||
_generic_types_cache = GenericTypesCache()
|
||||
|
||||
# _assigned_parameters is a Mapping from parametrized version of generic models to assigned types of parametrizations
|
||||
# as captured during construction of the class (not instances).
|
||||
# E.g., for generic model `Model[A, B]`, when parametrized model `Model[int, str]` is created,
|
||||
# `Model[int, str]`: {A: int, B: str}` will be stored in `_assigned_parameters`.
|
||||
# (This information is only otherwise available after creation from the class name string).
|
||||
_assigned_parameters = AssignedParameters()
|
||||
|
||||
|
||||
class GenericModel(BaseModel):
|
||||
__slots__ = ()
|
||||
__concrete__: ClassVar[bool] = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Putting this in a TYPE_CHECKING block allows us to replace `if Generic not in cls.__bases__` with
|
||||
# `not hasattr(cls, "__parameters__")`. This means we don't need to force non-concrete subclasses of
|
||||
# `GenericModel` to also inherit from `Generic`, which would require changes to the use of `create_model` below.
|
||||
__parameters__: ClassVar[Tuple[TypeVarType, ...]]
|
||||
|
||||
# Setting the return type as Type[Any] instead of Type[BaseModel] prevents PyCharm warnings
|
||||
def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[Type[Any], ...]]) -> Type[Any]:
|
||||
"""Instantiates a new class from a generic class `cls` and type variables `params`.
|
||||
|
||||
:param params: Tuple of types the class . Given a generic class
|
||||
`Model` with 2 type variables and a concrete model `Model[str, int]`,
|
||||
the value `(str, int)` would be passed to `params`.
|
||||
:return: New model class inheriting from `cls` with instantiated
|
||||
types described by `params`. If no parameters are given, `cls` is
|
||||
returned as is.
|
||||
|
||||
"""
|
||||
|
||||
def _cache_key(_params: Any) -> CacheKey:
|
||||
args = get_args(_params)
|
||||
# python returns a list for Callables, which is not hashable
|
||||
if len(args) == 2 and isinstance(args[0], list):
|
||||
args = (tuple(args[0]), args[1])
|
||||
return cls, _params, args
|
||||
|
||||
cached = _generic_types_cache.get(_cache_key(params))
|
||||
if cached is not None:
|
||||
return cached
|
||||
if cls.__concrete__ and Generic not in cls.__bases__:
|
||||
raise TypeError('Cannot parameterize a concrete instantiation of a generic model')
|
||||
if not isinstance(params, tuple):
|
||||
params = (params,)
|
||||
if cls is GenericModel and any(isinstance(param, TypeVar) for param in params):
|
||||
raise TypeError('Type parameters should be placed on typing.Generic, not GenericModel')
|
||||
if not hasattr(cls, '__parameters__'):
|
||||
raise TypeError(f'Type {cls.__name__} must inherit from typing.Generic before being parameterized')
|
||||
|
||||
check_parameters_count(cls, params)
|
||||
# Build map from generic typevars to passed params
|
||||
typevars_map: Dict[TypeVarType, Type[Any]] = dict(zip(cls.__parameters__, params))
|
||||
if all_identical(typevars_map.keys(), typevars_map.values()) and typevars_map:
|
||||
return cls # if arguments are equal to parameters it's the same object
|
||||
|
||||
# Create new model with original model as parent inserting fields with DeferredType.
|
||||
model_name = cls.__concrete_name__(params)
|
||||
validators = gather_all_validators(cls)
|
||||
|
||||
type_hints = get_all_type_hints(cls).items()
|
||||
instance_type_hints = {k: v for k, v in type_hints if get_origin(v) is not ClassVar}
|
||||
|
||||
fields = {k: (DeferredType(), cls.__fields__[k].field_info) for k in instance_type_hints if k in cls.__fields__}
|
||||
|
||||
model_module, called_globally = get_caller_frame_info()
|
||||
created_model = cast(
|
||||
Type[GenericModel], # casting ensures mypy is aware of the __concrete__ and __parameters__ attributes
|
||||
create_model(
|
||||
model_name,
|
||||
__module__=model_module or cls.__module__,
|
||||
__base__=(cls,) + tuple(cls.__parameterized_bases__(typevars_map)),
|
||||
__config__=None,
|
||||
__validators__=validators,
|
||||
__cls_kwargs__=None,
|
||||
**fields,
|
||||
),
|
||||
)
|
||||
|
||||
_assigned_parameters[created_model] = typevars_map
|
||||
|
||||
if called_globally: # create global reference and therefore allow pickling
|
||||
object_by_reference = None
|
||||
reference_name = model_name
|
||||
reference_module_globals = sys.modules[created_model.__module__].__dict__
|
||||
while object_by_reference is not created_model:
|
||||
object_by_reference = reference_module_globals.setdefault(reference_name, created_model)
|
||||
reference_name += '_'
|
||||
|
||||
created_model.Config = cls.Config
|
||||
|
||||
# Find any typevars that are still present in the model.
|
||||
# If none are left, the model is fully "concrete", otherwise the new
|
||||
# class is a generic class as well taking the found typevars as
|
||||
# parameters.
|
||||
new_params = tuple(
|
||||
{param: None for param in iter_contained_typevars(typevars_map.values())}
|
||||
) # use dict as ordered set
|
||||
created_model.__concrete__ = not new_params
|
||||
if new_params:
|
||||
created_model.__parameters__ = new_params
|
||||
|
||||
# Save created model in cache so we don't end up creating duplicate
|
||||
# models that should be identical.
|
||||
_generic_types_cache[_cache_key(params)] = created_model
|
||||
if len(params) == 1:
|
||||
_generic_types_cache[_cache_key(params[0])] = created_model
|
||||
|
||||
# Recursively walk class type hints and replace generic typevars
|
||||
# with concrete types that were passed.
|
||||
_prepare_model_fields(created_model, fields, instance_type_hints, typevars_map)
|
||||
|
||||
return created_model
|
||||
|
||||
@classmethod
|
||||
def __concrete_name__(cls: Type[Any], params: Tuple[Type[Any], ...]) -> str:
|
||||
"""Compute class name for child classes.
|
||||
|
||||
:param params: Tuple of types the class . Given a generic class
|
||||
`Model` with 2 type variables and a concrete model `Model[str, int]`,
|
||||
the value `(str, int)` would be passed to `params`.
|
||||
:return: String representing a the new class where `params` are
|
||||
passed to `cls` as type variables.
|
||||
|
||||
This method can be overridden to achieve a custom naming scheme for GenericModels.
|
||||
"""
|
||||
param_names = [display_as_type(param) for param in params]
|
||||
params_component = ', '.join(param_names)
|
||||
return f'{cls.__name__}[{params_component}]'
|
||||
|
||||
@classmethod
|
||||
def __parameterized_bases__(cls, typevars_map: Parametrization) -> Iterator[Type[Any]]:
|
||||
"""
|
||||
Returns unbound bases of cls parameterised to given type variables
|
||||
|
||||
:param typevars_map: Dictionary of type applications for binding subclasses.
|
||||
Given a generic class `Model` with 2 type variables [S, T]
|
||||
and a concrete model `Model[str, int]`,
|
||||
the value `{S: str, T: int}` would be passed to `typevars_map`.
|
||||
:return: an iterator of generic sub classes, parameterised by `typevars_map`
|
||||
and other assigned parameters of `cls`
|
||||
|
||||
e.g.:
|
||||
```
|
||||
class A(GenericModel, Generic[T]):
|
||||
...
|
||||
|
||||
class B(A[V], Generic[V]):
|
||||
...
|
||||
|
||||
assert A[int] in B.__parameterized_bases__({V: int})
|
||||
```
|
||||
"""
|
||||
|
||||
def build_base_model(
|
||||
base_model: Type[GenericModel], mapped_types: Parametrization
|
||||
) -> Iterator[Type[GenericModel]]:
|
||||
base_parameters = tuple(mapped_types[param] for param in base_model.__parameters__)
|
||||
parameterized_base = base_model.__class_getitem__(base_parameters)
|
||||
if parameterized_base is base_model or parameterized_base is cls:
|
||||
# Avoid duplication in MRO
|
||||
return
|
||||
yield parameterized_base
|
||||
|
||||
for base_model in cls.__bases__:
|
||||
if not issubclass(base_model, GenericModel):
|
||||
# not a class that can be meaningfully parameterized
|
||||
continue
|
||||
elif not getattr(base_model, '__parameters__', None):
|
||||
# base_model is "GenericModel" (and has no __parameters__)
|
||||
# or
|
||||
# base_model is already concrete, and will be included transitively via cls.
|
||||
continue
|
||||
elif cls in _assigned_parameters:
|
||||
if base_model in _assigned_parameters:
|
||||
# cls is partially parameterised but not from base_model
|
||||
# e.g. cls = B[S], base_model = A[S]
|
||||
# B[S][int] should subclass A[int], (and will be transitively via B[int])
|
||||
# but it's not viable to consistently subclass types with arbitrary construction
|
||||
# So don't attempt to include A[S][int]
|
||||
continue
|
||||
else: # base_model not in _assigned_parameters:
|
||||
# cls is partially parameterized, base_model is original generic
|
||||
# e.g. cls = B[str, T], base_model = B[S, T]
|
||||
# Need to determine the mapping for the base_model parameters
|
||||
mapped_types: Parametrization = {
|
||||
key: typevars_map.get(value, value) for key, value in _assigned_parameters[cls].items()
|
||||
}
|
||||
yield from build_base_model(base_model, mapped_types)
|
||||
else:
|
||||
# cls is base generic, so base_class has a distinct base
|
||||
# can construct the Parameterised base model using typevars_map directly
|
||||
yield from build_base_model(base_model, typevars_map)
|
||||
|
||||
|
||||
def replace_types(type_: Any, type_map: Mapping[Any, Any]) -> Any:
|
||||
"""Return type with all occurrences of `type_map` keys recursively replaced with their values.
|
||||
|
||||
:param type_: Any type, class or generic alias
|
||||
:param type_map: Mapping from `TypeVar` instance to concrete types.
|
||||
:return: New type representing the basic structure of `type_` with all
|
||||
`typevar_map` keys recursively replaced.
|
||||
|
||||
>>> replace_types(Tuple[str, Union[List[str], float]], {str: int})
|
||||
Tuple[int, Union[List[int], float]]
|
||||
|
||||
"""
|
||||
if not type_map:
|
||||
return type_
|
||||
|
||||
type_args = get_args(type_)
|
||||
origin_type = get_origin(type_)
|
||||
|
||||
if origin_type is Annotated:
|
||||
annotated_type, *annotations = type_args
|
||||
return Annotated[replace_types(annotated_type, type_map), tuple(annotations)]
|
||||
|
||||
if (origin_type is ExtLiteral) or (sys.version_info >= (3, 8) and origin_type is Literal):
|
||||
return type_map.get(type_, type_)
|
||||
# Having type args is a good indicator that this is a typing module
|
||||
# class instantiation or a generic alias of some sort.
|
||||
if type_args:
|
||||
resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
|
||||
if all_identical(type_args, resolved_type_args):
|
||||
# If all arguments are the same, there is no need to modify the
|
||||
# type or create a new object at all
|
||||
return type_
|
||||
if (
|
||||
origin_type is not None
|
||||
and isinstance(type_, typing_base)
|
||||
and not isinstance(origin_type, typing_base)
|
||||
and getattr(type_, '_name', None) is not None
|
||||
):
|
||||
# In python < 3.9 generic aliases don't exist so any of these like `list`,
|
||||
# `type` or `collections.abc.Callable` need to be translated.
|
||||
# See: https://www.python.org/dev/peps/pep-0585
|
||||
origin_type = getattr(typing, type_._name)
|
||||
assert origin_type is not None
|
||||
# PEP-604 syntax (Ex.: list | str) is represented with a types.UnionType object that does not have __getitem__.
|
||||
# We also cannot use isinstance() since we have to compare types.
|
||||
if sys.version_info >= (3, 10) and origin_type is types.UnionType: # noqa: E721
|
||||
return _UnionGenericAlias(origin_type, resolved_type_args)
|
||||
return origin_type[resolved_type_args]
|
||||
|
||||
# We handle pydantic generic models separately as they don't have the same
|
||||
# semantics as "typing" classes or generic aliases
|
||||
if not origin_type and lenient_issubclass(type_, GenericModel) and not type_.__concrete__:
|
||||
type_args = type_.__parameters__
|
||||
resolved_type_args = tuple(replace_types(t, type_map) for t in type_args)
|
||||
if all_identical(type_args, resolved_type_args):
|
||||
return type_
|
||||
return type_[resolved_type_args]
|
||||
|
||||
# Handle special case for typehints that can have lists as arguments.
|
||||
# `typing.Callable[[int, str], int]` is an example for this.
|
||||
if isinstance(type_, (List, list)):
|
||||
resolved_list = list(replace_types(element, type_map) for element in type_)
|
||||
if all_identical(type_, resolved_list):
|
||||
return type_
|
||||
return resolved_list
|
||||
|
||||
# For JsonWrapperValue, need to handle its inner type to allow correct parsing
|
||||
# of generic Json arguments like Json[T]
|
||||
if not origin_type and lenient_issubclass(type_, JsonWrapper):
|
||||
type_.inner_type = replace_types(type_.inner_type, type_map)
|
||||
return type_
|
||||
|
||||
# If all else fails, we try to resolve the type directly and otherwise just
|
||||
# return the input with no modifications.
|
||||
new_type = type_map.get(type_, type_)
|
||||
# Convert string to ForwardRef
|
||||
if isinstance(new_type, str):
|
||||
return ForwardRef(new_type)
|
||||
else:
|
||||
return new_type
|
||||
|
||||
|
||||
def check_parameters_count(cls: Type[GenericModel], parameters: Tuple[Any, ...]) -> None:
|
||||
actual = len(parameters)
|
||||
expected = len(cls.__parameters__)
|
||||
if actual != expected:
|
||||
description = 'many' if actual > expected else 'few'
|
||||
raise TypeError(f'Too {description} parameters for {cls.__name__}; actual {actual}, expected {expected}')
|
||||
|
||||
|
||||
DictValues: Type[Any] = {}.values().__class__
|
||||
|
||||
|
||||
def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]:
|
||||
"""Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found."""
|
||||
if isinstance(v, TypeVar):
|
||||
yield v
|
||||
elif hasattr(v, '__parameters__') and not get_origin(v) and lenient_issubclass(v, GenericModel):
|
||||
yield from v.__parameters__
|
||||
elif isinstance(v, (DictValues, list)):
|
||||
for var in v:
|
||||
yield from iter_contained_typevars(var)
|
||||
else:
|
||||
args = get_args(v)
|
||||
for arg in args:
|
||||
yield from iter_contained_typevars(arg)
|
||||
|
||||
|
||||
def get_caller_frame_info() -> Tuple[Optional[str], bool]:
|
||||
"""
|
||||
Used inside a function to check whether it was called globally
|
||||
|
||||
Will only work against non-compiled code, therefore used only in pydantic.generics
|
||||
|
||||
:returns Tuple[module_name, called_globally]
|
||||
"""
|
||||
try:
|
||||
previous_caller_frame = sys._getframe(2)
|
||||
except ValueError as e:
|
||||
raise RuntimeError('This function must be used inside another function') from e
|
||||
except AttributeError: # sys module does not have _getframe function, so there's nothing we can do about it
|
||||
return None, False
|
||||
frame_globals = previous_caller_frame.f_globals
|
||||
return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals
|
||||
|
||||
|
||||
def _prepare_model_fields(
|
||||
created_model: Type[GenericModel],
|
||||
fields: Mapping[str, Any],
|
||||
instance_type_hints: Mapping[str, type],
|
||||
typevars_map: Mapping[Any, type],
|
||||
) -> None:
|
||||
"""
|
||||
Replace DeferredType fields with concrete type hints and prepare them.
|
||||
"""
|
||||
|
||||
for key, field in created_model.__fields__.items():
|
||||
if key not in fields:
|
||||
assert field.type_.__class__ is not DeferredType
|
||||
# https://github.com/nedbat/coveragepy/issues/198
|
||||
continue # pragma: no cover
|
||||
|
||||
assert field.type_.__class__ is DeferredType, field.type_.__class__
|
||||
|
||||
field_type_hint = instance_type_hints[key]
|
||||
concrete_type = replace_types(field_type_hint, typevars_map)
|
||||
field.type_ = concrete_type
|
||||
field.outer_type_ = concrete_type
|
||||
field.prepare()
|
||||
created_model.__annotations__[key] = concrete_type
|
||||
@@ -0,0 +1,112 @@
|
||||
import datetime
|
||||
from collections import deque
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
||||
from pathlib import Path
|
||||
from re import Pattern
|
||||
from types import GeneratorType
|
||||
from typing import Any, Callable, Dict, Type, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic.v1.color import Color
|
||||
from pydantic.v1.networks import NameEmail
|
||||
from pydantic.v1.types import SecretBytes, SecretStr
|
||||
|
||||
__all__ = 'pydantic_encoder', 'custom_pydantic_encoder', 'timedelta_isoformat'
|
||||
|
||||
|
||||
def isoformat(o: Union[datetime.date, datetime.time]) -> str:
|
||||
return o.isoformat()
|
||||
|
||||
|
||||
def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
|
||||
"""
|
||||
Encodes a Decimal as int of there's no exponent, otherwise float
|
||||
|
||||
This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
|
||||
where a integer (but not int typed) is used. Encoding this as a float
|
||||
results in failed round-tripping between encode and parse.
|
||||
Our Id type is a prime example of this.
|
||||
|
||||
>>> decimal_encoder(Decimal("1.0"))
|
||||
1.0
|
||||
|
||||
>>> decimal_encoder(Decimal("1"))
|
||||
1
|
||||
"""
|
||||
if dec_value.as_tuple().exponent >= 0:
|
||||
return int(dec_value)
|
||||
else:
|
||||
return float(dec_value)
|
||||
|
||||
|
||||
ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
|
||||
bytes: lambda o: o.decode(),
|
||||
Color: str,
|
||||
datetime.date: isoformat,
|
||||
datetime.datetime: isoformat,
|
||||
datetime.time: isoformat,
|
||||
datetime.timedelta: lambda td: td.total_seconds(),
|
||||
Decimal: decimal_encoder,
|
||||
Enum: lambda o: o.value,
|
||||
frozenset: list,
|
||||
deque: list,
|
||||
GeneratorType: list,
|
||||
IPv4Address: str,
|
||||
IPv4Interface: str,
|
||||
IPv4Network: str,
|
||||
IPv6Address: str,
|
||||
IPv6Interface: str,
|
||||
IPv6Network: str,
|
||||
NameEmail: str,
|
||||
Path: str,
|
||||
Pattern: lambda o: o.pattern,
|
||||
SecretBytes: str,
|
||||
SecretStr: str,
|
||||
set: list,
|
||||
UUID: str,
|
||||
}
|
||||
|
||||
|
||||
def pydantic_encoder(obj: Any) -> Any:
|
||||
from dataclasses import asdict, is_dataclass
|
||||
|
||||
from pydantic.v1.main import BaseModel
|
||||
|
||||
if isinstance(obj, BaseModel):
|
||||
return obj.dict()
|
||||
elif is_dataclass(obj):
|
||||
return asdict(obj)
|
||||
|
||||
# Check the class type and its superclasses for a matching encoder
|
||||
for base in obj.__class__.__mro__[:-1]:
|
||||
try:
|
||||
encoder = ENCODERS_BY_TYPE[base]
|
||||
except KeyError:
|
||||
continue
|
||||
return encoder(obj)
|
||||
else: # We have exited the for loop without finding a suitable encoder
|
||||
raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable")
|
||||
|
||||
|
||||
def custom_pydantic_encoder(type_encoders: Dict[Any, Callable[[Type[Any]], Any]], obj: Any) -> Any:
|
||||
# Check the class type and its superclasses for a matching encoder
|
||||
for base in obj.__class__.__mro__[:-1]:
|
||||
try:
|
||||
encoder = type_encoders[base]
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
return encoder(obj)
|
||||
else: # We have exited the for loop without finding a suitable encoder
|
||||
return pydantic_encoder(obj)
|
||||
|
||||
|
||||
def timedelta_isoformat(td: datetime.timedelta) -> str:
|
||||
"""
|
||||
ISO 8601 encoding for Python timedelta object.
|
||||
"""
|
||||
minutes, seconds = divmod(td.seconds, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
return f'{"-" if td.days < 0 else ""}P{abs(td.days)}DT{hours:d}H{minutes:d}M{seconds:d}.{td.microseconds:06d}S'
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,949 @@
|
||||
import sys
|
||||
from configparser import ConfigParser
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type as TypingType, Union
|
||||
|
||||
from mypy.errorcodes import ErrorCode
|
||||
from mypy.nodes import (
|
||||
ARG_NAMED,
|
||||
ARG_NAMED_OPT,
|
||||
ARG_OPT,
|
||||
ARG_POS,
|
||||
ARG_STAR2,
|
||||
MDEF,
|
||||
Argument,
|
||||
AssignmentStmt,
|
||||
Block,
|
||||
CallExpr,
|
||||
ClassDef,
|
||||
Context,
|
||||
Decorator,
|
||||
EllipsisExpr,
|
||||
FuncBase,
|
||||
FuncDef,
|
||||
JsonDict,
|
||||
MemberExpr,
|
||||
NameExpr,
|
||||
PassStmt,
|
||||
PlaceholderNode,
|
||||
RefExpr,
|
||||
StrExpr,
|
||||
SymbolNode,
|
||||
SymbolTableNode,
|
||||
TempNode,
|
||||
TypeInfo,
|
||||
TypeVarExpr,
|
||||
Var,
|
||||
)
|
||||
from mypy.options import Options
|
||||
from mypy.plugin import (
|
||||
CheckerPluginInterface,
|
||||
ClassDefContext,
|
||||
FunctionContext,
|
||||
MethodContext,
|
||||
Plugin,
|
||||
ReportConfigContext,
|
||||
SemanticAnalyzerPluginInterface,
|
||||
)
|
||||
from mypy.plugins import dataclasses
|
||||
from mypy.semanal import set_callable_name # type: ignore
|
||||
from mypy.server.trigger import make_wildcard_trigger
|
||||
from mypy.types import (
|
||||
AnyType,
|
||||
CallableType,
|
||||
Instance,
|
||||
NoneType,
|
||||
Overloaded,
|
||||
ProperType,
|
||||
Type,
|
||||
TypeOfAny,
|
||||
TypeType,
|
||||
TypeVarId,
|
||||
TypeVarType,
|
||||
UnionType,
|
||||
get_proper_type,
|
||||
)
|
||||
from mypy.typevars import fill_typevars
|
||||
from mypy.util import get_unique_redefinition_name
|
||||
from mypy.version import __version__ as mypy_version
|
||||
|
||||
from pydantic.v1.utils import is_valid_field
|
||||
|
||||
try:
|
||||
from mypy.types import TypeVarDef # type: ignore[attr-defined]
|
||||
except ImportError: # pragma: no cover
|
||||
# Backward-compatible with TypeVarDef from Mypy 0.910.
|
||||
from mypy.types import TypeVarType as TypeVarDef
|
||||
|
||||
CONFIGFILE_KEY = 'pydantic-mypy'
|
||||
METADATA_KEY = 'pydantic-mypy-metadata'
|
||||
_NAMESPACE = __name__[:-5] # 'pydantic' in 1.10.X, 'pydantic.v1' in v2.X
|
||||
BASEMODEL_FULLNAME = f'{_NAMESPACE}.main.BaseModel'
|
||||
BASESETTINGS_FULLNAME = f'{_NAMESPACE}.env_settings.BaseSettings'
|
||||
MODEL_METACLASS_FULLNAME = f'{_NAMESPACE}.main.ModelMetaclass'
|
||||
FIELD_FULLNAME = f'{_NAMESPACE}.fields.Field'
|
||||
DATACLASS_FULLNAME = f'{_NAMESPACE}.dataclasses.dataclass'
|
||||
|
||||
|
||||
def parse_mypy_version(version: str) -> Tuple[int, ...]:
|
||||
return tuple(map(int, version.partition('+')[0].split('.')))
|
||||
|
||||
|
||||
MYPY_VERSION_TUPLE = parse_mypy_version(mypy_version)
|
||||
BUILTINS_NAME = 'builtins' if MYPY_VERSION_TUPLE >= (0, 930) else '__builtins__'
|
||||
|
||||
# Increment version if plugin changes and mypy caches should be invalidated
|
||||
__version__ = 2
|
||||
|
||||
|
||||
def plugin(version: str) -> 'TypingType[Plugin]':
|
||||
"""
|
||||
`version` is the mypy version string
|
||||
|
||||
We might want to use this to print a warning if the mypy version being used is
|
||||
newer, or especially older, than we expect (or need).
|
||||
"""
|
||||
return PydanticPlugin
|
||||
|
||||
|
||||
class PydanticPlugin(Plugin):
|
||||
def __init__(self, options: Options) -> None:
|
||||
self.plugin_config = PydanticPluginConfig(options)
|
||||
self._plugin_data = self.plugin_config.to_data()
|
||||
super().__init__(options)
|
||||
|
||||
def get_base_class_hook(self, fullname: str) -> 'Optional[Callable[[ClassDefContext], None]]':
|
||||
sym = self.lookup_fully_qualified(fullname)
|
||||
if sym and isinstance(sym.node, TypeInfo): # pragma: no branch
|
||||
# No branching may occur if the mypy cache has not been cleared
|
||||
if any(get_fullname(base) == BASEMODEL_FULLNAME for base in sym.node.mro):
|
||||
return self._pydantic_model_class_maker_callback
|
||||
return None
|
||||
|
||||
def get_metaclass_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
|
||||
if fullname == MODEL_METACLASS_FULLNAME:
|
||||
return self._pydantic_model_metaclass_marker_callback
|
||||
return None
|
||||
|
||||
def get_function_hook(self, fullname: str) -> 'Optional[Callable[[FunctionContext], Type]]':
|
||||
sym = self.lookup_fully_qualified(fullname)
|
||||
if sym and sym.fullname == FIELD_FULLNAME:
|
||||
return self._pydantic_field_callback
|
||||
return None
|
||||
|
||||
def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], Type]]:
|
||||
if fullname.endswith('.from_orm'):
|
||||
return from_orm_callback
|
||||
return None
|
||||
|
||||
def get_class_decorator_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
|
||||
"""Mark pydantic.dataclasses as dataclass.
|
||||
|
||||
Mypy version 1.1.1 added support for `@dataclass_transform` decorator.
|
||||
"""
|
||||
if fullname == DATACLASS_FULLNAME and MYPY_VERSION_TUPLE < (1, 1):
|
||||
return dataclasses.dataclass_class_maker_callback # type: ignore[return-value]
|
||||
return None
|
||||
|
||||
def report_config_data(self, ctx: ReportConfigContext) -> Dict[str, Any]:
|
||||
"""Return all plugin config data.
|
||||
|
||||
Used by mypy to determine if cache needs to be discarded.
|
||||
"""
|
||||
return self._plugin_data
|
||||
|
||||
def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> None:
|
||||
transformer = PydanticModelTransformer(ctx, self.plugin_config)
|
||||
transformer.transform()
|
||||
|
||||
def _pydantic_model_metaclass_marker_callback(self, ctx: ClassDefContext) -> None:
|
||||
"""Reset dataclass_transform_spec attribute of ModelMetaclass.
|
||||
|
||||
Let the plugin handle it. This behavior can be disabled
|
||||
if 'debug_dataclass_transform' is set to True', for testing purposes.
|
||||
"""
|
||||
if self.plugin_config.debug_dataclass_transform:
|
||||
return
|
||||
info_metaclass = ctx.cls.info.declared_metaclass
|
||||
assert info_metaclass, "callback not passed from 'get_metaclass_hook'"
|
||||
if getattr(info_metaclass.type, 'dataclass_transform_spec', None):
|
||||
info_metaclass.type.dataclass_transform_spec = None # type: ignore[attr-defined]
|
||||
|
||||
def _pydantic_field_callback(self, ctx: FunctionContext) -> 'Type':
|
||||
"""
|
||||
Extract the type of the `default` argument from the Field function, and use it as the return type.
|
||||
|
||||
In particular:
|
||||
* Check whether the default and default_factory argument is specified.
|
||||
* Output an error if both are specified.
|
||||
* Retrieve the type of the argument which is specified, and use it as return type for the function.
|
||||
"""
|
||||
default_any_type = ctx.default_return_type
|
||||
|
||||
assert ctx.callee_arg_names[0] == 'default', '"default" is no longer first argument in Field()'
|
||||
assert ctx.callee_arg_names[1] == 'default_factory', '"default_factory" is no longer second argument in Field()'
|
||||
default_args = ctx.args[0]
|
||||
default_factory_args = ctx.args[1]
|
||||
|
||||
if default_args and default_factory_args:
|
||||
error_default_and_default_factory_specified(ctx.api, ctx.context)
|
||||
return default_any_type
|
||||
|
||||
if default_args:
|
||||
default_type = ctx.arg_types[0][0]
|
||||
default_arg = default_args[0]
|
||||
|
||||
# Fallback to default Any type if the field is required
|
||||
if not isinstance(default_arg, EllipsisExpr):
|
||||
return default_type
|
||||
|
||||
elif default_factory_args:
|
||||
default_factory_type = ctx.arg_types[1][0]
|
||||
|
||||
# Functions which use `ParamSpec` can be overloaded, exposing the callable's types as a parameter
|
||||
# Pydantic calls the default factory without any argument, so we retrieve the first item
|
||||
if isinstance(default_factory_type, Overloaded):
|
||||
if MYPY_VERSION_TUPLE > (0, 910):
|
||||
default_factory_type = default_factory_type.items[0]
|
||||
else:
|
||||
# Mypy0.910 exposes the items of overloaded types in a function
|
||||
default_factory_type = default_factory_type.items()[0] # type: ignore[operator]
|
||||
|
||||
if isinstance(default_factory_type, CallableType):
|
||||
ret_type = default_factory_type.ret_type
|
||||
# mypy doesn't think `ret_type` has `args`, you'd think mypy should know,
|
||||
# add this check in case it varies by version
|
||||
args = getattr(ret_type, 'args', None)
|
||||
if args:
|
||||
if all(isinstance(arg, TypeVarType) for arg in args):
|
||||
# Looks like the default factory is a type like `list` or `dict`, replace all args with `Any`
|
||||
ret_type.args = tuple(default_any_type for _ in args) # type: ignore[attr-defined]
|
||||
return ret_type
|
||||
|
||||
return default_any_type
|
||||
|
||||
|
||||
class PydanticPluginConfig:
|
||||
__slots__ = (
|
||||
'init_forbid_extra',
|
||||
'init_typed',
|
||||
'warn_required_dynamic_aliases',
|
||||
'warn_untyped_fields',
|
||||
'debug_dataclass_transform',
|
||||
)
|
||||
init_forbid_extra: bool
|
||||
init_typed: bool
|
||||
warn_required_dynamic_aliases: bool
|
||||
warn_untyped_fields: bool
|
||||
debug_dataclass_transform: bool # undocumented
|
||||
|
||||
def __init__(self, options: Options) -> None:
|
||||
if options.config_file is None: # pragma: no cover
|
||||
return
|
||||
|
||||
toml_config = parse_toml(options.config_file)
|
||||
if toml_config is not None:
|
||||
config = toml_config.get('tool', {}).get('pydantic-mypy', {})
|
||||
for key in self.__slots__:
|
||||
setting = config.get(key, False)
|
||||
if not isinstance(setting, bool):
|
||||
raise ValueError(f'Configuration value must be a boolean for key: {key}')
|
||||
setattr(self, key, setting)
|
||||
else:
|
||||
plugin_config = ConfigParser()
|
||||
plugin_config.read(options.config_file)
|
||||
for key in self.__slots__:
|
||||
setting = plugin_config.getboolean(CONFIGFILE_KEY, key, fallback=False)
|
||||
setattr(self, key, setting)
|
||||
|
||||
def to_data(self) -> Dict[str, Any]:
|
||||
return {key: getattr(self, key) for key in self.__slots__}
|
||||
|
||||
|
||||
def from_orm_callback(ctx: MethodContext) -> Type:
|
||||
"""
|
||||
Raise an error if orm_mode is not enabled
|
||||
"""
|
||||
model_type: Instance
|
||||
ctx_type = ctx.type
|
||||
if isinstance(ctx_type, TypeType):
|
||||
ctx_type = ctx_type.item
|
||||
if isinstance(ctx_type, CallableType) and isinstance(ctx_type.ret_type, Instance):
|
||||
model_type = ctx_type.ret_type # called on the class
|
||||
elif isinstance(ctx_type, Instance):
|
||||
model_type = ctx_type # called on an instance (unusual, but still valid)
|
||||
else: # pragma: no cover
|
||||
detail = f'ctx.type: {ctx_type} (of type {ctx_type.__class__.__name__})'
|
||||
error_unexpected_behavior(detail, ctx.api, ctx.context)
|
||||
return ctx.default_return_type
|
||||
pydantic_metadata = model_type.type.metadata.get(METADATA_KEY)
|
||||
if pydantic_metadata is None:
|
||||
return ctx.default_return_type
|
||||
orm_mode = pydantic_metadata.get('config', {}).get('orm_mode')
|
||||
if orm_mode is not True:
|
||||
error_from_orm(get_name(model_type.type), ctx.api, ctx.context)
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
class PydanticModelTransformer:
|
||||
tracked_config_fields: Set[str] = {
|
||||
'extra',
|
||||
'allow_mutation',
|
||||
'frozen',
|
||||
'orm_mode',
|
||||
'allow_population_by_field_name',
|
||||
'alias_generator',
|
||||
}
|
||||
|
||||
def __init__(self, ctx: ClassDefContext, plugin_config: PydanticPluginConfig) -> None:
|
||||
self._ctx = ctx
|
||||
self.plugin_config = plugin_config
|
||||
|
||||
def transform(self) -> None:
|
||||
"""
|
||||
Configures the BaseModel subclass according to the plugin settings.
|
||||
|
||||
In particular:
|
||||
* determines the model config and fields,
|
||||
* adds a fields-aware signature for the initializer and construct methods
|
||||
* freezes the class if allow_mutation = False or frozen = True
|
||||
* stores the fields, config, and if the class is settings in the mypy metadata for access by subclasses
|
||||
"""
|
||||
ctx = self._ctx
|
||||
info = ctx.cls.info
|
||||
|
||||
self.adjust_validator_signatures()
|
||||
config = self.collect_config()
|
||||
fields = self.collect_fields(config)
|
||||
is_settings = any(get_fullname(base) == BASESETTINGS_FULLNAME for base in info.mro[:-1])
|
||||
self.add_initializer(fields, config, is_settings)
|
||||
self.add_construct_method(fields)
|
||||
self.set_frozen(fields, frozen=config.allow_mutation is False or config.frozen is True)
|
||||
info.metadata[METADATA_KEY] = {
|
||||
'fields': {field.name: field.serialize() for field in fields},
|
||||
'config': config.set_values_dict(),
|
||||
}
|
||||
|
||||
def adjust_validator_signatures(self) -> None:
|
||||
"""When we decorate a function `f` with `pydantic.validator(...), mypy sees
|
||||
`f` as a regular method taking a `self` instance, even though pydantic
|
||||
internally wraps `f` with `classmethod` if necessary.
|
||||
|
||||
Teach mypy this by marking any function whose outermost decorator is a
|
||||
`validator()` call as a classmethod.
|
||||
"""
|
||||
for name, sym in self._ctx.cls.info.names.items():
|
||||
if isinstance(sym.node, Decorator):
|
||||
first_dec = sym.node.original_decorators[0]
|
||||
if (
|
||||
isinstance(first_dec, CallExpr)
|
||||
and isinstance(first_dec.callee, NameExpr)
|
||||
and first_dec.callee.fullname == f'{_NAMESPACE}.class_validators.validator'
|
||||
):
|
||||
sym.node.func.is_class = True
|
||||
|
||||
def collect_config(self) -> 'ModelConfigData':
|
||||
"""
|
||||
Collects the values of the config attributes that are used by the plugin, accounting for parent classes.
|
||||
"""
|
||||
ctx = self._ctx
|
||||
cls = ctx.cls
|
||||
config = ModelConfigData()
|
||||
for stmt in cls.defs.body:
|
||||
if not isinstance(stmt, ClassDef):
|
||||
continue
|
||||
if stmt.name == 'Config':
|
||||
for substmt in stmt.defs.body:
|
||||
if not isinstance(substmt, AssignmentStmt):
|
||||
continue
|
||||
config.update(self.get_config_update(substmt))
|
||||
if (
|
||||
config.has_alias_generator
|
||||
and not config.allow_population_by_field_name
|
||||
and self.plugin_config.warn_required_dynamic_aliases
|
||||
):
|
||||
error_required_dynamic_aliases(ctx.api, stmt)
|
||||
for info in cls.info.mro[1:]: # 0 is the current class
|
||||
if METADATA_KEY not in info.metadata:
|
||||
continue
|
||||
|
||||
# Each class depends on the set of fields in its ancestors
|
||||
ctx.api.add_plugin_dependency(make_wildcard_trigger(get_fullname(info)))
|
||||
for name, value in info.metadata[METADATA_KEY]['config'].items():
|
||||
config.setdefault(name, value)
|
||||
return config
|
||||
|
||||
def collect_fields(self, model_config: 'ModelConfigData') -> List['PydanticModelField']:
|
||||
"""
|
||||
Collects the fields for the model, accounting for parent classes
|
||||
"""
|
||||
# First, collect fields belonging to the current class.
|
||||
ctx = self._ctx
|
||||
cls = self._ctx.cls
|
||||
fields = [] # type: List[PydanticModelField]
|
||||
known_fields = set() # type: Set[str]
|
||||
for stmt in cls.defs.body:
|
||||
if not isinstance(stmt, AssignmentStmt): # `and stmt.new_syntax` to require annotation
|
||||
continue
|
||||
|
||||
lhs = stmt.lvalues[0]
|
||||
if not isinstance(lhs, NameExpr) or not is_valid_field(lhs.name):
|
||||
continue
|
||||
|
||||
if not stmt.new_syntax and self.plugin_config.warn_untyped_fields:
|
||||
error_untyped_fields(ctx.api, stmt)
|
||||
|
||||
# if lhs.name == '__config__': # BaseConfig not well handled; I'm not sure why yet
|
||||
# continue
|
||||
|
||||
sym = cls.info.names.get(lhs.name)
|
||||
if sym is None: # pragma: no cover
|
||||
# This is likely due to a star import (see the dataclasses plugin for a more detailed explanation)
|
||||
# This is the same logic used in the dataclasses plugin
|
||||
continue
|
||||
|
||||
node = sym.node
|
||||
if isinstance(node, PlaceholderNode): # pragma: no cover
|
||||
# See the PlaceholderNode docstring for more detail about how this can occur
|
||||
# Basically, it is an edge case when dealing with complex import logic
|
||||
# This is the same logic used in the dataclasses plugin
|
||||
continue
|
||||
if not isinstance(node, Var): # pragma: no cover
|
||||
# Don't know if this edge case still happens with the `is_valid_field` check above
|
||||
# but better safe than sorry
|
||||
continue
|
||||
|
||||
# x: ClassVar[int] is ignored by dataclasses.
|
||||
if node.is_classvar:
|
||||
continue
|
||||
|
||||
is_required = self.get_is_required(cls, stmt, lhs)
|
||||
alias, has_dynamic_alias = self.get_alias_info(stmt)
|
||||
if (
|
||||
has_dynamic_alias
|
||||
and not model_config.allow_population_by_field_name
|
||||
and self.plugin_config.warn_required_dynamic_aliases
|
||||
):
|
||||
error_required_dynamic_aliases(ctx.api, stmt)
|
||||
fields.append(
|
||||
PydanticModelField(
|
||||
name=lhs.name,
|
||||
is_required=is_required,
|
||||
alias=alias,
|
||||
has_dynamic_alias=has_dynamic_alias,
|
||||
line=stmt.line,
|
||||
column=stmt.column,
|
||||
)
|
||||
)
|
||||
known_fields.add(lhs.name)
|
||||
all_fields = fields.copy()
|
||||
for info in cls.info.mro[1:]: # 0 is the current class, -2 is BaseModel, -1 is object
|
||||
if METADATA_KEY not in info.metadata:
|
||||
continue
|
||||
|
||||
superclass_fields = []
|
||||
# Each class depends on the set of fields in its ancestors
|
||||
ctx.api.add_plugin_dependency(make_wildcard_trigger(get_fullname(info)))
|
||||
|
||||
for name, data in info.metadata[METADATA_KEY]['fields'].items():
|
||||
if name not in known_fields:
|
||||
field = PydanticModelField.deserialize(info, data)
|
||||
known_fields.add(name)
|
||||
superclass_fields.append(field)
|
||||
else:
|
||||
(field,) = (a for a in all_fields if a.name == name)
|
||||
all_fields.remove(field)
|
||||
superclass_fields.append(field)
|
||||
all_fields = superclass_fields + all_fields
|
||||
return all_fields
|
||||
|
||||
def add_initializer(self, fields: List['PydanticModelField'], config: 'ModelConfigData', is_settings: bool) -> None:
|
||||
"""
|
||||
Adds a fields-aware `__init__` method to the class.
|
||||
|
||||
The added `__init__` will be annotated with types vs. all `Any` depending on the plugin settings.
|
||||
"""
|
||||
ctx = self._ctx
|
||||
typed = self.plugin_config.init_typed
|
||||
use_alias = config.allow_population_by_field_name is not True
|
||||
force_all_optional = is_settings or bool(
|
||||
config.has_alias_generator and not config.allow_population_by_field_name
|
||||
)
|
||||
init_arguments = self.get_field_arguments(
|
||||
fields, typed=typed, force_all_optional=force_all_optional, use_alias=use_alias
|
||||
)
|
||||
if not self.should_init_forbid_extra(fields, config):
|
||||
var = Var('kwargs')
|
||||
init_arguments.append(Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2))
|
||||
|
||||
if '__init__' not in ctx.cls.info.names:
|
||||
add_method(ctx, '__init__', init_arguments, NoneType())
|
||||
|
||||
def add_construct_method(self, fields: List['PydanticModelField']) -> None:
|
||||
"""
|
||||
Adds a fully typed `construct` classmethod to the class.
|
||||
|
||||
Similar to the fields-aware __init__ method, but always uses the field names (not aliases),
|
||||
and does not treat settings fields as optional.
|
||||
"""
|
||||
ctx = self._ctx
|
||||
set_str = ctx.api.named_type(f'{BUILTINS_NAME}.set', [ctx.api.named_type(f'{BUILTINS_NAME}.str')])
|
||||
optional_set_str = UnionType([set_str, NoneType()])
|
||||
fields_set_argument = Argument(Var('_fields_set', optional_set_str), optional_set_str, None, ARG_OPT)
|
||||
construct_arguments = self.get_field_arguments(fields, typed=True, force_all_optional=False, use_alias=False)
|
||||
construct_arguments = [fields_set_argument] + construct_arguments
|
||||
|
||||
obj_type = ctx.api.named_type(f'{BUILTINS_NAME}.object')
|
||||
self_tvar_name = '_PydanticBaseModel' # Make sure it does not conflict with other names in the class
|
||||
tvar_fullname = ctx.cls.fullname + '.' + self_tvar_name
|
||||
if MYPY_VERSION_TUPLE >= (1, 4):
|
||||
tvd = TypeVarType(
|
||||
self_tvar_name,
|
||||
tvar_fullname,
|
||||
(
|
||||
TypeVarId(-1, namespace=ctx.cls.fullname + '.construct')
|
||||
if MYPY_VERSION_TUPLE >= (1, 11)
|
||||
else TypeVarId(-1)
|
||||
),
|
||||
[],
|
||||
obj_type,
|
||||
AnyType(TypeOfAny.from_omitted_generics), # type: ignore[arg-type]
|
||||
)
|
||||
self_tvar_expr = TypeVarExpr(
|
||||
self_tvar_name,
|
||||
tvar_fullname,
|
||||
[],
|
||||
obj_type,
|
||||
AnyType(TypeOfAny.from_omitted_generics), # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
tvd = TypeVarDef(self_tvar_name, tvar_fullname, -1, [], obj_type)
|
||||
self_tvar_expr = TypeVarExpr(self_tvar_name, tvar_fullname, [], obj_type)
|
||||
ctx.cls.info.names[self_tvar_name] = SymbolTableNode(MDEF, self_tvar_expr)
|
||||
|
||||
# Backward-compatible with TypeVarDef from Mypy 0.910.
|
||||
if isinstance(tvd, TypeVarType):
|
||||
self_type = tvd
|
||||
else:
|
||||
self_type = TypeVarType(tvd)
|
||||
|
||||
add_method(
|
||||
ctx,
|
||||
'construct',
|
||||
construct_arguments,
|
||||
return_type=self_type,
|
||||
self_type=self_type,
|
||||
tvar_def=tvd,
|
||||
is_classmethod=True,
|
||||
)
|
||||
|
||||
def set_frozen(self, fields: List['PydanticModelField'], frozen: bool) -> None:
|
||||
"""
|
||||
Marks all fields as properties so that attempts to set them trigger mypy errors.
|
||||
|
||||
This is the same approach used by the attrs and dataclasses plugins.
|
||||
"""
|
||||
ctx = self._ctx
|
||||
info = ctx.cls.info
|
||||
for field in fields:
|
||||
sym_node = info.names.get(field.name)
|
||||
if sym_node is not None:
|
||||
var = sym_node.node
|
||||
if isinstance(var, Var):
|
||||
var.is_property = frozen
|
||||
elif isinstance(var, PlaceholderNode) and not ctx.api.final_iteration:
|
||||
# See https://github.com/pydantic/pydantic/issues/5191 to hit this branch for test coverage
|
||||
ctx.api.defer()
|
||||
else: # pragma: no cover
|
||||
# I don't know whether it's possible to hit this branch, but I've added it for safety
|
||||
try:
|
||||
var_str = str(var)
|
||||
except TypeError:
|
||||
# This happens for PlaceholderNode; perhaps it will happen for other types in the future..
|
||||
var_str = repr(var)
|
||||
detail = f'sym_node.node: {var_str} (of type {var.__class__})'
|
||||
error_unexpected_behavior(detail, ctx.api, ctx.cls)
|
||||
else:
|
||||
var = field.to_var(info, use_alias=False)
|
||||
var.info = info
|
||||
var.is_property = frozen
|
||||
var._fullname = get_fullname(info) + '.' + get_name(var)
|
||||
info.names[get_name(var)] = SymbolTableNode(MDEF, var)
|
||||
|
||||
def get_config_update(self, substmt: AssignmentStmt) -> Optional['ModelConfigData']:
|
||||
"""
|
||||
Determines the config update due to a single statement in the Config class definition.
|
||||
|
||||
Warns if a tracked config attribute is set to a value the plugin doesn't know how to interpret (e.g., an int)
|
||||
"""
|
||||
lhs = substmt.lvalues[0]
|
||||
if not (isinstance(lhs, NameExpr) and lhs.name in self.tracked_config_fields):
|
||||
return None
|
||||
if lhs.name == 'extra':
|
||||
if isinstance(substmt.rvalue, StrExpr):
|
||||
forbid_extra = substmt.rvalue.value == 'forbid'
|
||||
elif isinstance(substmt.rvalue, MemberExpr):
|
||||
forbid_extra = substmt.rvalue.name == 'forbid'
|
||||
else:
|
||||
error_invalid_config_value(lhs.name, self._ctx.api, substmt)
|
||||
return None
|
||||
return ModelConfigData(forbid_extra=forbid_extra)
|
||||
if lhs.name == 'alias_generator':
|
||||
has_alias_generator = True
|
||||
if isinstance(substmt.rvalue, NameExpr) and substmt.rvalue.fullname == 'builtins.None':
|
||||
has_alias_generator = False
|
||||
return ModelConfigData(has_alias_generator=has_alias_generator)
|
||||
if isinstance(substmt.rvalue, NameExpr) and substmt.rvalue.fullname in ('builtins.True', 'builtins.False'):
|
||||
return ModelConfigData(**{lhs.name: substmt.rvalue.fullname == 'builtins.True'})
|
||||
error_invalid_config_value(lhs.name, self._ctx.api, substmt)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_is_required(cls: ClassDef, stmt: AssignmentStmt, lhs: NameExpr) -> bool:
|
||||
"""
|
||||
Returns a boolean indicating whether the field defined in `stmt` is a required field.
|
||||
"""
|
||||
expr = stmt.rvalue
|
||||
if isinstance(expr, TempNode):
|
||||
# TempNode means annotation-only, so only non-required if Optional
|
||||
value_type = get_proper_type(cls.info[lhs.name].type)
|
||||
return not PydanticModelTransformer.type_has_implicit_default(value_type)
|
||||
if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME:
|
||||
# The "default value" is a call to `Field`; at this point, the field is
|
||||
# only required if default is Ellipsis (i.e., `field_name: Annotation = Field(...)`) or if default_factory
|
||||
# is specified.
|
||||
for arg, name in zip(expr.args, expr.arg_names):
|
||||
# If name is None, then this arg is the default because it is the only positional argument.
|
||||
if name is None or name == 'default':
|
||||
return arg.__class__ is EllipsisExpr
|
||||
if name == 'default_factory':
|
||||
return False
|
||||
# In this case, default and default_factory are not specified, so we need to look at the annotation
|
||||
value_type = get_proper_type(cls.info[lhs.name].type)
|
||||
return not PydanticModelTransformer.type_has_implicit_default(value_type)
|
||||
# Only required if the "default value" is Ellipsis (i.e., `field_name: Annotation = ...`)
|
||||
return isinstance(expr, EllipsisExpr)
|
||||
|
||||
@staticmethod
|
||||
def type_has_implicit_default(type_: Optional[ProperType]) -> bool:
|
||||
"""
|
||||
Returns True if the passed type will be given an implicit default value.
|
||||
|
||||
In pydantic v1, this is the case for Optional types and Any (with default value None).
|
||||
"""
|
||||
if isinstance(type_, AnyType):
|
||||
# Annotated as Any
|
||||
return True
|
||||
if isinstance(type_, UnionType) and any(
|
||||
isinstance(item, NoneType) or isinstance(item, AnyType) for item in type_.items
|
||||
):
|
||||
# Annotated as Optional, or otherwise having NoneType or AnyType in the union
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_alias_info(stmt: AssignmentStmt) -> Tuple[Optional[str], bool]:
|
||||
"""
|
||||
Returns a pair (alias, has_dynamic_alias), extracted from the declaration of the field defined in `stmt`.
|
||||
|
||||
`has_dynamic_alias` is True if and only if an alias is provided, but not as a string literal.
|
||||
If `has_dynamic_alias` is True, `alias` will be None.
|
||||
"""
|
||||
expr = stmt.rvalue
|
||||
if isinstance(expr, TempNode):
|
||||
# TempNode means annotation-only
|
||||
return None, False
|
||||
|
||||
if not (
|
||||
isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME
|
||||
):
|
||||
# Assigned value is not a call to pydantic.fields.Field
|
||||
return None, False
|
||||
|
||||
for i, arg_name in enumerate(expr.arg_names):
|
||||
if arg_name != 'alias':
|
||||
continue
|
||||
arg = expr.args[i]
|
||||
if isinstance(arg, StrExpr):
|
||||
return arg.value, False
|
||||
else:
|
||||
return None, True
|
||||
return None, False
|
||||
|
||||
def get_field_arguments(
|
||||
self, fields: List['PydanticModelField'], typed: bool, force_all_optional: bool, use_alias: bool
|
||||
) -> List[Argument]:
|
||||
"""
|
||||
Helper function used during the construction of the `__init__` and `construct` method signatures.
|
||||
|
||||
Returns a list of mypy Argument instances for use in the generated signatures.
|
||||
"""
|
||||
info = self._ctx.cls.info
|
||||
arguments = [
|
||||
field.to_argument(info, typed=typed, force_optional=force_all_optional, use_alias=use_alias)
|
||||
for field in fields
|
||||
if not (use_alias and field.has_dynamic_alias)
|
||||
]
|
||||
return arguments
|
||||
|
||||
def should_init_forbid_extra(self, fields: List['PydanticModelField'], config: 'ModelConfigData') -> bool:
|
||||
"""
|
||||
Indicates whether the generated `__init__` should get a `**kwargs` at the end of its signature
|
||||
|
||||
We disallow arbitrary kwargs if the extra config setting is "forbid", or if the plugin config says to,
|
||||
*unless* a required dynamic alias is present (since then we can't determine a valid signature).
|
||||
"""
|
||||
if not config.allow_population_by_field_name:
|
||||
if self.is_dynamic_alias_present(fields, bool(config.has_alias_generator)):
|
||||
return False
|
||||
if config.forbid_extra:
|
||||
return True
|
||||
return self.plugin_config.init_forbid_extra
|
||||
|
||||
@staticmethod
|
||||
def is_dynamic_alias_present(fields: List['PydanticModelField'], has_alias_generator: bool) -> bool:
|
||||
"""
|
||||
Returns whether any fields on the model have a "dynamic alias", i.e., an alias that cannot be
|
||||
determined during static analysis.
|
||||
"""
|
||||
for field in fields:
|
||||
if field.has_dynamic_alias:
|
||||
return True
|
||||
if has_alias_generator:
|
||||
for field in fields:
|
||||
if field.alias is None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class PydanticModelField:
|
||||
def __init__(
|
||||
self, name: str, is_required: bool, alias: Optional[str], has_dynamic_alias: bool, line: int, column: int
|
||||
):
|
||||
self.name = name
|
||||
self.is_required = is_required
|
||||
self.alias = alias
|
||||
self.has_dynamic_alias = has_dynamic_alias
|
||||
self.line = line
|
||||
self.column = column
|
||||
|
||||
def to_var(self, info: TypeInfo, use_alias: bool) -> Var:
|
||||
name = self.name
|
||||
if use_alias and self.alias is not None:
|
||||
name = self.alias
|
||||
return Var(name, info[self.name].type)
|
||||
|
||||
def to_argument(self, info: TypeInfo, typed: bool, force_optional: bool, use_alias: bool) -> Argument:
|
||||
if typed and info[self.name].type is not None:
|
||||
type_annotation = info[self.name].type
|
||||
else:
|
||||
type_annotation = AnyType(TypeOfAny.explicit)
|
||||
return Argument(
|
||||
variable=self.to_var(info, use_alias),
|
||||
type_annotation=type_annotation,
|
||||
initializer=None,
|
||||
kind=ARG_NAMED_OPT if force_optional or not self.is_required else ARG_NAMED,
|
||||
)
|
||||
|
||||
def serialize(self) -> JsonDict:
|
||||
return self.__dict__
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, info: TypeInfo, data: JsonDict) -> 'PydanticModelField':
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class ModelConfigData:
|
||||
def __init__(
|
||||
self,
|
||||
forbid_extra: Optional[bool] = None,
|
||||
allow_mutation: Optional[bool] = None,
|
||||
frozen: Optional[bool] = None,
|
||||
orm_mode: Optional[bool] = None,
|
||||
allow_population_by_field_name: Optional[bool] = None,
|
||||
has_alias_generator: Optional[bool] = None,
|
||||
):
|
||||
self.forbid_extra = forbid_extra
|
||||
self.allow_mutation = allow_mutation
|
||||
self.frozen = frozen
|
||||
self.orm_mode = orm_mode
|
||||
self.allow_population_by_field_name = allow_population_by_field_name
|
||||
self.has_alias_generator = has_alias_generator
|
||||
|
||||
def set_values_dict(self) -> Dict[str, Any]:
|
||||
return {k: v for k, v in self.__dict__.items() if v is not None}
|
||||
|
||||
def update(self, config: Optional['ModelConfigData']) -> None:
|
||||
if config is None:
|
||||
return
|
||||
for k, v in config.set_values_dict().items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def setdefault(self, key: str, value: Any) -> None:
|
||||
if getattr(self, key) is None:
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
ERROR_ORM = ErrorCode('pydantic-orm', 'Invalid from_orm call', 'Pydantic')
|
||||
ERROR_CONFIG = ErrorCode('pydantic-config', 'Invalid config value', 'Pydantic')
|
||||
ERROR_ALIAS = ErrorCode('pydantic-alias', 'Dynamic alias disallowed', 'Pydantic')
|
||||
ERROR_UNEXPECTED = ErrorCode('pydantic-unexpected', 'Unexpected behavior', 'Pydantic')
|
||||
ERROR_UNTYPED = ErrorCode('pydantic-field', 'Untyped field disallowed', 'Pydantic')
|
||||
ERROR_FIELD_DEFAULTS = ErrorCode('pydantic-field', 'Invalid Field defaults', 'Pydantic')
|
||||
|
||||
|
||||
def error_from_orm(model_name: str, api: CheckerPluginInterface, context: Context) -> None:
|
||||
api.fail(f'"{model_name}" does not have orm_mode=True', context, code=ERROR_ORM)
|
||||
|
||||
|
||||
def error_invalid_config_value(name: str, api: SemanticAnalyzerPluginInterface, context: Context) -> None:
|
||||
api.fail(f'Invalid value for "Config.{name}"', context, code=ERROR_CONFIG)
|
||||
|
||||
|
||||
def error_required_dynamic_aliases(api: SemanticAnalyzerPluginInterface, context: Context) -> None:
|
||||
api.fail('Required dynamic aliases disallowed', context, code=ERROR_ALIAS)
|
||||
|
||||
|
||||
def error_unexpected_behavior(
|
||||
detail: str, api: Union[CheckerPluginInterface, SemanticAnalyzerPluginInterface], context: Context
|
||||
) -> None: # pragma: no cover
|
||||
# Can't think of a good way to test this, but I confirmed it renders as desired by adding to a non-error path
|
||||
link = 'https://github.com/pydantic/pydantic/issues/new/choose'
|
||||
full_message = f'The pydantic mypy plugin ran into unexpected behavior: {detail}\n'
|
||||
full_message += f'Please consider reporting this bug at {link} so we can try to fix it!'
|
||||
api.fail(full_message, context, code=ERROR_UNEXPECTED)
|
||||
|
||||
|
||||
def error_untyped_fields(api: SemanticAnalyzerPluginInterface, context: Context) -> None:
|
||||
api.fail('Untyped fields disallowed', context, code=ERROR_UNTYPED)
|
||||
|
||||
|
||||
def error_default_and_default_factory_specified(api: CheckerPluginInterface, context: Context) -> None:
|
||||
api.fail('Field default and default_factory cannot be specified together', context, code=ERROR_FIELD_DEFAULTS)
|
||||
|
||||
|
||||
def add_method(
|
||||
ctx: ClassDefContext,
|
||||
name: str,
|
||||
args: List[Argument],
|
||||
return_type: Type,
|
||||
self_type: Optional[Type] = None,
|
||||
tvar_def: Optional[TypeVarDef] = None,
|
||||
is_classmethod: bool = False,
|
||||
is_new: bool = False,
|
||||
# is_staticmethod: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Adds a new method to a class.
|
||||
|
||||
This can be dropped if/when https://github.com/python/mypy/issues/7301 is merged
|
||||
"""
|
||||
info = ctx.cls.info
|
||||
|
||||
# First remove any previously generated methods with the same name
|
||||
# to avoid clashes and problems in the semantic analyzer.
|
||||
if name in info.names:
|
||||
sym = info.names[name]
|
||||
if sym.plugin_generated and isinstance(sym.node, FuncDef):
|
||||
ctx.cls.defs.body.remove(sym.node) # pragma: no cover
|
||||
|
||||
self_type = self_type or fill_typevars(info)
|
||||
if is_classmethod or is_new:
|
||||
first = [Argument(Var('_cls'), TypeType.make_normalized(self_type), None, ARG_POS)]
|
||||
# elif is_staticmethod:
|
||||
# first = []
|
||||
else:
|
||||
self_type = self_type or fill_typevars(info)
|
||||
first = [Argument(Var('__pydantic_self__'), self_type, None, ARG_POS)]
|
||||
args = first + args
|
||||
arg_types, arg_names, arg_kinds = [], [], []
|
||||
for arg in args:
|
||||
assert arg.type_annotation, 'All arguments must be fully typed.'
|
||||
arg_types.append(arg.type_annotation)
|
||||
arg_names.append(get_name(arg.variable))
|
||||
arg_kinds.append(arg.kind)
|
||||
|
||||
function_type = ctx.api.named_type(f'{BUILTINS_NAME}.function')
|
||||
signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type)
|
||||
if tvar_def:
|
||||
signature.variables = [tvar_def]
|
||||
|
||||
func = FuncDef(name, args, Block([PassStmt()]))
|
||||
func.info = info
|
||||
func.type = set_callable_name(signature, func)
|
||||
func.is_class = is_classmethod
|
||||
# func.is_static = is_staticmethod
|
||||
func._fullname = get_fullname(info) + '.' + name
|
||||
func.line = info.line
|
||||
|
||||
# NOTE: we would like the plugin generated node to dominate, but we still
|
||||
# need to keep any existing definitions so they get semantically analyzed.
|
||||
if name in info.names:
|
||||
# Get a nice unique name instead.
|
||||
r_name = get_unique_redefinition_name(name, info.names)
|
||||
info.names[r_name] = info.names[name]
|
||||
|
||||
if is_classmethod: # or is_staticmethod:
|
||||
func.is_decorated = True
|
||||
v = Var(name, func.type)
|
||||
v.info = info
|
||||
v._fullname = func._fullname
|
||||
# if is_classmethod:
|
||||
v.is_classmethod = True
|
||||
dec = Decorator(func, [NameExpr('classmethod')], v)
|
||||
# else:
|
||||
# v.is_staticmethod = True
|
||||
# dec = Decorator(func, [NameExpr('staticmethod')], v)
|
||||
|
||||
dec.line = info.line
|
||||
sym = SymbolTableNode(MDEF, dec)
|
||||
else:
|
||||
sym = SymbolTableNode(MDEF, func)
|
||||
sym.plugin_generated = True
|
||||
|
||||
info.names[name] = sym
|
||||
info.defn.defs.body.append(func)
|
||||
|
||||
|
||||
def get_fullname(x: Union[FuncBase, SymbolNode]) -> str:
|
||||
"""
|
||||
Used for compatibility with mypy 0.740; can be dropped once support for 0.740 is dropped.
|
||||
"""
|
||||
fn = x.fullname
|
||||
if callable(fn): # pragma: no cover
|
||||
return fn()
|
||||
return fn
|
||||
|
||||
|
||||
def get_name(x: Union[FuncBase, SymbolNode]) -> str:
|
||||
"""
|
||||
Used for compatibility with mypy 0.740; can be dropped once support for 0.740 is dropped.
|
||||
"""
|
||||
fn = x.name
|
||||
if callable(fn): # pragma: no cover
|
||||
return fn()
|
||||
return fn
|
||||
|
||||
|
||||
def parse_toml(config_file: str) -> Optional[Dict[str, Any]]:
|
||||
if not config_file.endswith('.toml'):
|
||||
return None
|
||||
|
||||
read_mode = 'rb'
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib as toml_
|
||||
else:
|
||||
try:
|
||||
import tomli as toml_
|
||||
except ImportError:
|
||||
# older versions of mypy have toml as a dependency, not tomli
|
||||
read_mode = 'r'
|
||||
try:
|
||||
import toml as toml_ # type: ignore[no-redef]
|
||||
except ImportError: # pragma: no cover
|
||||
import warnings
|
||||
|
||||
warnings.warn('No TOML parser installed, cannot read configuration from `pyproject.toml`.')
|
||||
return None
|
||||
|
||||
with open(config_file, read_mode) as rf:
|
||||
return toml_.load(rf) # type: ignore[arg-type]
|
||||
@@ -0,0 +1,747 @@
|
||||
import re
|
||||
from ipaddress import (
|
||||
IPv4Address,
|
||||
IPv4Interface,
|
||||
IPv4Network,
|
||||
IPv6Address,
|
||||
IPv6Interface,
|
||||
IPv6Network,
|
||||
_BaseAddress,
|
||||
_BaseNetwork,
|
||||
)
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Collection,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Match,
|
||||
Optional,
|
||||
Pattern,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
no_type_check,
|
||||
)
|
||||
|
||||
from pydantic.v1 import errors
|
||||
from pydantic.v1.utils import Representation, update_not_none
|
||||
from pydantic.v1.validators import constr_length_validator, str_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import email_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from pydantic.v1.config import BaseConfig
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic.v1.typing import AnyCallable
|
||||
|
||||
CallableGenerator = Generator[AnyCallable, None, None]
|
||||
|
||||
class Parts(TypedDict, total=False):
|
||||
scheme: str
|
||||
user: Optional[str]
|
||||
password: Optional[str]
|
||||
ipv4: Optional[str]
|
||||
ipv6: Optional[str]
|
||||
domain: Optional[str]
|
||||
port: Optional[str]
|
||||
path: Optional[str]
|
||||
query: Optional[str]
|
||||
fragment: Optional[str]
|
||||
|
||||
class HostParts(TypedDict, total=False):
|
||||
host: str
|
||||
tld: Optional[str]
|
||||
host_type: Optional[str]
|
||||
port: Optional[str]
|
||||
rebuild: bool
|
||||
|
||||
else:
|
||||
email_validator = None
|
||||
|
||||
class Parts(dict):
|
||||
pass
|
||||
|
||||
|
||||
NetworkType = Union[str, bytes, int, Tuple[Union[str, bytes, int], Union[str, int]]]
|
||||
|
||||
__all__ = [
|
||||
'AnyUrl',
|
||||
'AnyHttpUrl',
|
||||
'FileUrl',
|
||||
'HttpUrl',
|
||||
'stricturl',
|
||||
'EmailStr',
|
||||
'NameEmail',
|
||||
'IPvAnyAddress',
|
||||
'IPvAnyInterface',
|
||||
'IPvAnyNetwork',
|
||||
'PostgresDsn',
|
||||
'CockroachDsn',
|
||||
'AmqpDsn',
|
||||
'RedisDsn',
|
||||
'MongoDsn',
|
||||
'KafkaDsn',
|
||||
'validate_email',
|
||||
]
|
||||
|
||||
_url_regex_cache = None
|
||||
_multi_host_url_regex_cache = None
|
||||
_ascii_domain_regex_cache = None
|
||||
_int_domain_regex_cache = None
|
||||
_host_regex_cache = None
|
||||
|
||||
_host_regex = (
|
||||
r'(?:'
|
||||
r'(?P<ipv4>(?:\d{1,3}\.){3}\d{1,3})(?=$|[/:#?])|' # ipv4
|
||||
r'(?P<ipv6>\[[A-F0-9]*:[A-F0-9:]+\])(?=$|[/:#?])|' # ipv6
|
||||
r'(?P<domain>[^\s/:?#]+)' # domain, validation occurs later
|
||||
r')?'
|
||||
r'(?::(?P<port>\d+))?' # port
|
||||
)
|
||||
_scheme_regex = r'(?:(?P<scheme>[a-z][a-z0-9+\-.]+)://)?' # scheme https://tools.ietf.org/html/rfc3986#appendix-A
|
||||
_user_info_regex = r'(?:(?P<user>[^\s:/]*)(?::(?P<password>[^\s/]*))?@)?'
|
||||
_path_regex = r'(?P<path>/[^\s?#]*)?'
|
||||
_query_regex = r'(?:\?(?P<query>[^\s#]*))?'
|
||||
_fragment_regex = r'(?:#(?P<fragment>[^\s#]*))?'
|
||||
|
||||
|
||||
def url_regex() -> Pattern[str]:
|
||||
global _url_regex_cache
|
||||
if _url_regex_cache is None:
|
||||
_url_regex_cache = re.compile(
|
||||
rf'{_scheme_regex}{_user_info_regex}{_host_regex}{_path_regex}{_query_regex}{_fragment_regex}',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
return _url_regex_cache
|
||||
|
||||
|
||||
def multi_host_url_regex() -> Pattern[str]:
|
||||
"""
|
||||
Compiled multi host url regex.
|
||||
|
||||
Additionally to `url_regex` it allows to match multiple hosts.
|
||||
E.g. host1.db.net,host2.db.net
|
||||
"""
|
||||
global _multi_host_url_regex_cache
|
||||
if _multi_host_url_regex_cache is None:
|
||||
_multi_host_url_regex_cache = re.compile(
|
||||
rf'{_scheme_regex}{_user_info_regex}'
|
||||
r'(?P<hosts>([^/]*))' # validation occurs later
|
||||
rf'{_path_regex}{_query_regex}{_fragment_regex}',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
return _multi_host_url_regex_cache
|
||||
|
||||
|
||||
def ascii_domain_regex() -> Pattern[str]:
|
||||
global _ascii_domain_regex_cache
|
||||
if _ascii_domain_regex_cache is None:
|
||||
ascii_chunk = r'[_0-9a-z](?:[-_0-9a-z]{0,61}[_0-9a-z])?'
|
||||
ascii_domain_ending = r'(?P<tld>\.[a-z]{2,63})?\.?'
|
||||
_ascii_domain_regex_cache = re.compile(
|
||||
fr'(?:{ascii_chunk}\.)*?{ascii_chunk}{ascii_domain_ending}', re.IGNORECASE
|
||||
)
|
||||
return _ascii_domain_regex_cache
|
||||
|
||||
|
||||
def int_domain_regex() -> Pattern[str]:
|
||||
global _int_domain_regex_cache
|
||||
if _int_domain_regex_cache is None:
|
||||
int_chunk = r'[_0-9a-\U00040000](?:[-_0-9a-\U00040000]{0,61}[_0-9a-\U00040000])?'
|
||||
int_domain_ending = r'(?P<tld>(\.[^\W\d_]{2,63})|(\.(?:xn--)[_0-9a-z-]{2,63}))?\.?'
|
||||
_int_domain_regex_cache = re.compile(fr'(?:{int_chunk}\.)*?{int_chunk}{int_domain_ending}', re.IGNORECASE)
|
||||
return _int_domain_regex_cache
|
||||
|
||||
|
||||
def host_regex() -> Pattern[str]:
|
||||
global _host_regex_cache
|
||||
if _host_regex_cache is None:
|
||||
_host_regex_cache = re.compile(
|
||||
_host_regex,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
return _host_regex_cache
|
||||
|
||||
|
||||
class AnyUrl(str):
|
||||
strip_whitespace = True
|
||||
min_length = 1
|
||||
max_length = 2**16
|
||||
allowed_schemes: Optional[Collection[str]] = None
|
||||
tld_required: bool = False
|
||||
user_required: bool = False
|
||||
host_required: bool = True
|
||||
hidden_parts: Set[str] = set()
|
||||
|
||||
__slots__ = ('scheme', 'user', 'password', 'host', 'tld', 'host_type', 'port', 'path', 'query', 'fragment')
|
||||
|
||||
@no_type_check
|
||||
def __new__(cls, url: Optional[str], **kwargs) -> object:
|
||||
return str.__new__(cls, cls.build(**kwargs) if url is None else url)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
scheme: str,
|
||||
user: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
tld: Optional[str] = None,
|
||||
host_type: str = 'domain',
|
||||
port: Optional[str] = None,
|
||||
path: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
fragment: Optional[str] = None,
|
||||
) -> None:
|
||||
str.__init__(url)
|
||||
self.scheme = scheme
|
||||
self.user = user
|
||||
self.password = password
|
||||
self.host = host
|
||||
self.tld = tld
|
||||
self.host_type = host_type
|
||||
self.port = port
|
||||
self.path = path
|
||||
self.query = query
|
||||
self.fragment = fragment
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
*,
|
||||
scheme: str,
|
||||
user: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
host: str,
|
||||
port: Optional[str] = None,
|
||||
path: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
fragment: Optional[str] = None,
|
||||
**_kwargs: str,
|
||||
) -> str:
|
||||
parts = Parts(
|
||||
scheme=scheme,
|
||||
user=user,
|
||||
password=password,
|
||||
host=host,
|
||||
port=port,
|
||||
path=path,
|
||||
query=query,
|
||||
fragment=fragment,
|
||||
**_kwargs, # type: ignore[misc]
|
||||
)
|
||||
|
||||
url = scheme + '://'
|
||||
if user:
|
||||
url += user
|
||||
if password:
|
||||
url += ':' + password
|
||||
if user or password:
|
||||
url += '@'
|
||||
url += host
|
||||
if port and ('port' not in cls.hidden_parts or cls.get_default_parts(parts).get('port') != port):
|
||||
url += ':' + port
|
||||
if path:
|
||||
url += path
|
||||
if query:
|
||||
url += '?' + query
|
||||
if fragment:
|
||||
url += '#' + fragment
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
update_not_none(field_schema, minLength=cls.min_length, maxLength=cls.max_length, format='uri')
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: Any, field: 'ModelField', config: 'BaseConfig') -> 'AnyUrl':
|
||||
if value.__class__ == cls:
|
||||
return value
|
||||
value = str_validator(value)
|
||||
if cls.strip_whitespace:
|
||||
value = value.strip()
|
||||
url: str = cast(str, constr_length_validator(value, field, config))
|
||||
|
||||
m = cls._match_url(url)
|
||||
# the regex should always match, if it doesn't please report with details of the URL tried
|
||||
assert m, 'URL regex failed unexpectedly'
|
||||
|
||||
original_parts = cast('Parts', m.groupdict())
|
||||
parts = cls.apply_default_parts(original_parts)
|
||||
parts = cls.validate_parts(parts)
|
||||
|
||||
if m.end() != len(url):
|
||||
raise errors.UrlExtraError(extra=url[m.end() :])
|
||||
|
||||
return cls._build_url(m, url, parts)
|
||||
|
||||
@classmethod
|
||||
def _build_url(cls, m: Match[str], url: str, parts: 'Parts') -> 'AnyUrl':
|
||||
"""
|
||||
Validate hosts and build the AnyUrl object. Split from `validate` so this method
|
||||
can be altered in `MultiHostDsn`.
|
||||
"""
|
||||
host, tld, host_type, rebuild = cls.validate_host(parts)
|
||||
|
||||
return cls(
|
||||
None if rebuild else url,
|
||||
scheme=parts['scheme'],
|
||||
user=parts['user'],
|
||||
password=parts['password'],
|
||||
host=host,
|
||||
tld=tld,
|
||||
host_type=host_type,
|
||||
port=parts['port'],
|
||||
path=parts['path'],
|
||||
query=parts['query'],
|
||||
fragment=parts['fragment'],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _match_url(url: str) -> Optional[Match[str]]:
|
||||
return url_regex().match(url)
|
||||
|
||||
@staticmethod
|
||||
def _validate_port(port: Optional[str]) -> None:
|
||||
if port is not None and int(port) > 65_535:
|
||||
raise errors.UrlPortError()
|
||||
|
||||
@classmethod
|
||||
def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts':
|
||||
"""
|
||||
A method used to validate parts of a URL.
|
||||
Could be overridden to set default values for parts if missing
|
||||
"""
|
||||
scheme = parts['scheme']
|
||||
if scheme is None:
|
||||
raise errors.UrlSchemeError()
|
||||
|
||||
if cls.allowed_schemes and scheme.lower() not in cls.allowed_schemes:
|
||||
raise errors.UrlSchemePermittedError(set(cls.allowed_schemes))
|
||||
|
||||
if validate_port:
|
||||
cls._validate_port(parts['port'])
|
||||
|
||||
user = parts['user']
|
||||
if cls.user_required and user is None:
|
||||
raise errors.UrlUserInfoError()
|
||||
|
||||
return parts
|
||||
|
||||
@classmethod
|
||||
def validate_host(cls, parts: 'Parts') -> Tuple[str, Optional[str], str, bool]:
|
||||
tld, host_type, rebuild = None, None, False
|
||||
for f in ('domain', 'ipv4', 'ipv6'):
|
||||
host = parts[f] # type: ignore[literal-required]
|
||||
if host:
|
||||
host_type = f
|
||||
break
|
||||
|
||||
if host is None:
|
||||
if cls.host_required:
|
||||
raise errors.UrlHostError()
|
||||
elif host_type == 'domain':
|
||||
is_international = False
|
||||
d = ascii_domain_regex().fullmatch(host)
|
||||
if d is None:
|
||||
d = int_domain_regex().fullmatch(host)
|
||||
if d is None:
|
||||
raise errors.UrlHostError()
|
||||
is_international = True
|
||||
|
||||
tld = d.group('tld')
|
||||
if tld is None and not is_international:
|
||||
d = int_domain_regex().fullmatch(host)
|
||||
assert d is not None
|
||||
tld = d.group('tld')
|
||||
is_international = True
|
||||
|
||||
if tld is not None:
|
||||
tld = tld[1:]
|
||||
elif cls.tld_required:
|
||||
raise errors.UrlHostTldError()
|
||||
|
||||
if is_international:
|
||||
host_type = 'int_domain'
|
||||
rebuild = True
|
||||
host = host.encode('idna').decode('ascii')
|
||||
if tld is not None:
|
||||
tld = tld.encode('idna').decode('ascii')
|
||||
|
||||
return host, tld, host_type, rebuild # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def get_default_parts(parts: 'Parts') -> 'Parts':
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def apply_default_parts(cls, parts: 'Parts') -> 'Parts':
|
||||
for key, value in cls.get_default_parts(parts).items():
|
||||
if not parts[key]: # type: ignore[literal-required]
|
||||
parts[key] = value # type: ignore[literal-required]
|
||||
return parts
|
||||
|
||||
def __repr__(self) -> str:
|
||||
extra = ', '.join(f'{n}={getattr(self, n)!r}' for n in self.__slots__ if getattr(self, n) is not None)
|
||||
return f'{self.__class__.__name__}({super().__repr__()}, {extra})'
|
||||
|
||||
|
||||
class AnyHttpUrl(AnyUrl):
|
||||
allowed_schemes = {'http', 'https'}
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class HttpUrl(AnyHttpUrl):
|
||||
tld_required = True
|
||||
# https://stackoverflow.com/questions/417142/what-is-the-maximum-length-of-a-url-in-different-browsers
|
||||
max_length = 2083
|
||||
hidden_parts = {'port'}
|
||||
|
||||
@staticmethod
|
||||
def get_default_parts(parts: 'Parts') -> 'Parts':
|
||||
return {'port': '80' if parts['scheme'] == 'http' else '443'}
|
||||
|
||||
|
||||
class FileUrl(AnyUrl):
|
||||
allowed_schemes = {'file'}
|
||||
host_required = False
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class MultiHostDsn(AnyUrl):
|
||||
__slots__ = AnyUrl.__slots__ + ('hosts',)
|
||||
|
||||
def __init__(self, *args: Any, hosts: Optional[List['HostParts']] = None, **kwargs: Any):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.hosts = hosts
|
||||
|
||||
@staticmethod
|
||||
def _match_url(url: str) -> Optional[Match[str]]:
|
||||
return multi_host_url_regex().match(url)
|
||||
|
||||
@classmethod
|
||||
def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts':
|
||||
return super().validate_parts(parts, validate_port=False)
|
||||
|
||||
@classmethod
|
||||
def _build_url(cls, m: Match[str], url: str, parts: 'Parts') -> 'MultiHostDsn':
|
||||
hosts_parts: List['HostParts'] = []
|
||||
host_re = host_regex()
|
||||
for host in m.groupdict()['hosts'].split(','):
|
||||
d: Parts = host_re.match(host).groupdict() # type: ignore
|
||||
host, tld, host_type, rebuild = cls.validate_host(d)
|
||||
port = d.get('port')
|
||||
cls._validate_port(port)
|
||||
hosts_parts.append(
|
||||
{
|
||||
'host': host,
|
||||
'host_type': host_type,
|
||||
'tld': tld,
|
||||
'rebuild': rebuild,
|
||||
'port': port,
|
||||
}
|
||||
)
|
||||
|
||||
if len(hosts_parts) > 1:
|
||||
return cls(
|
||||
None if any([hp['rebuild'] for hp in hosts_parts]) else url,
|
||||
scheme=parts['scheme'],
|
||||
user=parts['user'],
|
||||
password=parts['password'],
|
||||
path=parts['path'],
|
||||
query=parts['query'],
|
||||
fragment=parts['fragment'],
|
||||
host_type=None,
|
||||
hosts=hosts_parts,
|
||||
)
|
||||
else:
|
||||
# backwards compatibility with single host
|
||||
host_part = hosts_parts[0]
|
||||
return cls(
|
||||
None if host_part['rebuild'] else url,
|
||||
scheme=parts['scheme'],
|
||||
user=parts['user'],
|
||||
password=parts['password'],
|
||||
host=host_part['host'],
|
||||
tld=host_part['tld'],
|
||||
host_type=host_part['host_type'],
|
||||
port=host_part.get('port'),
|
||||
path=parts['path'],
|
||||
query=parts['query'],
|
||||
fragment=parts['fragment'],
|
||||
)
|
||||
|
||||
|
||||
class PostgresDsn(MultiHostDsn):
|
||||
allowed_schemes = {
|
||||
'postgres',
|
||||
'postgresql',
|
||||
'postgresql+asyncpg',
|
||||
'postgresql+pg8000',
|
||||
'postgresql+psycopg',
|
||||
'postgresql+psycopg2',
|
||||
'postgresql+psycopg2cffi',
|
||||
'postgresql+py-postgresql',
|
||||
'postgresql+pygresql',
|
||||
}
|
||||
user_required = True
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class CockroachDsn(AnyUrl):
|
||||
allowed_schemes = {
|
||||
'cockroachdb',
|
||||
'cockroachdb+psycopg2',
|
||||
'cockroachdb+asyncpg',
|
||||
}
|
||||
user_required = True
|
||||
|
||||
|
||||
class AmqpDsn(AnyUrl):
|
||||
allowed_schemes = {'amqp', 'amqps'}
|
||||
host_required = False
|
||||
|
||||
|
||||
class RedisDsn(AnyUrl):
|
||||
__slots__ = ()
|
||||
allowed_schemes = {'redis', 'rediss'}
|
||||
host_required = False
|
||||
|
||||
@staticmethod
|
||||
def get_default_parts(parts: 'Parts') -> 'Parts':
|
||||
return {
|
||||
'domain': 'localhost' if not (parts['ipv4'] or parts['ipv6']) else '',
|
||||
'port': '6379',
|
||||
'path': '/0',
|
||||
}
|
||||
|
||||
|
||||
class MongoDsn(AnyUrl):
|
||||
allowed_schemes = {'mongodb'}
|
||||
|
||||
# TODO: Needed to generic "Parts" for "Replica Set", "Sharded Cluster", and other mongodb deployment modes
|
||||
@staticmethod
|
||||
def get_default_parts(parts: 'Parts') -> 'Parts':
|
||||
return {
|
||||
'port': '27017',
|
||||
}
|
||||
|
||||
|
||||
class KafkaDsn(AnyUrl):
|
||||
allowed_schemes = {'kafka'}
|
||||
|
||||
@staticmethod
|
||||
def get_default_parts(parts: 'Parts') -> 'Parts':
|
||||
return {
|
||||
'domain': 'localhost',
|
||||
'port': '9092',
|
||||
}
|
||||
|
||||
|
||||
def stricturl(
|
||||
*,
|
||||
strip_whitespace: bool = True,
|
||||
min_length: int = 1,
|
||||
max_length: int = 2**16,
|
||||
tld_required: bool = True,
|
||||
host_required: bool = True,
|
||||
allowed_schemes: Optional[Collection[str]] = None,
|
||||
) -> Type[AnyUrl]:
|
||||
# use kwargs then define conf in a dict to aid with IDE type hinting
|
||||
namespace = dict(
|
||||
strip_whitespace=strip_whitespace,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
tld_required=tld_required,
|
||||
host_required=host_required,
|
||||
allowed_schemes=allowed_schemes,
|
||||
)
|
||||
return type('UrlValue', (AnyUrl,), namespace)
|
||||
|
||||
|
||||
def import_email_validator() -> None:
|
||||
global email_validator
|
||||
try:
|
||||
import email_validator
|
||||
except ImportError as e:
|
||||
raise ImportError('email-validator is not installed, run `pip install pydantic[email]`') from e
|
||||
|
||||
|
||||
class EmailStr(str):
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(type='string', format='email')
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
# included here and below so the error happens straight away
|
||||
import_email_validator()
|
||||
|
||||
yield str_validator
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: Union[str]) -> str:
|
||||
return validate_email(value)[1]
|
||||
|
||||
|
||||
class NameEmail(Representation):
|
||||
__slots__ = 'name', 'email'
|
||||
|
||||
def __init__(self, name: str, email: str):
|
||||
self.name = name
|
||||
self.email = email
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, NameEmail) and (self.name, self.email) == (other.name, other.email)
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(type='string', format='name-email')
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
import_email_validator()
|
||||
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: Any) -> 'NameEmail':
|
||||
if value.__class__ == cls:
|
||||
return value
|
||||
value = str_validator(value)
|
||||
return cls(*validate_email(value))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'{self.name} <{self.email}>'
|
||||
|
||||
|
||||
class IPvAnyAddress(_BaseAddress):
|
||||
__slots__ = ()
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(type='string', format='ipvanyaddress')
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: Union[str, bytes, int]) -> Union[IPv4Address, IPv6Address]:
|
||||
try:
|
||||
return IPv4Address(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return IPv6Address(value)
|
||||
except ValueError:
|
||||
raise errors.IPvAnyAddressError()
|
||||
|
||||
|
||||
class IPvAnyInterface(_BaseAddress):
|
||||
__slots__ = ()
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(type='string', format='ipvanyinterface')
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: NetworkType) -> Union[IPv4Interface, IPv6Interface]:
|
||||
try:
|
||||
return IPv4Interface(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return IPv6Interface(value)
|
||||
except ValueError:
|
||||
raise errors.IPvAnyInterfaceError()
|
||||
|
||||
|
||||
class IPvAnyNetwork(_BaseNetwork): # type: ignore
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(type='string', format='ipvanynetwork')
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: NetworkType) -> Union[IPv4Network, IPv6Network]:
|
||||
# Assume IP Network is defined with a default value for ``strict`` argument.
|
||||
# Define your own class if you want to specify network address check strictness.
|
||||
try:
|
||||
return IPv4Network(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return IPv6Network(value)
|
||||
except ValueError:
|
||||
raise errors.IPvAnyNetworkError()
|
||||
|
||||
|
||||
pretty_email_regex = re.compile(r'([\w ]*?) *<(.*)> *')
|
||||
MAX_EMAIL_LENGTH = 2048
|
||||
"""Maximum length for an email.
|
||||
A somewhat arbitrary but very generous number compared to what is allowed by most implementations.
|
||||
"""
|
||||
|
||||
|
||||
def validate_email(value: Union[str]) -> Tuple[str, str]:
|
||||
"""
|
||||
Email address validation using https://pypi.org/project/email-validator/
|
||||
Notes:
|
||||
* raw ip address (literal) domain parts are not allowed.
|
||||
* "John Doe <local_part@domain.com>" style "pretty" email addresses are processed
|
||||
* spaces are striped from the beginning and end of addresses but no error is raised
|
||||
"""
|
||||
if email_validator is None:
|
||||
import_email_validator()
|
||||
|
||||
if len(value) > MAX_EMAIL_LENGTH:
|
||||
raise errors.EmailError()
|
||||
|
||||
m = pretty_email_regex.fullmatch(value)
|
||||
name: Union[str, None] = None
|
||||
if m:
|
||||
name, value = m.groups()
|
||||
email = value.strip()
|
||||
try:
|
||||
parts = email_validator.validate_email(email, check_deliverability=False)
|
||||
except email_validator.EmailNotValidError as e:
|
||||
raise errors.EmailError from e
|
||||
|
||||
if hasattr(parts, 'normalized'):
|
||||
# email-validator >= 2
|
||||
email = parts.normalized
|
||||
assert email is not None
|
||||
name = name or parts.local_part
|
||||
return name, email
|
||||
else:
|
||||
# email-validator >1, <2
|
||||
at_index = email.index('@')
|
||||
local_part = email[:at_index] # RFC 5321, local part must be case-sensitive.
|
||||
global_part = email[at_index:].lower()
|
||||
|
||||
return name or local_part, local_part + global_part
|
||||
@@ -0,0 +1,66 @@
|
||||
import json
|
||||
import pickle
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from pydantic.v1.types import StrBytes
|
||||
|
||||
|
||||
class Protocol(str, Enum):
|
||||
json = 'json'
|
||||
pickle = 'pickle'
|
||||
|
||||
|
||||
def load_str_bytes(
|
||||
b: StrBytes,
|
||||
*,
|
||||
content_type: str = None,
|
||||
encoding: str = 'utf8',
|
||||
proto: Protocol = None,
|
||||
allow_pickle: bool = False,
|
||||
json_loads: Callable[[str], Any] = json.loads,
|
||||
) -> Any:
|
||||
if proto is None and content_type:
|
||||
if content_type.endswith(('json', 'javascript')):
|
||||
pass
|
||||
elif allow_pickle and content_type.endswith('pickle'):
|
||||
proto = Protocol.pickle
|
||||
else:
|
||||
raise TypeError(f'Unknown content-type: {content_type}')
|
||||
|
||||
proto = proto or Protocol.json
|
||||
|
||||
if proto == Protocol.json:
|
||||
if isinstance(b, bytes):
|
||||
b = b.decode(encoding)
|
||||
return json_loads(b)
|
||||
elif proto == Protocol.pickle:
|
||||
if not allow_pickle:
|
||||
raise RuntimeError('Trying to decode with pickle with allow_pickle=False')
|
||||
bb = b if isinstance(b, bytes) else b.encode()
|
||||
return pickle.loads(bb)
|
||||
else:
|
||||
raise TypeError(f'Unknown protocol: {proto}')
|
||||
|
||||
|
||||
def load_file(
|
||||
path: Union[str, Path],
|
||||
*,
|
||||
content_type: str = None,
|
||||
encoding: str = 'utf8',
|
||||
proto: Protocol = None,
|
||||
allow_pickle: bool = False,
|
||||
json_loads: Callable[[str], Any] = json.loads,
|
||||
) -> Any:
|
||||
path = Path(path)
|
||||
b = path.read_bytes()
|
||||
if content_type is None:
|
||||
if path.suffix in ('.js', '.json'):
|
||||
proto = Protocol.json
|
||||
elif path.suffix == '.pkl':
|
||||
proto = Protocol.pickle
|
||||
|
||||
return load_str_bytes(
|
||||
b, proto=proto, content_type=content_type, encoding=encoding, allow_pickle=allow_pickle, json_loads=json_loads
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,92 @@
|
||||
import json
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar, Union
|
||||
|
||||
from pydantic.v1.parse import Protocol, load_file, load_str_bytes
|
||||
from pydantic.v1.types import StrBytes
|
||||
from pydantic.v1.typing import display_as_type
|
||||
|
||||
__all__ = ('parse_file_as', 'parse_obj_as', 'parse_raw_as', 'schema_of', 'schema_json_of')
|
||||
|
||||
NameFactory = Union[str, Callable[[Type[Any]], str]]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.typing import DictStrAny
|
||||
|
||||
|
||||
def _generate_parsing_type_name(type_: Any) -> str:
|
||||
return f'ParsingModel[{display_as_type(type_)}]'
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def _get_parsing_type(type_: Any, *, type_name: Optional[NameFactory] = None) -> Any:
|
||||
from pydantic.v1.main import create_model
|
||||
|
||||
if type_name is None:
|
||||
type_name = _generate_parsing_type_name
|
||||
if not isinstance(type_name, str):
|
||||
type_name = type_name(type_)
|
||||
return create_model(type_name, __root__=(type_, ...))
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def parse_obj_as(type_: Type[T], obj: Any, *, type_name: Optional[NameFactory] = None) -> T:
|
||||
model_type = _get_parsing_type(type_, type_name=type_name) # type: ignore[arg-type]
|
||||
return model_type(__root__=obj).__root__
|
||||
|
||||
|
||||
def parse_file_as(
|
||||
type_: Type[T],
|
||||
path: Union[str, Path],
|
||||
*,
|
||||
content_type: str = None,
|
||||
encoding: str = 'utf8',
|
||||
proto: Protocol = None,
|
||||
allow_pickle: bool = False,
|
||||
json_loads: Callable[[str], Any] = json.loads,
|
||||
type_name: Optional[NameFactory] = None,
|
||||
) -> T:
|
||||
obj = load_file(
|
||||
path,
|
||||
proto=proto,
|
||||
content_type=content_type,
|
||||
encoding=encoding,
|
||||
allow_pickle=allow_pickle,
|
||||
json_loads=json_loads,
|
||||
)
|
||||
return parse_obj_as(type_, obj, type_name=type_name)
|
||||
|
||||
|
||||
def parse_raw_as(
|
||||
type_: Type[T],
|
||||
b: StrBytes,
|
||||
*,
|
||||
content_type: str = None,
|
||||
encoding: str = 'utf8',
|
||||
proto: Protocol = None,
|
||||
allow_pickle: bool = False,
|
||||
json_loads: Callable[[str], Any] = json.loads,
|
||||
type_name: Optional[NameFactory] = None,
|
||||
) -> T:
|
||||
obj = load_str_bytes(
|
||||
b,
|
||||
proto=proto,
|
||||
content_type=content_type,
|
||||
encoding=encoding,
|
||||
allow_pickle=allow_pickle,
|
||||
json_loads=json_loads,
|
||||
)
|
||||
return parse_obj_as(type_, obj, type_name=type_name)
|
||||
|
||||
|
||||
def schema_of(type_: Any, *, title: Optional[NameFactory] = None, **schema_kwargs: Any) -> 'DictStrAny':
|
||||
"""Generate a JSON schema (as dict) for the passed model or dynamically generated one"""
|
||||
return _get_parsing_type(type_, type_name=title).schema(**schema_kwargs)
|
||||
|
||||
|
||||
def schema_json_of(type_: Any, *, title: Optional[NameFactory] = None, **schema_json_kwargs: Any) -> str:
|
||||
"""Generate a JSON schema (as JSON) for the passed model or dynamically generated one"""
|
||||
return _get_parsing_type(type_, type_name=title).schema_json(**schema_json_kwargs)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,608 @@
|
||||
import sys
|
||||
import typing
|
||||
from collections.abc import Callable
|
||||
from os import PathLike
|
||||
from typing import ( # type: ignore
|
||||
TYPE_CHECKING,
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable as TypingCallable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
NewType,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
_eval_type,
|
||||
cast,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from typing_extensions import (
|
||||
Annotated,
|
||||
Final,
|
||||
Literal,
|
||||
NotRequired as TypedDictNotRequired,
|
||||
Required as TypedDictRequired,
|
||||
)
|
||||
|
||||
try:
|
||||
from typing import _TypingBase as typing_base # type: ignore
|
||||
except ImportError:
|
||||
from typing import _Final as typing_base # type: ignore
|
||||
|
||||
try:
|
||||
from typing import GenericAlias as TypingGenericAlias # type: ignore
|
||||
except ImportError:
|
||||
# python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
|
||||
TypingGenericAlias = ()
|
||||
|
||||
try:
|
||||
from types import UnionType as TypesUnionType # type: ignore
|
||||
except ImportError:
|
||||
# python < 3.10 does not have UnionType (str | int, byte | bool and so on)
|
||||
TypesUnionType = ()
|
||||
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
|
||||
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
||||
return type_._evaluate(globalns, localns)
|
||||
|
||||
else:
|
||||
|
||||
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
||||
# Even though it is the right signature for python 3.9, mypy complains with
|
||||
# `error: Too many arguments for "_evaluate" of "ForwardRef"` hence the cast...
|
||||
# Python 3.13/3.12.4+ made `recursive_guard` a kwarg, so name it explicitly to avoid:
|
||||
# TypeError: ForwardRef._evaluate() missing 1 required keyword-only argument: 'recursive_guard'
|
||||
return cast(Any, type_)._evaluate(globalns, localns, recursive_guard=set())
|
||||
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
# Ensure we always get all the whole `Annotated` hint, not just the annotated type.
|
||||
# For 3.7 to 3.8, `get_type_hints` doesn't recognize `typing_extensions.Annotated`,
|
||||
# so it already returns the full annotation
|
||||
get_all_type_hints = get_type_hints
|
||||
|
||||
else:
|
||||
|
||||
def get_all_type_hints(obj: Any, globalns: Any = None, localns: Any = None) -> Any:
|
||||
return get_type_hints(obj, globalns, localns, include_extras=True)
|
||||
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
AnyCallable = TypingCallable[..., Any]
|
||||
NoArgAnyCallable = TypingCallable[[], Any]
|
||||
|
||||
# workaround for https://github.com/python/mypy/issues/9496
|
||||
AnyArgTCallable = TypingCallable[..., _T]
|
||||
|
||||
|
||||
# Annotated[...] is implemented by returning an instance of one of these classes, depending on
|
||||
# python/typing_extensions version.
|
||||
AnnotatedTypeNames = {'AnnotatedMeta', '_AnnotatedAlias'}
|
||||
|
||||
|
||||
LITERAL_TYPES: Set[Any] = {Literal}
|
||||
if hasattr(typing, 'Literal'):
|
||||
LITERAL_TYPES.add(typing.Literal)
|
||||
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
|
||||
def get_origin(t: Type[Any]) -> Optional[Type[Any]]:
|
||||
if type(t).__name__ in AnnotatedTypeNames:
|
||||
# weirdly this is a runtime requirement, as well as for mypy
|
||||
return cast(Type[Any], Annotated)
|
||||
return getattr(t, '__origin__', None)
|
||||
|
||||
else:
|
||||
from typing import get_origin as _typing_get_origin
|
||||
|
||||
def get_origin(tp: Type[Any]) -> Optional[Type[Any]]:
|
||||
"""
|
||||
We can't directly use `typing.get_origin` since we need a fallback to support
|
||||
custom generic classes like `ConstrainedList`
|
||||
It should be useless once https://github.com/cython/cython/issues/3537 is
|
||||
solved and https://github.com/pydantic/pydantic/pull/1753 is merged.
|
||||
"""
|
||||
if type(tp).__name__ in AnnotatedTypeNames:
|
||||
return cast(Type[Any], Annotated) # mypy complains about _SpecialForm
|
||||
return _typing_get_origin(tp) or getattr(tp, '__origin__', None)
|
||||
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing import _GenericAlias
|
||||
|
||||
def get_args(t: Type[Any]) -> Tuple[Any, ...]:
|
||||
"""Compatibility version of get_args for python 3.7.
|
||||
|
||||
Mostly compatible with the python 3.8 `typing` module version
|
||||
and able to handle almost all use cases.
|
||||
"""
|
||||
if type(t).__name__ in AnnotatedTypeNames:
|
||||
return t.__args__ + t.__metadata__
|
||||
if isinstance(t, _GenericAlias):
|
||||
res = t.__args__
|
||||
if t.__origin__ is Callable and res and res[0] is not Ellipsis:
|
||||
res = (list(res[:-1]), res[-1])
|
||||
return res
|
||||
return getattr(t, '__args__', ())
|
||||
|
||||
else:
|
||||
from typing import get_args as _typing_get_args
|
||||
|
||||
def _generic_get_args(tp: Type[Any]) -> Tuple[Any, ...]:
|
||||
"""
|
||||
In python 3.9, `typing.Dict`, `typing.List`, ...
|
||||
do have an empty `__args__` by default (instead of the generic ~T for example).
|
||||
In order to still support `Dict` for example and consider it as `Dict[Any, Any]`,
|
||||
we retrieve the `_nparams` value that tells us how many parameters it needs.
|
||||
"""
|
||||
if hasattr(tp, '_nparams'):
|
||||
return (Any,) * tp._nparams
|
||||
# Special case for `tuple[()]`, which used to return ((),) with `typing.Tuple`
|
||||
# in python 3.10- but now returns () for `tuple` and `Tuple`.
|
||||
# This will probably be clarified in pydantic v2
|
||||
try:
|
||||
if tp == Tuple[()] or sys.version_info >= (3, 9) and tp == tuple[()]: # type: ignore[misc]
|
||||
return ((),)
|
||||
# there is a TypeError when compiled with cython
|
||||
except TypeError: # pragma: no cover
|
||||
pass
|
||||
return ()
|
||||
|
||||
def get_args(tp: Type[Any]) -> Tuple[Any, ...]:
|
||||
"""Get type arguments with all substitutions performed.
|
||||
|
||||
For unions, basic simplifications used by Union constructor are performed.
|
||||
Examples::
|
||||
get_args(Dict[str, int]) == (str, int)
|
||||
get_args(int) == ()
|
||||
get_args(Union[int, Union[T, int], str][int]) == (int, str)
|
||||
get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int])
|
||||
get_args(Callable[[], T][int]) == ([], int)
|
||||
"""
|
||||
if type(tp).__name__ in AnnotatedTypeNames:
|
||||
return tp.__args__ + tp.__metadata__
|
||||
# the fallback is needed for the same reasons as `get_origin` (see above)
|
||||
return _typing_get_args(tp) or getattr(tp, '__args__', ()) or _generic_get_args(tp)
|
||||
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
|
||||
def convert_generics(tp: Type[Any]) -> Type[Any]:
|
||||
"""Python 3.9 and older only supports generics from `typing` module.
|
||||
They convert strings to ForwardRef automatically.
|
||||
|
||||
Examples::
|
||||
typing.List['Hero'] == typing.List[ForwardRef('Hero')]
|
||||
"""
|
||||
return tp
|
||||
|
||||
else:
|
||||
from typing import _UnionGenericAlias # type: ignore
|
||||
|
||||
from typing_extensions import _AnnotatedAlias
|
||||
|
||||
def convert_generics(tp: Type[Any]) -> Type[Any]:
|
||||
"""
|
||||
Recursively searches for `str` type hints and replaces them with ForwardRef.
|
||||
|
||||
Examples::
|
||||
convert_generics(list['Hero']) == list[ForwardRef('Hero')]
|
||||
convert_generics(dict['Hero', 'Team']) == dict[ForwardRef('Hero'), ForwardRef('Team')]
|
||||
convert_generics(typing.Dict['Hero', 'Team']) == typing.Dict[ForwardRef('Hero'), ForwardRef('Team')]
|
||||
convert_generics(list[str | 'Hero'] | int) == list[str | ForwardRef('Hero')] | int
|
||||
"""
|
||||
origin = get_origin(tp)
|
||||
if not origin or not hasattr(tp, '__args__'):
|
||||
return tp
|
||||
|
||||
args = get_args(tp)
|
||||
|
||||
# typing.Annotated needs special treatment
|
||||
if origin is Annotated:
|
||||
return _AnnotatedAlias(convert_generics(args[0]), args[1:])
|
||||
|
||||
# recursively replace `str` instances inside of `GenericAlias` with `ForwardRef(arg)`
|
||||
converted = tuple(
|
||||
ForwardRef(arg) if isinstance(arg, str) and isinstance(tp, TypingGenericAlias) else convert_generics(arg)
|
||||
for arg in args
|
||||
)
|
||||
|
||||
if converted == args:
|
||||
return tp
|
||||
elif isinstance(tp, TypingGenericAlias):
|
||||
return TypingGenericAlias(origin, converted)
|
||||
elif isinstance(tp, TypesUnionType):
|
||||
# recreate types.UnionType (PEP604, Python >= 3.10)
|
||||
return _UnionGenericAlias(origin, converted)
|
||||
else:
|
||||
try:
|
||||
setattr(tp, '__args__', converted)
|
||||
except AttributeError:
|
||||
pass
|
||||
return tp
|
||||
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
|
||||
def is_union(tp: Optional[Type[Any]]) -> bool:
|
||||
return tp is Union
|
||||
|
||||
WithArgsTypes = (TypingGenericAlias,)
|
||||
|
||||
else:
|
||||
import types
|
||||
import typing
|
||||
|
||||
def is_union(tp: Optional[Type[Any]]) -> bool:
|
||||
return tp is Union or tp is types.UnionType # noqa: E721
|
||||
|
||||
WithArgsTypes = (typing._GenericAlias, types.GenericAlias, types.UnionType)
|
||||
|
||||
|
||||
StrPath = Union[str, PathLike]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.fields import ModelField
|
||||
|
||||
TupleGenerator = Generator[Tuple[str, Any], None, None]
|
||||
DictStrAny = Dict[str, Any]
|
||||
DictAny = Dict[Any, Any]
|
||||
SetStr = Set[str]
|
||||
ListStr = List[str]
|
||||
IntStr = Union[int, str]
|
||||
AbstractSetIntStr = AbstractSet[IntStr]
|
||||
DictIntStrAny = Dict[IntStr, Any]
|
||||
MappingIntStrAny = Mapping[IntStr, Any]
|
||||
CallableGenerator = Generator[AnyCallable, None, None]
|
||||
ReprArgs = Sequence[Tuple[Optional[str], Any]]
|
||||
|
||||
MYPY = False
|
||||
if MYPY:
|
||||
AnyClassMethod = classmethod[Any]
|
||||
else:
|
||||
# classmethod[TargetType, CallableParamSpecType, CallableReturnType]
|
||||
AnyClassMethod = classmethod[Any, Any, Any]
|
||||
|
||||
__all__ = (
|
||||
'AnyCallable',
|
||||
'NoArgAnyCallable',
|
||||
'NoneType',
|
||||
'is_none_type',
|
||||
'display_as_type',
|
||||
'resolve_annotations',
|
||||
'is_callable_type',
|
||||
'is_literal_type',
|
||||
'all_literal_values',
|
||||
'is_namedtuple',
|
||||
'is_typeddict',
|
||||
'is_typeddict_special',
|
||||
'is_new_type',
|
||||
'new_type_supertype',
|
||||
'is_classvar',
|
||||
'is_finalvar',
|
||||
'update_field_forward_refs',
|
||||
'update_model_forward_refs',
|
||||
'TupleGenerator',
|
||||
'DictStrAny',
|
||||
'DictAny',
|
||||
'SetStr',
|
||||
'ListStr',
|
||||
'IntStr',
|
||||
'AbstractSetIntStr',
|
||||
'DictIntStrAny',
|
||||
'CallableGenerator',
|
||||
'ReprArgs',
|
||||
'AnyClassMethod',
|
||||
'CallableGenerator',
|
||||
'WithArgsTypes',
|
||||
'get_args',
|
||||
'get_origin',
|
||||
'get_sub_types',
|
||||
'typing_base',
|
||||
'get_all_type_hints',
|
||||
'is_union',
|
||||
'StrPath',
|
||||
'MappingIntStrAny',
|
||||
)
|
||||
|
||||
|
||||
NoneType = None.__class__
|
||||
|
||||
|
||||
NONE_TYPES: Tuple[Any, Any, Any] = (None, NoneType, Literal[None])
|
||||
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
# Even though this implementation is slower, we need it for python 3.7:
|
||||
# In python 3.7 "Literal" is not a builtin type and uses a different
|
||||
# mechanism.
|
||||
# for this reason `Literal[None] is Literal[None]` evaluates to `False`,
|
||||
# breaking the faster implementation used for the other python versions.
|
||||
|
||||
def is_none_type(type_: Any) -> bool:
|
||||
return type_ in NONE_TYPES
|
||||
|
||||
elif sys.version_info[:2] == (3, 8):
|
||||
|
||||
def is_none_type(type_: Any) -> bool:
|
||||
for none_type in NONE_TYPES:
|
||||
if type_ is none_type:
|
||||
return True
|
||||
# With python 3.8, specifically 3.8.10, Literal "is" check sare very flakey
|
||||
# can change on very subtle changes like use of types in other modules,
|
||||
# hopefully this check avoids that issue.
|
||||
if is_literal_type(type_): # pragma: no cover
|
||||
return all_literal_values(type_) == (None,)
|
||||
return False
|
||||
|
||||
else:
|
||||
|
||||
def is_none_type(type_: Any) -> bool:
|
||||
return type_ in NONE_TYPES
|
||||
|
||||
|
||||
def display_as_type(v: Type[Any]) -> str:
|
||||
if not isinstance(v, typing_base) and not isinstance(v, WithArgsTypes) and not isinstance(v, type):
|
||||
v = v.__class__
|
||||
|
||||
if is_union(get_origin(v)):
|
||||
return f'Union[{", ".join(map(display_as_type, get_args(v)))}]'
|
||||
|
||||
if isinstance(v, WithArgsTypes):
|
||||
# Generic alias are constructs like `list[int]`
|
||||
return str(v).replace('typing.', '')
|
||||
|
||||
try:
|
||||
return v.__name__
|
||||
except AttributeError:
|
||||
# happens with typing objects
|
||||
return str(v).replace('typing.', '')
|
||||
|
||||
|
||||
def resolve_annotations(raw_annotations: Dict[str, Type[Any]], module_name: Optional[str]) -> Dict[str, Type[Any]]:
|
||||
"""
|
||||
Partially taken from typing.get_type_hints.
|
||||
|
||||
Resolve string or ForwardRef annotations into type objects if possible.
|
||||
"""
|
||||
base_globals: Optional[Dict[str, Any]] = None
|
||||
if module_name:
|
||||
try:
|
||||
module = sys.modules[module_name]
|
||||
except KeyError:
|
||||
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
|
||||
pass
|
||||
else:
|
||||
base_globals = module.__dict__
|
||||
|
||||
annotations = {}
|
||||
for name, value in raw_annotations.items():
|
||||
if isinstance(value, str):
|
||||
if (3, 10) > sys.version_info >= (3, 9, 8) or sys.version_info >= (3, 10, 1):
|
||||
value = ForwardRef(value, is_argument=False, is_class=True)
|
||||
else:
|
||||
value = ForwardRef(value, is_argument=False)
|
||||
try:
|
||||
if sys.version_info >= (3, 13):
|
||||
value = _eval_type(value, base_globals, None, type_params=())
|
||||
else:
|
||||
value = _eval_type(value, base_globals, None)
|
||||
except NameError:
|
||||
# this is ok, it can be fixed with update_forward_refs
|
||||
pass
|
||||
annotations[name] = value
|
||||
return annotations
|
||||
|
||||
|
||||
def is_callable_type(type_: Type[Any]) -> bool:
|
||||
return type_ is Callable or get_origin(type_) is Callable
|
||||
|
||||
|
||||
def is_literal_type(type_: Type[Any]) -> bool:
|
||||
return Literal is not None and get_origin(type_) in LITERAL_TYPES
|
||||
|
||||
|
||||
def literal_values(type_: Type[Any]) -> Tuple[Any, ...]:
|
||||
return get_args(type_)
|
||||
|
||||
|
||||
def all_literal_values(type_: Type[Any]) -> Tuple[Any, ...]:
|
||||
"""
|
||||
This method is used to retrieve all Literal values as
|
||||
Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586)
|
||||
e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]`
|
||||
"""
|
||||
if not is_literal_type(type_):
|
||||
return (type_,)
|
||||
|
||||
values = literal_values(type_)
|
||||
return tuple(x for value in values for x in all_literal_values(value))
|
||||
|
||||
|
||||
def is_namedtuple(type_: Type[Any]) -> bool:
|
||||
"""
|
||||
Check if a given class is a named tuple.
|
||||
It can be either a `typing.NamedTuple` or `collections.namedtuple`
|
||||
"""
|
||||
from pydantic.v1.utils import lenient_issubclass
|
||||
|
||||
return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields')
|
||||
|
||||
|
||||
def is_typeddict(type_: Type[Any]) -> bool:
|
||||
"""
|
||||
Check if a given class is a typed dict (from `typing` or `typing_extensions`)
|
||||
In 3.10, there will be a public method (https://docs.python.org/3.10/library/typing.html#typing.is_typeddict)
|
||||
"""
|
||||
from pydantic.v1.utils import lenient_issubclass
|
||||
|
||||
return lenient_issubclass(type_, dict) and hasattr(type_, '__total__')
|
||||
|
||||
|
||||
def _check_typeddict_special(type_: Any) -> bool:
|
||||
return type_ is TypedDictRequired or type_ is TypedDictNotRequired
|
||||
|
||||
|
||||
def is_typeddict_special(type_: Any) -> bool:
|
||||
"""
|
||||
Check if type is a TypedDict special form (Required or NotRequired).
|
||||
"""
|
||||
return _check_typeddict_special(type_) or _check_typeddict_special(get_origin(type_))
|
||||
|
||||
|
||||
test_type = NewType('test_type', str)
|
||||
|
||||
|
||||
def is_new_type(type_: Type[Any]) -> bool:
|
||||
"""
|
||||
Check whether type_ was created using typing.NewType
|
||||
"""
|
||||
return isinstance(type_, test_type.__class__) and hasattr(type_, '__supertype__') # type: ignore
|
||||
|
||||
|
||||
def new_type_supertype(type_: Type[Any]) -> Type[Any]:
|
||||
while hasattr(type_, '__supertype__'):
|
||||
type_ = type_.__supertype__
|
||||
return type_
|
||||
|
||||
|
||||
def _check_classvar(v: Optional[Type[Any]]) -> bool:
|
||||
if v is None:
|
||||
return False
|
||||
|
||||
return v.__class__ == ClassVar.__class__ and getattr(v, '_name', None) == 'ClassVar'
|
||||
|
||||
|
||||
def _check_finalvar(v: Optional[Type[Any]]) -> bool:
|
||||
"""
|
||||
Check if a given type is a `typing.Final` type.
|
||||
"""
|
||||
if v is None:
|
||||
return False
|
||||
|
||||
return v.__class__ == Final.__class__ and (sys.version_info < (3, 8) or getattr(v, '_name', None) == 'Final')
|
||||
|
||||
|
||||
def is_classvar(ann_type: Type[Any]) -> bool:
|
||||
if _check_classvar(ann_type) or _check_classvar(get_origin(ann_type)):
|
||||
return True
|
||||
|
||||
# this is an ugly workaround for class vars that contain forward references and are therefore themselves
|
||||
# forward references, see #3679
|
||||
if ann_type.__class__ == ForwardRef and ann_type.__forward_arg__.startswith('ClassVar['):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_finalvar(ann_type: Type[Any]) -> bool:
|
||||
return _check_finalvar(ann_type) or _check_finalvar(get_origin(ann_type))
|
||||
|
||||
|
||||
def update_field_forward_refs(field: 'ModelField', globalns: Any, localns: Any) -> None:
|
||||
"""
|
||||
Try to update ForwardRefs on fields based on this ModelField, globalns and localns.
|
||||
"""
|
||||
prepare = False
|
||||
if field.type_.__class__ == ForwardRef:
|
||||
prepare = True
|
||||
field.type_ = evaluate_forwardref(field.type_, globalns, localns or None)
|
||||
if field.outer_type_.__class__ == ForwardRef:
|
||||
prepare = True
|
||||
field.outer_type_ = evaluate_forwardref(field.outer_type_, globalns, localns or None)
|
||||
if prepare:
|
||||
field.prepare()
|
||||
|
||||
if field.sub_fields:
|
||||
for sub_f in field.sub_fields:
|
||||
update_field_forward_refs(sub_f, globalns=globalns, localns=localns)
|
||||
|
||||
if field.discriminator_key is not None:
|
||||
field.prepare_discriminated_union_sub_fields()
|
||||
|
||||
|
||||
def update_model_forward_refs(
|
||||
model: Type[Any],
|
||||
fields: Iterable['ModelField'],
|
||||
json_encoders: Dict[Union[Type[Any], str, ForwardRef], AnyCallable],
|
||||
localns: 'DictStrAny',
|
||||
exc_to_suppress: Tuple[Type[BaseException], ...] = (),
|
||||
) -> None:
|
||||
"""
|
||||
Try to update model fields ForwardRefs based on model and localns.
|
||||
"""
|
||||
if model.__module__ in sys.modules:
|
||||
globalns = sys.modules[model.__module__].__dict__.copy()
|
||||
else:
|
||||
globalns = {}
|
||||
|
||||
globalns.setdefault(model.__name__, model)
|
||||
|
||||
for f in fields:
|
||||
try:
|
||||
update_field_forward_refs(f, globalns=globalns, localns=localns)
|
||||
except exc_to_suppress:
|
||||
pass
|
||||
|
||||
for key in set(json_encoders.keys()):
|
||||
if isinstance(key, str):
|
||||
fr: ForwardRef = ForwardRef(key)
|
||||
elif isinstance(key, ForwardRef):
|
||||
fr = key
|
||||
else:
|
||||
continue
|
||||
|
||||
try:
|
||||
new_key = evaluate_forwardref(fr, globalns, localns or None)
|
||||
except exc_to_suppress: # pragma: no cover
|
||||
continue
|
||||
|
||||
json_encoders[new_key] = json_encoders.pop(key)
|
||||
|
||||
|
||||
def get_class(type_: Type[Any]) -> Union[None, bool, Type[Any]]:
|
||||
"""
|
||||
Tries to get the class of a Type[T] annotation. Returns True if Type is used
|
||||
without brackets. Otherwise returns None.
|
||||
"""
|
||||
if type_ is type:
|
||||
return True
|
||||
|
||||
if get_origin(type_) is None:
|
||||
return None
|
||||
|
||||
args = get_args(type_)
|
||||
if not args or not isinstance(args[0], type):
|
||||
return True
|
||||
else:
|
||||
return args[0]
|
||||
|
||||
|
||||
def get_sub_types(tp: Any) -> List[Any]:
|
||||
"""
|
||||
Return all the types that are allowed by type `tp`
|
||||
`tp` can be a `Union` of allowed types or an `Annotated` type
|
||||
"""
|
||||
origin = get_origin(tp)
|
||||
if origin is Annotated:
|
||||
return get_sub_types(get_args(tp)[0])
|
||||
elif is_union(origin):
|
||||
return [x for t in get_args(tp) for x in get_sub_types(t)]
|
||||
else:
|
||||
return [tp]
|
||||
@@ -0,0 +1,804 @@
|
||||
import keyword
|
||||
import warnings
|
||||
import weakref
|
||||
from collections import OrderedDict, defaultdict, deque
|
||||
from copy import deepcopy
|
||||
from itertools import islice, zip_longest
|
||||
from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic.v1.errors import ConfigError
|
||||
from pydantic.v1.typing import (
|
||||
NoneType,
|
||||
WithArgsTypes,
|
||||
all_literal_values,
|
||||
display_as_type,
|
||||
get_args,
|
||||
get_origin,
|
||||
is_literal_type,
|
||||
is_union,
|
||||
)
|
||||
from pydantic.v1.version import version_info
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from inspect import Signature
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic.v1.config import BaseConfig
|
||||
from pydantic.v1.dataclasses import Dataclass
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic.v1.main import BaseModel
|
||||
from pydantic.v1.typing import AbstractSetIntStr, DictIntStrAny, IntStr, MappingIntStrAny, ReprArgs
|
||||
|
||||
RichReprResult = Iterable[Union[Any, Tuple[Any], Tuple[str, Any], Tuple[str, Any, Any]]]
|
||||
|
||||
__all__ = (
|
||||
'import_string',
|
||||
'sequence_like',
|
||||
'validate_field_name',
|
||||
'lenient_isinstance',
|
||||
'lenient_issubclass',
|
||||
'in_ipython',
|
||||
'is_valid_identifier',
|
||||
'deep_update',
|
||||
'update_not_none',
|
||||
'almost_equal_floats',
|
||||
'get_model',
|
||||
'to_camel',
|
||||
'to_lower_camel',
|
||||
'is_valid_field',
|
||||
'smart_deepcopy',
|
||||
'PyObjectStr',
|
||||
'Representation',
|
||||
'GetterDict',
|
||||
'ValueItems',
|
||||
'version_info', # required here to match behaviour in v1.3
|
||||
'ClassAttribute',
|
||||
'path_type',
|
||||
'ROOT_KEY',
|
||||
'get_unique_discriminator_alias',
|
||||
'get_discriminator_alias_and_values',
|
||||
'DUNDER_ATTRIBUTES',
|
||||
)
|
||||
|
||||
ROOT_KEY = '__root__'
|
||||
# these are types that are returned unchanged by deepcopy
|
||||
IMMUTABLE_NON_COLLECTIONS_TYPES: Set[Type[Any]] = {
|
||||
int,
|
||||
float,
|
||||
complex,
|
||||
str,
|
||||
bool,
|
||||
bytes,
|
||||
type,
|
||||
NoneType,
|
||||
FunctionType,
|
||||
BuiltinFunctionType,
|
||||
LambdaType,
|
||||
weakref.ref,
|
||||
CodeType,
|
||||
# note: including ModuleType will differ from behaviour of deepcopy by not producing error.
|
||||
# It might be not a good idea in general, but considering that this function used only internally
|
||||
# against default values of fields, this will allow to actually have a field with module as default value
|
||||
ModuleType,
|
||||
NotImplemented.__class__,
|
||||
Ellipsis.__class__,
|
||||
}
|
||||
|
||||
# these are types that if empty, might be copied with simple copy() instead of deepcopy()
|
||||
BUILTIN_COLLECTIONS: Set[Type[Any]] = {
|
||||
list,
|
||||
set,
|
||||
tuple,
|
||||
frozenset,
|
||||
dict,
|
||||
OrderedDict,
|
||||
defaultdict,
|
||||
deque,
|
||||
}
|
||||
|
||||
|
||||
def import_string(dotted_path: str) -> Any:
|
||||
"""
|
||||
Stolen approximately from django. Import a dotted module path and return the attribute/class designated by the
|
||||
last name in the path. Raise ImportError if the import fails.
|
||||
"""
|
||||
from importlib import import_module
|
||||
|
||||
try:
|
||||
module_path, class_name = dotted_path.strip(' ').rsplit('.', 1)
|
||||
except ValueError as e:
|
||||
raise ImportError(f'"{dotted_path}" doesn\'t look like a module path') from e
|
||||
|
||||
module = import_module(module_path)
|
||||
try:
|
||||
return getattr(module, class_name)
|
||||
except AttributeError as e:
|
||||
raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute') from e
|
||||
|
||||
|
||||
def truncate(v: Union[str], *, max_len: int = 80) -> str:
|
||||
"""
|
||||
Truncate a value and add a unicode ellipsis (three dots) to the end if it was too long
|
||||
"""
|
||||
warnings.warn('`truncate` is no-longer used by pydantic and is deprecated', DeprecationWarning)
|
||||
if isinstance(v, str) and len(v) > (max_len - 2):
|
||||
# -3 so quote + string + … + quote has correct length
|
||||
return (v[: (max_len - 3)] + '…').__repr__()
|
||||
try:
|
||||
v = v.__repr__()
|
||||
except TypeError:
|
||||
v = v.__class__.__repr__(v) # in case v is a type
|
||||
if len(v) > max_len:
|
||||
v = v[: max_len - 1] + '…'
|
||||
return v
|
||||
|
||||
|
||||
def sequence_like(v: Any) -> bool:
|
||||
return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque))
|
||||
|
||||
|
||||
def validate_field_name(bases: List[Type['BaseModel']], field_name: str) -> None:
|
||||
"""
|
||||
Ensure that the field's name does not shadow an existing attribute of the model.
|
||||
"""
|
||||
for base in bases:
|
||||
if getattr(base, field_name, None):
|
||||
raise NameError(
|
||||
f'Field name "{field_name}" shadows a BaseModel attribute; '
|
||||
f'use a different field name with "alias=\'{field_name}\'".'
|
||||
)
|
||||
|
||||
|
||||
def lenient_isinstance(o: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool:
|
||||
try:
|
||||
return isinstance(o, class_or_tuple) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def lenient_issubclass(cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool:
|
||||
try:
|
||||
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
if isinstance(cls, WithArgsTypes):
|
||||
return False
|
||||
raise # pragma: no cover
|
||||
|
||||
|
||||
def in_ipython() -> bool:
|
||||
"""
|
||||
Check whether we're in an ipython environment, including jupyter notebooks.
|
||||
"""
|
||||
try:
|
||||
eval('__IPYTHON__')
|
||||
except NameError:
|
||||
return False
|
||||
else: # pragma: no cover
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_identifier(identifier: str) -> bool:
|
||||
"""
|
||||
Checks that a string is a valid identifier and not a Python keyword.
|
||||
:param identifier: The identifier to test.
|
||||
:return: True if the identifier is valid.
|
||||
"""
|
||||
return identifier.isidentifier() and not keyword.iskeyword(identifier)
|
||||
|
||||
|
||||
KeyType = TypeVar('KeyType')
|
||||
|
||||
|
||||
def deep_update(mapping: Dict[KeyType, Any], *updating_mappings: Dict[KeyType, Any]) -> Dict[KeyType, Any]:
|
||||
updated_mapping = mapping.copy()
|
||||
for updating_mapping in updating_mappings:
|
||||
for k, v in updating_mapping.items():
|
||||
if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict):
|
||||
updated_mapping[k] = deep_update(updated_mapping[k], v)
|
||||
else:
|
||||
updated_mapping[k] = v
|
||||
return updated_mapping
|
||||
|
||||
|
||||
def update_not_none(mapping: Dict[Any, Any], **update: Any) -> None:
|
||||
mapping.update({k: v for k, v in update.items() if v is not None})
|
||||
|
||||
|
||||
def almost_equal_floats(value_1: float, value_2: float, *, delta: float = 1e-8) -> bool:
|
||||
"""
|
||||
Return True if two floats are almost equal
|
||||
"""
|
||||
return abs(value_1 - value_2) <= delta
|
||||
|
||||
|
||||
def generate_model_signature(
|
||||
init: Callable[..., None], fields: Dict[str, 'ModelField'], config: Type['BaseConfig']
|
||||
) -> 'Signature':
|
||||
"""
|
||||
Generate signature for model based on its fields
|
||||
"""
|
||||
from inspect import Parameter, Signature, signature
|
||||
|
||||
from pydantic.v1.config import Extra
|
||||
|
||||
present_params = signature(init).parameters.values()
|
||||
merged_params: Dict[str, Parameter] = {}
|
||||
var_kw = None
|
||||
use_var_kw = False
|
||||
|
||||
for param in islice(present_params, 1, None): # skip self arg
|
||||
if param.kind is param.VAR_KEYWORD:
|
||||
var_kw = param
|
||||
continue
|
||||
merged_params[param.name] = param
|
||||
|
||||
if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
|
||||
allow_names = config.allow_population_by_field_name
|
||||
for field_name, field in fields.items():
|
||||
param_name = field.alias
|
||||
if field_name in merged_params or param_name in merged_params:
|
||||
continue
|
||||
elif not is_valid_identifier(param_name):
|
||||
if allow_names and is_valid_identifier(field_name):
|
||||
param_name = field_name
|
||||
else:
|
||||
use_var_kw = True
|
||||
continue
|
||||
|
||||
# TODO: replace annotation with actual expected types once #1055 solved
|
||||
kwargs = {'default': field.default} if not field.required else {}
|
||||
merged_params[param_name] = Parameter(
|
||||
param_name, Parameter.KEYWORD_ONLY, annotation=field.annotation, **kwargs
|
||||
)
|
||||
|
||||
if config.extra is Extra.allow:
|
||||
use_var_kw = True
|
||||
|
||||
if var_kw and use_var_kw:
|
||||
# Make sure the parameter for extra kwargs
|
||||
# does not have the same name as a field
|
||||
default_model_signature = [
|
||||
('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD),
|
||||
('data', Parameter.VAR_KEYWORD),
|
||||
]
|
||||
if [(p.name, p.kind) for p in present_params] == default_model_signature:
|
||||
# if this is the standard model signature, use extra_data as the extra args name
|
||||
var_kw_name = 'extra_data'
|
||||
else:
|
||||
# else start from var_kw
|
||||
var_kw_name = var_kw.name
|
||||
|
||||
# generate a name that's definitely unique
|
||||
while var_kw_name in fields:
|
||||
var_kw_name += '_'
|
||||
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
|
||||
|
||||
return Signature(parameters=list(merged_params.values()), return_annotation=None)
|
||||
|
||||
|
||||
def get_model(obj: Union[Type['BaseModel'], Type['Dataclass']]) -> Type['BaseModel']:
|
||||
from pydantic.v1.main import BaseModel
|
||||
|
||||
try:
|
||||
model_cls = obj.__pydantic_model__ # type: ignore
|
||||
except AttributeError:
|
||||
model_cls = obj
|
||||
|
||||
if not issubclass(model_cls, BaseModel):
|
||||
raise TypeError('Unsupported type, must be either BaseModel or dataclass')
|
||||
return model_cls
|
||||
|
||||
|
||||
def to_camel(string: str) -> str:
|
||||
return ''.join(word.capitalize() for word in string.split('_'))
|
||||
|
||||
|
||||
def to_lower_camel(string: str) -> str:
|
||||
if len(string) >= 1:
|
||||
pascal_string = to_camel(string)
|
||||
return pascal_string[0].lower() + pascal_string[1:]
|
||||
return string.lower()
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def unique_list(
|
||||
input_list: Union[List[T], Tuple[T, ...]],
|
||||
*,
|
||||
name_factory: Callable[[T], str] = str,
|
||||
) -> List[T]:
|
||||
"""
|
||||
Make a list unique while maintaining order.
|
||||
We update the list if another one with the same name is set
|
||||
(e.g. root validator overridden in subclass)
|
||||
"""
|
||||
result: List[T] = []
|
||||
result_names: List[str] = []
|
||||
for v in input_list:
|
||||
v_name = name_factory(v)
|
||||
if v_name not in result_names:
|
||||
result_names.append(v_name)
|
||||
result.append(v)
|
||||
else:
|
||||
result[result_names.index(v_name)] = v
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class PyObjectStr(str):
|
||||
"""
|
||||
String class where repr doesn't include quotes. Useful with Representation when you want to return a string
|
||||
representation of something that valid (or pseudo-valid) python.
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self)
|
||||
|
||||
|
||||
class Representation:
|
||||
"""
|
||||
Mixin to provide __str__, __repr__, and __pretty__ methods. See #884 for more details.
|
||||
|
||||
__pretty__ is used by [devtools](https://python-devtools.helpmanual.io/) to provide human readable representations
|
||||
of objects.
|
||||
"""
|
||||
|
||||
__slots__: Tuple[str, ...] = tuple()
|
||||
|
||||
def __repr_args__(self) -> 'ReprArgs':
|
||||
"""
|
||||
Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden.
|
||||
|
||||
Can either return:
|
||||
* name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]`
|
||||
* or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]`
|
||||
"""
|
||||
attrs = ((s, getattr(self, s)) for s in self.__slots__)
|
||||
return [(a, v) for a, v in attrs if v is not None]
|
||||
|
||||
def __repr_name__(self) -> str:
|
||||
"""
|
||||
Name of the instance's class, used in __repr__.
|
||||
"""
|
||||
return self.__class__.__name__
|
||||
|
||||
def __repr_str__(self, join_str: str) -> str:
|
||||
return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
|
||||
|
||||
def __pretty__(self, fmt: Callable[[Any], Any], **kwargs: Any) -> Generator[Any, None, None]:
|
||||
"""
|
||||
Used by devtools (https://python-devtools.helpmanual.io/) to provide a human readable representations of objects
|
||||
"""
|
||||
yield self.__repr_name__() + '('
|
||||
yield 1
|
||||
for name, value in self.__repr_args__():
|
||||
if name is not None:
|
||||
yield name + '='
|
||||
yield fmt(value)
|
||||
yield ','
|
||||
yield 0
|
||||
yield -1
|
||||
yield ')'
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr_str__(' ')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__repr_name__()}({self.__repr_str__(", ")})'
|
||||
|
||||
def __rich_repr__(self) -> 'RichReprResult':
|
||||
"""Get fields for Rich library"""
|
||||
for name, field_repr in self.__repr_args__():
|
||||
if name is None:
|
||||
yield field_repr
|
||||
else:
|
||||
yield name, field_repr
|
||||
|
||||
|
||||
class GetterDict(Representation):
|
||||
"""
|
||||
Hack to make object's smell just enough like dicts for validate_model.
|
||||
|
||||
We can't inherit from Mapping[str, Any] because it upsets cython so we have to implement all methods ourselves.
|
||||
"""
|
||||
|
||||
__slots__ = ('_obj',)
|
||||
|
||||
def __init__(self, obj: Any):
|
||||
self._obj = obj
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
try:
|
||||
return getattr(self._obj, key)
|
||||
except AttributeError as e:
|
||||
raise KeyError(key) from e
|
||||
|
||||
def get(self, key: Any, default: Any = None) -> Any:
|
||||
return getattr(self._obj, key, default)
|
||||
|
||||
def extra_keys(self) -> Set[Any]:
|
||||
"""
|
||||
We don't want to get any other attributes of obj if the model didn't explicitly ask for them
|
||||
"""
|
||||
return set()
|
||||
|
||||
def keys(self) -> List[Any]:
|
||||
"""
|
||||
Keys of the pseudo dictionary, uses a list not set so order information can be maintained like python
|
||||
dictionaries.
|
||||
"""
|
||||
return list(self)
|
||||
|
||||
def values(self) -> List[Any]:
|
||||
return [self[k] for k in self]
|
||||
|
||||
def items(self) -> Iterator[Tuple[str, Any]]:
|
||||
for k in self:
|
||||
yield k, self.get(k)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
for name in dir(self._obj):
|
||||
if not name.startswith('_'):
|
||||
yield name
|
||||
|
||||
def __len__(self) -> int:
|
||||
return sum(1 for _ in self)
|
||||
|
||||
def __contains__(self, item: Any) -> bool:
|
||||
return item in self.keys()
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return dict(self) == dict(other.items())
|
||||
|
||||
def __repr_args__(self) -> 'ReprArgs':
|
||||
return [(None, dict(self))]
|
||||
|
||||
def __repr_name__(self) -> str:
|
||||
return f'GetterDict[{display_as_type(self._obj)}]'
|
||||
|
||||
|
||||
class ValueItems(Representation):
|
||||
"""
|
||||
Class for more convenient calculation of excluded or included fields on values.
|
||||
"""
|
||||
|
||||
__slots__ = ('_items', '_type')
|
||||
|
||||
def __init__(self, value: Any, items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> None:
|
||||
items = self._coerce_items(items)
|
||||
|
||||
if isinstance(value, (list, tuple)):
|
||||
items = self._normalize_indexes(items, len(value))
|
||||
|
||||
self._items: 'MappingIntStrAny' = items
|
||||
|
||||
def is_excluded(self, item: Any) -> bool:
|
||||
"""
|
||||
Check if item is fully excluded.
|
||||
|
||||
:param item: key or index of a value
|
||||
"""
|
||||
return self.is_true(self._items.get(item))
|
||||
|
||||
def is_included(self, item: Any) -> bool:
|
||||
"""
|
||||
Check if value is contained in self._items
|
||||
|
||||
:param item: key or index of value
|
||||
"""
|
||||
return item in self._items
|
||||
|
||||
def for_element(self, e: 'IntStr') -> Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']]:
|
||||
"""
|
||||
:param e: key or index of element on value
|
||||
:return: raw values for element if self._items is dict and contain needed element
|
||||
"""
|
||||
|
||||
item = self._items.get(e)
|
||||
return item if not self.is_true(item) else None
|
||||
|
||||
def _normalize_indexes(self, items: 'MappingIntStrAny', v_length: int) -> 'DictIntStrAny':
|
||||
"""
|
||||
:param items: dict or set of indexes which will be normalized
|
||||
:param v_length: length of sequence indexes of which will be
|
||||
|
||||
>>> self._normalize_indexes({0: True, -2: True, -1: True}, 4)
|
||||
{0: True, 2: True, 3: True}
|
||||
>>> self._normalize_indexes({'__all__': True}, 4)
|
||||
{0: True, 1: True, 2: True, 3: True}
|
||||
"""
|
||||
|
||||
normalized_items: 'DictIntStrAny' = {}
|
||||
all_items = None
|
||||
for i, v in items.items():
|
||||
if not (isinstance(v, Mapping) or isinstance(v, AbstractSet) or self.is_true(v)):
|
||||
raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}')
|
||||
if i == '__all__':
|
||||
all_items = self._coerce_value(v)
|
||||
continue
|
||||
if not isinstance(i, int):
|
||||
raise TypeError(
|
||||
'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: '
|
||||
'expected integer keys or keyword "__all__"'
|
||||
)
|
||||
normalized_i = v_length + i if i < 0 else i
|
||||
normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i))
|
||||
|
||||
if not all_items:
|
||||
return normalized_items
|
||||
if self.is_true(all_items):
|
||||
for i in range(v_length):
|
||||
normalized_items.setdefault(i, ...)
|
||||
return normalized_items
|
||||
for i in range(v_length):
|
||||
normalized_item = normalized_items.setdefault(i, {})
|
||||
if not self.is_true(normalized_item):
|
||||
normalized_items[i] = self.merge(all_items, normalized_item)
|
||||
return normalized_items
|
||||
|
||||
@classmethod
|
||||
def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any:
|
||||
"""
|
||||
Merge a ``base`` item with an ``override`` item.
|
||||
|
||||
Both ``base`` and ``override`` are converted to dictionaries if possible.
|
||||
Sets are converted to dictionaries with the sets entries as keys and
|
||||
Ellipsis as values.
|
||||
|
||||
Each key-value pair existing in ``base`` is merged with ``override``,
|
||||
while the rest of the key-value pairs are updated recursively with this function.
|
||||
|
||||
Merging takes place based on the "union" of keys if ``intersect`` is
|
||||
set to ``False`` (default) and on the intersection of keys if
|
||||
``intersect`` is set to ``True``.
|
||||
"""
|
||||
override = cls._coerce_value(override)
|
||||
base = cls._coerce_value(base)
|
||||
if override is None:
|
||||
return base
|
||||
if cls.is_true(base) or base is None:
|
||||
return override
|
||||
if cls.is_true(override):
|
||||
return base if intersect else override
|
||||
|
||||
# intersection or union of keys while preserving ordering:
|
||||
if intersect:
|
||||
merge_keys = [k for k in base if k in override] + [k for k in override if k in base]
|
||||
else:
|
||||
merge_keys = list(base) + [k for k in override if k not in base]
|
||||
|
||||
merged: 'DictIntStrAny' = {}
|
||||
for k in merge_keys:
|
||||
merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect)
|
||||
if merged_item is not None:
|
||||
merged[k] = merged_item
|
||||
|
||||
return merged
|
||||
|
||||
@staticmethod
|
||||
def _coerce_items(items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> 'MappingIntStrAny':
|
||||
if isinstance(items, Mapping):
|
||||
pass
|
||||
elif isinstance(items, AbstractSet):
|
||||
items = dict.fromkeys(items, ...)
|
||||
else:
|
||||
class_name = getattr(items, '__class__', '???')
|
||||
assert_never(
|
||||
items,
|
||||
f'Unexpected type of exclude value {class_name}',
|
||||
)
|
||||
return items
|
||||
|
||||
@classmethod
|
||||
def _coerce_value(cls, value: Any) -> Any:
|
||||
if value is None or cls.is_true(value):
|
||||
return value
|
||||
return cls._coerce_items(value)
|
||||
|
||||
@staticmethod
|
||||
def is_true(v: Any) -> bool:
|
||||
return v is True or v is ...
|
||||
|
||||
def __repr_args__(self) -> 'ReprArgs':
|
||||
return [(None, self._items)]
|
||||
|
||||
|
||||
class ClassAttribute:
|
||||
"""
|
||||
Hide class attribute from its instances
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'name',
|
||||
'value',
|
||||
)
|
||||
|
||||
def __init__(self, name: str, value: Any) -> None:
|
||||
self.name = name
|
||||
self.value = value
|
||||
|
||||
def __get__(self, instance: Any, owner: Type[Any]) -> None:
|
||||
if instance is None:
|
||||
return self.value
|
||||
raise AttributeError(f'{self.name!r} attribute of {owner.__name__!r} is class-only')
|
||||
|
||||
|
||||
path_types = {
|
||||
'is_dir': 'directory',
|
||||
'is_file': 'file',
|
||||
'is_mount': 'mount point',
|
||||
'is_symlink': 'symlink',
|
||||
'is_block_device': 'block device',
|
||||
'is_char_device': 'char device',
|
||||
'is_fifo': 'FIFO',
|
||||
'is_socket': 'socket',
|
||||
}
|
||||
|
||||
|
||||
def path_type(p: 'Path') -> str:
|
||||
"""
|
||||
Find out what sort of thing a path is.
|
||||
"""
|
||||
assert p.exists(), 'path does not exist'
|
||||
for method, name in path_types.items():
|
||||
if getattr(p, method)():
|
||||
return name
|
||||
|
||||
return 'unknown'
|
||||
|
||||
|
||||
Obj = TypeVar('Obj')
|
||||
|
||||
|
||||
def smart_deepcopy(obj: Obj) -> Obj:
|
||||
"""
|
||||
Return type as is for immutable built-in types
|
||||
Use obj.copy() for built-in empty collections
|
||||
Use copy.deepcopy() for non-empty collections and unknown objects
|
||||
"""
|
||||
|
||||
obj_type = obj.__class__
|
||||
if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES:
|
||||
return obj # fastest case: obj is immutable and not collection therefore will not be copied anyway
|
||||
try:
|
||||
if not obj and obj_type in BUILTIN_COLLECTIONS:
|
||||
# faster way for empty collections, no need to copy its members
|
||||
return obj if obj_type is tuple else obj.copy() # type: ignore # tuple doesn't have copy method
|
||||
except (TypeError, ValueError, RuntimeError):
|
||||
# do we really dare to catch ALL errors? Seems a bit risky
|
||||
pass
|
||||
|
||||
return deepcopy(obj) # slowest way when we actually might need a deepcopy
|
||||
|
||||
|
||||
def is_valid_field(name: str) -> bool:
|
||||
if not name.startswith('_'):
|
||||
return True
|
||||
return ROOT_KEY == name
|
||||
|
||||
|
||||
DUNDER_ATTRIBUTES = {
|
||||
'__annotations__',
|
||||
'__classcell__',
|
||||
'__doc__',
|
||||
'__module__',
|
||||
'__orig_bases__',
|
||||
'__orig_class__',
|
||||
'__qualname__',
|
||||
}
|
||||
|
||||
|
||||
def is_valid_private_name(name: str) -> bool:
|
||||
return not is_valid_field(name) and name not in DUNDER_ATTRIBUTES
|
||||
|
||||
|
||||
_EMPTY = object()
|
||||
|
||||
|
||||
def all_identical(left: Iterable[Any], right: Iterable[Any]) -> bool:
|
||||
"""
|
||||
Check that the items of `left` are the same objects as those in `right`.
|
||||
|
||||
>>> a, b = object(), object()
|
||||
>>> all_identical([a, b, a], [a, b, a])
|
||||
True
|
||||
>>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical"
|
||||
False
|
||||
"""
|
||||
for left_item, right_item in zip_longest(left, right, fillvalue=_EMPTY):
|
||||
if left_item is not right_item:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def assert_never(obj: NoReturn, msg: str) -> NoReturn:
|
||||
"""
|
||||
Helper to make sure that we have covered all possible types.
|
||||
|
||||
This is mostly useful for ``mypy``, docs:
|
||||
https://mypy.readthedocs.io/en/latest/literal_types.html#exhaustive-checks
|
||||
"""
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
def get_unique_discriminator_alias(all_aliases: Collection[str], discriminator_key: str) -> str:
|
||||
"""Validate that all aliases are the same and if that's the case return the alias"""
|
||||
unique_aliases = set(all_aliases)
|
||||
if len(unique_aliases) > 1:
|
||||
raise ConfigError(
|
||||
f'Aliases for discriminator {discriminator_key!r} must be the same (got {", ".join(sorted(all_aliases))})'
|
||||
)
|
||||
return unique_aliases.pop()
|
||||
|
||||
|
||||
def get_discriminator_alias_and_values(tp: Any, discriminator_key: str) -> Tuple[str, Tuple[str, ...]]:
|
||||
"""
|
||||
Get alias and all valid values in the `Literal` type of the discriminator field
|
||||
`tp` can be a `BaseModel` class or directly an `Annotated` `Union` of many.
|
||||
"""
|
||||
is_root_model = getattr(tp, '__custom_root_type__', False)
|
||||
|
||||
if get_origin(tp) is Annotated:
|
||||
tp = get_args(tp)[0]
|
||||
|
||||
if hasattr(tp, '__pydantic_model__'):
|
||||
tp = tp.__pydantic_model__
|
||||
|
||||
if is_union(get_origin(tp)):
|
||||
alias, all_values = _get_union_alias_and_all_values(tp, discriminator_key)
|
||||
return alias, tuple(v for values in all_values for v in values)
|
||||
elif is_root_model:
|
||||
union_type = tp.__fields__[ROOT_KEY].type_
|
||||
alias, all_values = _get_union_alias_and_all_values(union_type, discriminator_key)
|
||||
|
||||
if len(set(all_values)) > 1:
|
||||
raise ConfigError(
|
||||
f'Field {discriminator_key!r} is not the same for all submodels of {display_as_type(tp)!r}'
|
||||
)
|
||||
|
||||
return alias, all_values[0]
|
||||
|
||||
else:
|
||||
try:
|
||||
t_discriminator_type = tp.__fields__[discriminator_key].type_
|
||||
except AttributeError as e:
|
||||
raise TypeError(f'Type {tp.__name__!r} is not a valid `BaseModel` or `dataclass`') from e
|
||||
except KeyError as e:
|
||||
raise ConfigError(f'Model {tp.__name__!r} needs a discriminator field for key {discriminator_key!r}') from e
|
||||
|
||||
if not is_literal_type(t_discriminator_type):
|
||||
raise ConfigError(f'Field {discriminator_key!r} of model {tp.__name__!r} needs to be a `Literal`')
|
||||
|
||||
return tp.__fields__[discriminator_key].alias, all_literal_values(t_discriminator_type)
|
||||
|
||||
|
||||
def _get_union_alias_and_all_values(
|
||||
union_type: Type[Any], discriminator_key: str
|
||||
) -> Tuple[str, Tuple[Tuple[str, ...], ...]]:
|
||||
zipped_aliases_values = [get_discriminator_alias_and_values(t, discriminator_key) for t in get_args(union_type)]
|
||||
# unzip: [('alias_a',('v1', 'v2)), ('alias_b', ('v3',))] => [('alias_a', 'alias_b'), (('v1', 'v2'), ('v3',))]
|
||||
all_aliases, all_values = zip(*zipped_aliases_values)
|
||||
return get_unique_discriminator_alias(all_aliases, discriminator_key), all_values
|
||||
@@ -0,0 +1,768 @@
|
||||
import math
|
||||
import re
|
||||
from collections import OrderedDict, deque
|
||||
from collections.abc import Hashable as CollectionsHashable
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from decimal import Decimal, DecimalException
|
||||
from enum import Enum, IntEnum
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Deque,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
FrozenSet,
|
||||
Generator,
|
||||
Hashable,
|
||||
List,
|
||||
NamedTuple,
|
||||
Pattern,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from uuid import UUID
|
||||
from warnings import warn
|
||||
|
||||
from pydantic.v1 import errors
|
||||
from pydantic.v1.datetime_parse import parse_date, parse_datetime, parse_duration, parse_time
|
||||
from pydantic.v1.typing import (
|
||||
AnyCallable,
|
||||
all_literal_values,
|
||||
display_as_type,
|
||||
get_class,
|
||||
is_callable_type,
|
||||
is_literal_type,
|
||||
is_namedtuple,
|
||||
is_none_type,
|
||||
is_typeddict,
|
||||
)
|
||||
from pydantic.v1.utils import almost_equal_floats, lenient_issubclass, sequence_like
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Literal, TypedDict
|
||||
|
||||
from pydantic.v1.config import BaseConfig
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic.v1.types import ConstrainedDecimal, ConstrainedFloat, ConstrainedInt
|
||||
|
||||
ConstrainedNumber = Union[ConstrainedDecimal, ConstrainedFloat, ConstrainedInt]
|
||||
AnyOrderedDict = OrderedDict[Any, Any]
|
||||
Number = Union[int, float, Decimal]
|
||||
StrBytes = Union[str, bytes]
|
||||
|
||||
|
||||
def str_validator(v: Any) -> Union[str]:
|
||||
if isinstance(v, str):
|
||||
if isinstance(v, Enum):
|
||||
return v.value
|
||||
else:
|
||||
return v
|
||||
elif isinstance(v, (float, int, Decimal)):
|
||||
# is there anything else we want to add here? If you think so, create an issue.
|
||||
return str(v)
|
||||
elif isinstance(v, (bytes, bytearray)):
|
||||
return v.decode()
|
||||
else:
|
||||
raise errors.StrError()
|
||||
|
||||
|
||||
def strict_str_validator(v: Any) -> Union[str]:
|
||||
if isinstance(v, str) and not isinstance(v, Enum):
|
||||
return v
|
||||
raise errors.StrError()
|
||||
|
||||
|
||||
def bytes_validator(v: Any) -> Union[bytes]:
|
||||
if isinstance(v, bytes):
|
||||
return v
|
||||
elif isinstance(v, bytearray):
|
||||
return bytes(v)
|
||||
elif isinstance(v, str):
|
||||
return v.encode()
|
||||
elif isinstance(v, (float, int, Decimal)):
|
||||
return str(v).encode()
|
||||
else:
|
||||
raise errors.BytesError()
|
||||
|
||||
|
||||
def strict_bytes_validator(v: Any) -> Union[bytes]:
|
||||
if isinstance(v, bytes):
|
||||
return v
|
||||
elif isinstance(v, bytearray):
|
||||
return bytes(v)
|
||||
else:
|
||||
raise errors.BytesError()
|
||||
|
||||
|
||||
BOOL_FALSE = {0, '0', 'off', 'f', 'false', 'n', 'no'}
|
||||
BOOL_TRUE = {1, '1', 'on', 't', 'true', 'y', 'yes'}
|
||||
|
||||
|
||||
def bool_validator(v: Any) -> bool:
|
||||
if v is True or v is False:
|
||||
return v
|
||||
if isinstance(v, bytes):
|
||||
v = v.decode()
|
||||
if isinstance(v, str):
|
||||
v = v.lower()
|
||||
try:
|
||||
if v in BOOL_TRUE:
|
||||
return True
|
||||
if v in BOOL_FALSE:
|
||||
return False
|
||||
except TypeError:
|
||||
raise errors.BoolError()
|
||||
raise errors.BoolError()
|
||||
|
||||
|
||||
# matches the default limit cpython, see https://github.com/python/cpython/pull/96500
|
||||
max_str_int = 4_300
|
||||
|
||||
|
||||
def int_validator(v: Any) -> int:
|
||||
if isinstance(v, int) and not (v is True or v is False):
|
||||
return v
|
||||
|
||||
# see https://github.com/pydantic/pydantic/issues/1477 and in turn, https://github.com/python/cpython/issues/95778
|
||||
# this check should be unnecessary once patch releases are out for 3.7, 3.8, 3.9 and 3.10
|
||||
# but better to check here until then.
|
||||
# NOTICE: this does not fully protect user from the DOS risk since the standard library JSON implementation
|
||||
# (and other std lib modules like xml) use `int()` and are likely called before this, the best workaround is to
|
||||
# 1. update to the latest patch release of python once released, 2. use a different JSON library like ujson
|
||||
if isinstance(v, (str, bytes, bytearray)) and len(v) > max_str_int:
|
||||
raise errors.IntegerError()
|
||||
|
||||
try:
|
||||
return int(v)
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
raise errors.IntegerError()
|
||||
|
||||
|
||||
def strict_int_validator(v: Any) -> int:
|
||||
if isinstance(v, int) and not (v is True or v is False):
|
||||
return v
|
||||
raise errors.IntegerError()
|
||||
|
||||
|
||||
def float_validator(v: Any) -> float:
|
||||
if isinstance(v, float):
|
||||
return v
|
||||
|
||||
try:
|
||||
return float(v)
|
||||
except (TypeError, ValueError):
|
||||
raise errors.FloatError()
|
||||
|
||||
|
||||
def strict_float_validator(v: Any) -> float:
|
||||
if isinstance(v, float):
|
||||
return v
|
||||
raise errors.FloatError()
|
||||
|
||||
|
||||
def float_finite_validator(v: 'Number', field: 'ModelField', config: 'BaseConfig') -> 'Number':
|
||||
allow_inf_nan = getattr(field.type_, 'allow_inf_nan', None)
|
||||
if allow_inf_nan is None:
|
||||
allow_inf_nan = config.allow_inf_nan
|
||||
|
||||
if allow_inf_nan is False and (math.isnan(v) or math.isinf(v)):
|
||||
raise errors.NumberNotFiniteError()
|
||||
return v
|
||||
|
||||
|
||||
def number_multiple_validator(v: 'Number', field: 'ModelField') -> 'Number':
|
||||
field_type: ConstrainedNumber = field.type_
|
||||
if field_type.multiple_of is not None:
|
||||
mod = float(v) / float(field_type.multiple_of) % 1
|
||||
if not almost_equal_floats(mod, 0.0) and not almost_equal_floats(mod, 1.0):
|
||||
raise errors.NumberNotMultipleError(multiple_of=field_type.multiple_of)
|
||||
return v
|
||||
|
||||
|
||||
def number_size_validator(v: 'Number', field: 'ModelField') -> 'Number':
|
||||
field_type: ConstrainedNumber = field.type_
|
||||
if field_type.gt is not None and not v > field_type.gt:
|
||||
raise errors.NumberNotGtError(limit_value=field_type.gt)
|
||||
elif field_type.ge is not None and not v >= field_type.ge:
|
||||
raise errors.NumberNotGeError(limit_value=field_type.ge)
|
||||
|
||||
if field_type.lt is not None and not v < field_type.lt:
|
||||
raise errors.NumberNotLtError(limit_value=field_type.lt)
|
||||
if field_type.le is not None and not v <= field_type.le:
|
||||
raise errors.NumberNotLeError(limit_value=field_type.le)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def constant_validator(v: 'Any', field: 'ModelField') -> 'Any':
|
||||
"""Validate ``const`` fields.
|
||||
|
||||
The value provided for a ``const`` field must be equal to the default value
|
||||
of the field. This is to support the keyword of the same name in JSON
|
||||
Schema.
|
||||
"""
|
||||
if v != field.default:
|
||||
raise errors.WrongConstantError(given=v, permitted=[field.default])
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def anystr_length_validator(v: 'StrBytes', config: 'BaseConfig') -> 'StrBytes':
|
||||
v_len = len(v)
|
||||
|
||||
min_length = config.min_anystr_length
|
||||
if v_len < min_length:
|
||||
raise errors.AnyStrMinLengthError(limit_value=min_length)
|
||||
|
||||
max_length = config.max_anystr_length
|
||||
if max_length is not None and v_len > max_length:
|
||||
raise errors.AnyStrMaxLengthError(limit_value=max_length)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def anystr_strip_whitespace(v: 'StrBytes') -> 'StrBytes':
|
||||
return v.strip()
|
||||
|
||||
|
||||
def anystr_upper(v: 'StrBytes') -> 'StrBytes':
|
||||
return v.upper()
|
||||
|
||||
|
||||
def anystr_lower(v: 'StrBytes') -> 'StrBytes':
|
||||
return v.lower()
|
||||
|
||||
|
||||
def ordered_dict_validator(v: Any) -> 'AnyOrderedDict':
|
||||
if isinstance(v, OrderedDict):
|
||||
return v
|
||||
|
||||
try:
|
||||
return OrderedDict(v)
|
||||
except (TypeError, ValueError):
|
||||
raise errors.DictError()
|
||||
|
||||
|
||||
def dict_validator(v: Any) -> Dict[Any, Any]:
|
||||
if isinstance(v, dict):
|
||||
return v
|
||||
|
||||
try:
|
||||
return dict(v)
|
||||
except (TypeError, ValueError):
|
||||
raise errors.DictError()
|
||||
|
||||
|
||||
def list_validator(v: Any) -> List[Any]:
|
||||
if isinstance(v, list):
|
||||
return v
|
||||
elif sequence_like(v):
|
||||
return list(v)
|
||||
else:
|
||||
raise errors.ListError()
|
||||
|
||||
|
||||
def tuple_validator(v: Any) -> Tuple[Any, ...]:
|
||||
if isinstance(v, tuple):
|
||||
return v
|
||||
elif sequence_like(v):
|
||||
return tuple(v)
|
||||
else:
|
||||
raise errors.TupleError()
|
||||
|
||||
|
||||
def set_validator(v: Any) -> Set[Any]:
|
||||
if isinstance(v, set):
|
||||
return v
|
||||
elif sequence_like(v):
|
||||
return set(v)
|
||||
else:
|
||||
raise errors.SetError()
|
||||
|
||||
|
||||
def frozenset_validator(v: Any) -> FrozenSet[Any]:
|
||||
if isinstance(v, frozenset):
|
||||
return v
|
||||
elif sequence_like(v):
|
||||
return frozenset(v)
|
||||
else:
|
||||
raise errors.FrozenSetError()
|
||||
|
||||
|
||||
def deque_validator(v: Any) -> Deque[Any]:
|
||||
if isinstance(v, deque):
|
||||
return v
|
||||
elif sequence_like(v):
|
||||
return deque(v)
|
||||
else:
|
||||
raise errors.DequeError()
|
||||
|
||||
|
||||
def enum_member_validator(v: Any, field: 'ModelField', config: 'BaseConfig') -> Enum:
|
||||
try:
|
||||
enum_v = field.type_(v)
|
||||
except ValueError:
|
||||
# field.type_ should be an enum, so will be iterable
|
||||
raise errors.EnumMemberError(enum_values=list(field.type_))
|
||||
return enum_v.value if config.use_enum_values else enum_v
|
||||
|
||||
|
||||
def uuid_validator(v: Any, field: 'ModelField') -> UUID:
|
||||
try:
|
||||
if isinstance(v, str):
|
||||
v = UUID(v)
|
||||
elif isinstance(v, (bytes, bytearray)):
|
||||
try:
|
||||
v = UUID(v.decode())
|
||||
except ValueError:
|
||||
# 16 bytes in big-endian order as the bytes argument fail
|
||||
# the above check
|
||||
v = UUID(bytes=v)
|
||||
except ValueError:
|
||||
raise errors.UUIDError()
|
||||
|
||||
if not isinstance(v, UUID):
|
||||
raise errors.UUIDError()
|
||||
|
||||
required_version = getattr(field.type_, '_required_version', None)
|
||||
if required_version and v.version != required_version:
|
||||
raise errors.UUIDVersionError(required_version=required_version)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def decimal_validator(v: Any) -> Decimal:
|
||||
if isinstance(v, Decimal):
|
||||
return v
|
||||
elif isinstance(v, (bytes, bytearray)):
|
||||
v = v.decode()
|
||||
|
||||
v = str(v).strip()
|
||||
|
||||
try:
|
||||
v = Decimal(v)
|
||||
except DecimalException:
|
||||
raise errors.DecimalError()
|
||||
|
||||
if not v.is_finite():
|
||||
raise errors.DecimalIsNotFiniteError()
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def hashable_validator(v: Any) -> Hashable:
|
||||
if isinstance(v, Hashable):
|
||||
return v
|
||||
|
||||
raise errors.HashableError()
|
||||
|
||||
|
||||
def ip_v4_address_validator(v: Any) -> IPv4Address:
|
||||
if isinstance(v, IPv4Address):
|
||||
return v
|
||||
|
||||
try:
|
||||
return IPv4Address(v)
|
||||
except ValueError:
|
||||
raise errors.IPv4AddressError()
|
||||
|
||||
|
||||
def ip_v6_address_validator(v: Any) -> IPv6Address:
|
||||
if isinstance(v, IPv6Address):
|
||||
return v
|
||||
|
||||
try:
|
||||
return IPv6Address(v)
|
||||
except ValueError:
|
||||
raise errors.IPv6AddressError()
|
||||
|
||||
|
||||
def ip_v4_network_validator(v: Any) -> IPv4Network:
|
||||
"""
|
||||
Assume IPv4Network initialised with a default ``strict`` argument
|
||||
|
||||
See more:
|
||||
https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network
|
||||
"""
|
||||
if isinstance(v, IPv4Network):
|
||||
return v
|
||||
|
||||
try:
|
||||
return IPv4Network(v)
|
||||
except ValueError:
|
||||
raise errors.IPv4NetworkError()
|
||||
|
||||
|
||||
def ip_v6_network_validator(v: Any) -> IPv6Network:
|
||||
"""
|
||||
Assume IPv6Network initialised with a default ``strict`` argument
|
||||
|
||||
See more:
|
||||
https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network
|
||||
"""
|
||||
if isinstance(v, IPv6Network):
|
||||
return v
|
||||
|
||||
try:
|
||||
return IPv6Network(v)
|
||||
except ValueError:
|
||||
raise errors.IPv6NetworkError()
|
||||
|
||||
|
||||
def ip_v4_interface_validator(v: Any) -> IPv4Interface:
|
||||
if isinstance(v, IPv4Interface):
|
||||
return v
|
||||
|
||||
try:
|
||||
return IPv4Interface(v)
|
||||
except ValueError:
|
||||
raise errors.IPv4InterfaceError()
|
||||
|
||||
|
||||
def ip_v6_interface_validator(v: Any) -> IPv6Interface:
|
||||
if isinstance(v, IPv6Interface):
|
||||
return v
|
||||
|
||||
try:
|
||||
return IPv6Interface(v)
|
||||
except ValueError:
|
||||
raise errors.IPv6InterfaceError()
|
||||
|
||||
|
||||
def path_validator(v: Any) -> Path:
|
||||
if isinstance(v, Path):
|
||||
return v
|
||||
|
||||
try:
|
||||
return Path(v)
|
||||
except TypeError:
|
||||
raise errors.PathError()
|
||||
|
||||
|
||||
def path_exists_validator(v: Any) -> Path:
|
||||
if not v.exists():
|
||||
raise errors.PathNotExistsError(path=v)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def callable_validator(v: Any) -> AnyCallable:
|
||||
"""
|
||||
Perform a simple check if the value is callable.
|
||||
|
||||
Note: complete matching of argument type hints and return types is not performed
|
||||
"""
|
||||
if callable(v):
|
||||
return v
|
||||
|
||||
raise errors.CallableError(value=v)
|
||||
|
||||
|
||||
def enum_validator(v: Any) -> Enum:
|
||||
if isinstance(v, Enum):
|
||||
return v
|
||||
|
||||
raise errors.EnumError(value=v)
|
||||
|
||||
|
||||
def int_enum_validator(v: Any) -> IntEnum:
|
||||
if isinstance(v, IntEnum):
|
||||
return v
|
||||
|
||||
raise errors.IntEnumError(value=v)
|
||||
|
||||
|
||||
def make_literal_validator(type_: Any) -> Callable[[Any], Any]:
|
||||
permitted_choices = all_literal_values(type_)
|
||||
|
||||
# To have a O(1) complexity and still return one of the values set inside the `Literal`,
|
||||
# we create a dict with the set values (a set causes some problems with the way intersection works).
|
||||
# In some cases the set value and checked value can indeed be different (see `test_literal_validator_str_enum`)
|
||||
allowed_choices = {v: v for v in permitted_choices}
|
||||
|
||||
def literal_validator(v: Any) -> Any:
|
||||
try:
|
||||
return allowed_choices[v]
|
||||
except (KeyError, TypeError):
|
||||
raise errors.WrongConstantError(given=v, permitted=permitted_choices)
|
||||
|
||||
return literal_validator
|
||||
|
||||
|
||||
def constr_length_validator(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes':
|
||||
v_len = len(v)
|
||||
|
||||
min_length = field.type_.min_length if field.type_.min_length is not None else config.min_anystr_length
|
||||
if v_len < min_length:
|
||||
raise errors.AnyStrMinLengthError(limit_value=min_length)
|
||||
|
||||
max_length = field.type_.max_length if field.type_.max_length is not None else config.max_anystr_length
|
||||
if max_length is not None and v_len > max_length:
|
||||
raise errors.AnyStrMaxLengthError(limit_value=max_length)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def constr_strip_whitespace(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes':
|
||||
strip_whitespace = field.type_.strip_whitespace or config.anystr_strip_whitespace
|
||||
if strip_whitespace:
|
||||
v = v.strip()
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def constr_upper(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes':
|
||||
upper = field.type_.to_upper or config.anystr_upper
|
||||
if upper:
|
||||
v = v.upper()
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def constr_lower(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes':
|
||||
lower = field.type_.to_lower or config.anystr_lower
|
||||
if lower:
|
||||
v = v.lower()
|
||||
return v
|
||||
|
||||
|
||||
def validate_json(v: Any, config: 'BaseConfig') -> Any:
|
||||
if v is None:
|
||||
# pass None through to other validators
|
||||
return v
|
||||
try:
|
||||
return config.json_loads(v) # type: ignore
|
||||
except ValueError:
|
||||
raise errors.JsonError()
|
||||
except TypeError:
|
||||
raise errors.JsonTypeError()
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def make_arbitrary_type_validator(type_: Type[T]) -> Callable[[T], T]:
|
||||
def arbitrary_type_validator(v: Any) -> T:
|
||||
if isinstance(v, type_):
|
||||
return v
|
||||
raise errors.ArbitraryTypeError(expected_arbitrary_type=type_)
|
||||
|
||||
return arbitrary_type_validator
|
||||
|
||||
|
||||
def make_class_validator(type_: Type[T]) -> Callable[[Any], Type[T]]:
|
||||
def class_validator(v: Any) -> Type[T]:
|
||||
if lenient_issubclass(v, type_):
|
||||
return v
|
||||
raise errors.SubclassError(expected_class=type_)
|
||||
|
||||
return class_validator
|
||||
|
||||
|
||||
def any_class_validator(v: Any) -> Type[T]:
|
||||
if isinstance(v, type):
|
||||
return v
|
||||
raise errors.ClassError()
|
||||
|
||||
|
||||
def none_validator(v: Any) -> 'Literal[None]':
|
||||
if v is None:
|
||||
return v
|
||||
raise errors.NotNoneError()
|
||||
|
||||
|
||||
def pattern_validator(v: Any) -> Pattern[str]:
|
||||
if isinstance(v, Pattern):
|
||||
return v
|
||||
|
||||
str_value = str_validator(v)
|
||||
|
||||
try:
|
||||
return re.compile(str_value)
|
||||
except re.error:
|
||||
raise errors.PatternError()
|
||||
|
||||
|
||||
NamedTupleT = TypeVar('NamedTupleT', bound=NamedTuple)
|
||||
|
||||
|
||||
def make_namedtuple_validator(
|
||||
namedtuple_cls: Type[NamedTupleT], config: Type['BaseConfig']
|
||||
) -> Callable[[Tuple[Any, ...]], NamedTupleT]:
|
||||
from pydantic.v1.annotated_types import create_model_from_namedtuple
|
||||
|
||||
NamedTupleModel = create_model_from_namedtuple(
|
||||
namedtuple_cls,
|
||||
__config__=config,
|
||||
__module__=namedtuple_cls.__module__,
|
||||
)
|
||||
namedtuple_cls.__pydantic_model__ = NamedTupleModel # type: ignore[attr-defined]
|
||||
|
||||
def namedtuple_validator(values: Tuple[Any, ...]) -> NamedTupleT:
|
||||
annotations = NamedTupleModel.__annotations__
|
||||
|
||||
if len(values) > len(annotations):
|
||||
raise errors.ListMaxLengthError(limit_value=len(annotations))
|
||||
|
||||
dict_values: Dict[str, Any] = dict(zip(annotations, values))
|
||||
validated_dict_values: Dict[str, Any] = dict(NamedTupleModel(**dict_values))
|
||||
return namedtuple_cls(**validated_dict_values)
|
||||
|
||||
return namedtuple_validator
|
||||
|
||||
|
||||
def make_typeddict_validator(
|
||||
typeddict_cls: Type['TypedDict'], config: Type['BaseConfig'] # type: ignore[valid-type]
|
||||
) -> Callable[[Any], Dict[str, Any]]:
|
||||
from pydantic.v1.annotated_types import create_model_from_typeddict
|
||||
|
||||
TypedDictModel = create_model_from_typeddict(
|
||||
typeddict_cls,
|
||||
__config__=config,
|
||||
__module__=typeddict_cls.__module__,
|
||||
)
|
||||
typeddict_cls.__pydantic_model__ = TypedDictModel # type: ignore[attr-defined]
|
||||
|
||||
def typeddict_validator(values: 'TypedDict') -> Dict[str, Any]: # type: ignore[valid-type]
|
||||
return TypedDictModel.parse_obj(values).dict(exclude_unset=True)
|
||||
|
||||
return typeddict_validator
|
||||
|
||||
|
||||
class IfConfig:
|
||||
def __init__(self, validator: AnyCallable, *config_attr_names: str, ignored_value: Any = False) -> None:
|
||||
self.validator = validator
|
||||
self.config_attr_names = config_attr_names
|
||||
self.ignored_value = ignored_value
|
||||
|
||||
def check(self, config: Type['BaseConfig']) -> bool:
|
||||
return any(getattr(config, name) not in {None, self.ignored_value} for name in self.config_attr_names)
|
||||
|
||||
|
||||
# order is important here, for example: bool is a subclass of int so has to come first, datetime before date same,
|
||||
# IPv4Interface before IPv4Address, etc
|
||||
_VALIDATORS: List[Tuple[Type[Any], List[Any]]] = [
|
||||
(IntEnum, [int_validator, enum_member_validator]),
|
||||
(Enum, [enum_member_validator]),
|
||||
(
|
||||
str,
|
||||
[
|
||||
str_validator,
|
||||
IfConfig(anystr_strip_whitespace, 'anystr_strip_whitespace'),
|
||||
IfConfig(anystr_upper, 'anystr_upper'),
|
||||
IfConfig(anystr_lower, 'anystr_lower'),
|
||||
IfConfig(anystr_length_validator, 'min_anystr_length', 'max_anystr_length'),
|
||||
],
|
||||
),
|
||||
(
|
||||
bytes,
|
||||
[
|
||||
bytes_validator,
|
||||
IfConfig(anystr_strip_whitespace, 'anystr_strip_whitespace'),
|
||||
IfConfig(anystr_upper, 'anystr_upper'),
|
||||
IfConfig(anystr_lower, 'anystr_lower'),
|
||||
IfConfig(anystr_length_validator, 'min_anystr_length', 'max_anystr_length'),
|
||||
],
|
||||
),
|
||||
(bool, [bool_validator]),
|
||||
(int, [int_validator]),
|
||||
(float, [float_validator, IfConfig(float_finite_validator, 'allow_inf_nan', ignored_value=True)]),
|
||||
(Path, [path_validator]),
|
||||
(datetime, [parse_datetime]),
|
||||
(date, [parse_date]),
|
||||
(time, [parse_time]),
|
||||
(timedelta, [parse_duration]),
|
||||
(OrderedDict, [ordered_dict_validator]),
|
||||
(dict, [dict_validator]),
|
||||
(list, [list_validator]),
|
||||
(tuple, [tuple_validator]),
|
||||
(set, [set_validator]),
|
||||
(frozenset, [frozenset_validator]),
|
||||
(deque, [deque_validator]),
|
||||
(UUID, [uuid_validator]),
|
||||
(Decimal, [decimal_validator]),
|
||||
(IPv4Interface, [ip_v4_interface_validator]),
|
||||
(IPv6Interface, [ip_v6_interface_validator]),
|
||||
(IPv4Address, [ip_v4_address_validator]),
|
||||
(IPv6Address, [ip_v6_address_validator]),
|
||||
(IPv4Network, [ip_v4_network_validator]),
|
||||
(IPv6Network, [ip_v6_network_validator]),
|
||||
]
|
||||
|
||||
|
||||
def find_validators( # noqa: C901 (ignore complexity)
|
||||
type_: Type[Any], config: Type['BaseConfig']
|
||||
) -> Generator[AnyCallable, None, None]:
|
||||
from pydantic.v1.dataclasses import is_builtin_dataclass, make_dataclass_validator
|
||||
|
||||
if type_ is Any or type_ is object:
|
||||
return
|
||||
type_type = type_.__class__
|
||||
if type_type == ForwardRef or type_type == TypeVar:
|
||||
return
|
||||
|
||||
if is_none_type(type_):
|
||||
yield none_validator
|
||||
return
|
||||
if type_ is Pattern or type_ is re.Pattern:
|
||||
yield pattern_validator
|
||||
return
|
||||
if type_ is Hashable or type_ is CollectionsHashable:
|
||||
yield hashable_validator
|
||||
return
|
||||
if is_callable_type(type_):
|
||||
yield callable_validator
|
||||
return
|
||||
if is_literal_type(type_):
|
||||
yield make_literal_validator(type_)
|
||||
return
|
||||
if is_builtin_dataclass(type_):
|
||||
yield from make_dataclass_validator(type_, config)
|
||||
return
|
||||
if type_ is Enum:
|
||||
yield enum_validator
|
||||
return
|
||||
if type_ is IntEnum:
|
||||
yield int_enum_validator
|
||||
return
|
||||
if is_namedtuple(type_):
|
||||
yield tuple_validator
|
||||
yield make_namedtuple_validator(type_, config)
|
||||
return
|
||||
if is_typeddict(type_):
|
||||
yield make_typeddict_validator(type_, config)
|
||||
return
|
||||
|
||||
class_ = get_class(type_)
|
||||
if class_ is not None:
|
||||
if class_ is not Any and isinstance(class_, type):
|
||||
yield make_class_validator(class_)
|
||||
else:
|
||||
yield any_class_validator
|
||||
return
|
||||
|
||||
for val_type, validators in _VALIDATORS:
|
||||
try:
|
||||
if issubclass(type_, val_type):
|
||||
for v in validators:
|
||||
if isinstance(v, IfConfig):
|
||||
if v.check(config):
|
||||
yield v.validator
|
||||
else:
|
||||
yield v
|
||||
return
|
||||
except TypeError:
|
||||
raise RuntimeError(f'error checking inheritance of {type_!r} (type: {display_as_type(type_)})')
|
||||
|
||||
if config.arbitrary_types_allowed:
|
||||
yield make_arbitrary_type_validator(type_)
|
||||
else:
|
||||
if hasattr(type_, '__pydantic_core_schema__'):
|
||||
warn(f'Mixing V1 and V2 models is not supported. `{type_.__name__}` is a V2 model.', UserWarning)
|
||||
raise RuntimeError(f'no validator found for {type_}, see `arbitrary_types_allowed` in Config')
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user