mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-06 12:14:28 +08:00
add BaseManager.create() typechecking
This commit is contained in:
@@ -11,7 +11,7 @@ CHANGE: int
|
||||
DELETION: int
|
||||
ACTION_FLAG_CHOICES: Any
|
||||
|
||||
class LogEntryManager(models.Manager):
|
||||
class LogEntryManager(models.Manager["LogEntry"]):
|
||||
def log_action(
|
||||
self,
|
||||
user_id: int,
|
||||
|
||||
@@ -16,7 +16,7 @@ class Permission(models.Model):
|
||||
content_type_id: int
|
||||
name: models.CharField = ...
|
||||
content_type: models.ForeignKey[ContentType] = ...
|
||||
codename: str = ...
|
||||
codename: models.CharField = ...
|
||||
def natural_key(self) -> Tuple[str, str, str]: ...
|
||||
|
||||
class GroupManager(models.Manager):
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from django.db import models
|
||||
|
||||
class Redirect(models.Model):
|
||||
id: None
|
||||
site_id: int
|
||||
site: Any = ...
|
||||
old_path: str = ...
|
||||
new_path: str = ...
|
||||
site: models.ForeignKey = ...
|
||||
old_path: models.CharField = ...
|
||||
new_path: models.CharField = ...
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
from typing import Dict, Optional, Type, Union
|
||||
|
||||
from django.contrib.sessions.backends.base import SessionBase
|
||||
from django.contrib.sessions.base_session import AbstractBaseSession
|
||||
from django.contrib.sessions.models import Session
|
||||
from django.core.signing import Serializer
|
||||
from django.db.models.base import Model
|
||||
|
||||
class SessionStore(SessionBase):
|
||||
accessed: bool
|
||||
serializer: Type[django.core.signing.JSONSerializer]
|
||||
serializer: Type[Serializer]
|
||||
def __init__(self, session_key: Optional[str] = ...) -> None: ...
|
||||
@classmethod
|
||||
def get_model_class(cls) -> Type[Session]: ...
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union, ClassVar, Sequence
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union, ClassVar, Sequence, Generic
|
||||
|
||||
from django.db.models.manager import Manager
|
||||
|
||||
@@ -10,9 +10,9 @@ class Model(metaclass=ModelBase):
|
||||
class DoesNotExist(Exception): ...
|
||||
class Meta: ...
|
||||
_meta: Any
|
||||
_default_manager: Manager[Model]
|
||||
pk: Any = ...
|
||||
objects: Manager[Model]
|
||||
def __init__(self, *args, **kwargs) -> None: ...
|
||||
def __init__(self: _Self, *args, **kwargs) -> None: ...
|
||||
def delete(self, using: Any = ..., keep_parents: bool = ...) -> Tuple[int, Dict[str, int]]: ...
|
||||
def full_clean(self, exclude: Optional[List[str]] = ..., validate_unique: bool = ...) -> None: ...
|
||||
def clean_fields(self, exclude: List[str] = ...) -> None: ...
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import uuid
|
||||
from datetime import date, time, datetime, timedelta
|
||||
from typing import Any, Optional, Tuple, Iterable, Callable, Dict, Union, Type, TypeVar
|
||||
from typing import Any, Optional, Tuple, Iterable, Callable, Dict, Union, Type, TypeVar, Generic
|
||||
import decimal
|
||||
|
||||
from typing_extensions import Literal
|
||||
@@ -14,7 +14,7 @@ from django.forms import Widget, Field as FormField
|
||||
from .mixins import NOT_PROVIDED as NOT_PROVIDED
|
||||
|
||||
_Choice = Tuple[Any, Any]
|
||||
_ChoiceNamedGroup = Union[Tuple[str, Iterable[_Choice]], Tuple[str, Any]]
|
||||
_ChoiceNamedGroup = Tuple[str, Iterable[_Choice]]
|
||||
_FieldChoices = Iterable[Union[_Choice, _ChoiceNamedGroup]]
|
||||
|
||||
_ValidatorCallable = Callable[..., None]
|
||||
@@ -76,7 +76,7 @@ class SmallIntegerField(IntegerField): ...
|
||||
class BigIntegerField(IntegerField): ...
|
||||
|
||||
class FloatField(Field):
|
||||
def __set__(self, instance, value: Union[float, int, Combinable]) -> float: ...
|
||||
def __set__(self, instance, value: Union[float, int, str, Combinable]) -> float: ...
|
||||
def __get__(self, instance, owner) -> float: ...
|
||||
|
||||
class DecimalField(Field):
|
||||
@@ -102,7 +102,7 @@ class DecimalField(Field):
|
||||
validators: Iterable[_ValidatorCallable] = ...,
|
||||
error_messages: Optional[_ErrorMessagesToOverride] = ...,
|
||||
): ...
|
||||
def __set__(self, instance, value: Union[str, Combinable]) -> decimal.Decimal: ...
|
||||
def __set__(self, instance, value: Union[str, float, decimal.Decimal, Combinable]) -> decimal.Decimal: ...
|
||||
def __get__(self, instance, owner) -> decimal.Decimal: ...
|
||||
|
||||
class AutoField(Field):
|
||||
@@ -167,15 +167,15 @@ class EmailField(CharField): ...
|
||||
class URLField(CharField): ...
|
||||
|
||||
class TextField(Field):
|
||||
def __set__(self, instance, value: str) -> None: ...
|
||||
def __set__(self, instance, value: Union[str, Combinable]) -> None: ...
|
||||
def __get__(self, instance, owner) -> str: ...
|
||||
|
||||
class BooleanField(Field):
|
||||
def __set__(self, instance, value: bool) -> None: ...
|
||||
def __set__(self, instance, value: Union[bool, Combinable]) -> None: ...
|
||||
def __get__(self, instance, owner) -> bool: ...
|
||||
|
||||
class NullBooleanField(Field):
|
||||
def __set__(self, instance, value: Optional[bool]) -> None: ...
|
||||
def __set__(self, instance, value: Optional[Union[bool, Combinable]]) -> None: ...
|
||||
def __get__(self, instance, owner) -> Optional[bool]: ...
|
||||
|
||||
class IPAddressField(Field):
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from importlib.abc import SourceLoader
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
from types import TracebackType
|
||||
from typing import Any, Callable, Dict, List, MutableMapping, Optional, Type, Union
|
||||
|
||||
from django.core.handlers.wsgi import WSGIRequest
|
||||
from django.http.request import QueryDict
|
||||
from django.http.request import HttpRequest, QueryDict
|
||||
from django.http.response import Http404, HttpResponse
|
||||
from django.utils.safestring import SafeText
|
||||
|
||||
@@ -19,21 +18,21 @@ def cleanse_setting(key: Union[int, str], value: Any) -> Any: ...
|
||||
def get_safe_settings() -> Dict[str, Any]: ...
|
||||
def technical_500_response(request: Any, exc_type: Any, exc_value: Any, tb: Any, status_code: int = ...): ...
|
||||
def get_default_exception_reporter_filter() -> ExceptionReporterFilter: ...
|
||||
def get_exception_reporter_filter(request: Optional[WSGIRequest]) -> ExceptionReporterFilter: ...
|
||||
def get_exception_reporter_filter(request: Optional[HttpRequest]) -> ExceptionReporterFilter: ...
|
||||
|
||||
class ExceptionReporterFilter:
|
||||
def get_post_parameters(self, request: Any): ...
|
||||
def get_traceback_frame_variables(self, request: Any, tb_frame: Any): ...
|
||||
|
||||
class SafeExceptionReporterFilter(ExceptionReporterFilter):
|
||||
def is_active(self, request: Optional[WSGIRequest]) -> bool: ...
|
||||
def get_cleansed_multivaluedict(self, request: WSGIRequest, multivaluedict: QueryDict) -> QueryDict: ...
|
||||
def get_post_parameters(self, request: Optional[WSGIRequest]) -> Union[Dict[Any, Any], QueryDict]: ...
|
||||
def cleanse_special_types(self, request: Optional[WSGIRequest], value: Any) -> Any: ...
|
||||
def is_active(self, request: Optional[HttpRequest]) -> bool: ...
|
||||
def get_cleansed_multivaluedict(self, request: HttpRequest, multivaluedict: QueryDict) -> QueryDict: ...
|
||||
def get_post_parameters(self, request: Optional[HttpRequest]) -> MutableMapping[str, Any]: ...
|
||||
def cleanse_special_types(self, request: Optional[HttpRequest], value: Any) -> Any: ...
|
||||
def get_traceback_frame_variables(self, request: Any, tb_frame: Any): ...
|
||||
|
||||
class ExceptionReporter:
|
||||
request: Optional[WSGIRequest] = ...
|
||||
request: Optional[HttpRequest] = ...
|
||||
filter: ExceptionReporterFilter = ...
|
||||
exc_type: None = ...
|
||||
exc_value: Optional[str] = ...
|
||||
@@ -44,7 +43,7 @@ class ExceptionReporter:
|
||||
postmortem: None = ...
|
||||
def __init__(
|
||||
self,
|
||||
request: Optional[WSGIRequest],
|
||||
request: Optional[HttpRequest],
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[Union[str, BaseException]],
|
||||
tb: Optional[TracebackType],
|
||||
@@ -63,5 +62,5 @@ class ExceptionReporter:
|
||||
module_name: Optional[str] = None,
|
||||
): ...
|
||||
|
||||
def technical_404_response(request: WSGIRequest, exception: Http404) -> HttpResponse: ...
|
||||
def default_urlconf(request: WSGIRequest) -> HttpResponse: ...
|
||||
def technical_404_response(request: HttpRequest, exception: Http404) -> HttpResponse: ...
|
||||
def default_urlconf(request: HttpRequest) -> HttpResponse: ...
|
||||
|
||||
21
mypy_django_plugin/config.py
Normal file
21
mypy_django_plugin/config.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from configparser import ConfigParser
|
||||
from typing import Optional
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
django_settings_module: Optional[str] = None
|
||||
ignore_missing_settings: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_config_file(self, fpath: str) -> 'Config':
|
||||
ini_config = ConfigParser()
|
||||
ini_config.read(fpath)
|
||||
if not ini_config.has_section('mypy_django_plugin'):
|
||||
raise ValueError('Invalid config file: no [mypy_django_plugin] section')
|
||||
return Config(django_settings_module=ini_config.get('mypy_django_plugin', 'django_settings',
|
||||
fallback=None),
|
||||
ignore_missing_settings=ini_config.get('mypy_django_plugin', 'ignore_missing_settings',
|
||||
fallback=False))
|
||||
@@ -1,10 +1,9 @@
|
||||
import typing
|
||||
from typing import Dict, Optional
|
||||
|
||||
from mypy.nodes import Expression, FuncDef, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo, Var, AssignmentStmt, \
|
||||
CallExpr
|
||||
from mypy.nodes import Expression, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType, UnionType
|
||||
from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeVarType
|
||||
|
||||
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
|
||||
FIELD_FULLNAME = 'django.db.models.fields.Field'
|
||||
@@ -119,74 +118,6 @@ def fill_typevars(tp: Instance, type_to_fill: Instance) -> Instance:
|
||||
return reparametrize_with(type_to_fill, typevar_values)
|
||||
|
||||
|
||||
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
|
||||
if tp.type.has_base(FIELD_FULLNAME):
|
||||
set_method = tp.type.get_method('__set__')
|
||||
if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType):
|
||||
if 'value' in set_method.type.arg_names:
|
||||
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
|
||||
if isinstance(set_value_type, Instance):
|
||||
set_value_type = fill_typevars(tp, set_value_type)
|
||||
return set_value_type
|
||||
elif isinstance(set_value_type, UnionType):
|
||||
items_no_typevars = []
|
||||
for item in set_value_type.items:
|
||||
if isinstance(item, Instance):
|
||||
item = fill_typevars(tp, item)
|
||||
items_no_typevars.append(item)
|
||||
return UnionType(items_no_typevars)
|
||||
|
||||
get_method = tp.type.get_method('__get__')
|
||||
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
|
||||
return get_method.type.ret_type
|
||||
# GenericForeignKey
|
||||
if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME):
|
||||
return AnyType(TypeOfAny.special_form)
|
||||
return None
|
||||
|
||||
|
||||
def extract_primary_key_type(model: TypeInfo) -> Optional[Type]:
|
||||
# only primary keys defined in current class for now
|
||||
for stmt in model.defn.defs.body:
|
||||
if isinstance(stmt, AssignmentStmt) and isinstance(stmt.rvalue, CallExpr):
|
||||
name_expr = stmt.lvalues[0]
|
||||
if isinstance(name_expr, NameExpr):
|
||||
name = name_expr.name
|
||||
if 'primary_key' in stmt.rvalue.arg_names:
|
||||
is_primary_key = stmt.rvalue.args[stmt.rvalue.arg_names.index('primary_key')]
|
||||
if is_primary_key:
|
||||
return extract_field_setter_type(model.names[name].type)
|
||||
return None
|
||||
|
||||
|
||||
def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]:
|
||||
expected_types: Dict[str, Type] = {}
|
||||
|
||||
primary_key_type = extract_primary_key_type(model)
|
||||
if not primary_key_type:
|
||||
# no explicit primary key, set pk to Any and add id
|
||||
primary_key_type = AnyType(TypeOfAny.special_form)
|
||||
expected_types['id'] = ctx.api.named_generic_type('builtins.int', [])
|
||||
|
||||
expected_types['pk'] = primary_key_type
|
||||
|
||||
for base in model.mro:
|
||||
for name, sym in base.names.items():
|
||||
if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance):
|
||||
tp = sym.node.type
|
||||
field_type = extract_field_setter_type(tp)
|
||||
if tp.type.fullname() in {FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME}:
|
||||
ref_to_model = tp.args[0]
|
||||
if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(MODEL_CLASS_FULLNAME):
|
||||
primary_key_type = extract_primary_key_type(ref_to_model.type)
|
||||
if not primary_key_type:
|
||||
primary_key_type = AnyType(TypeOfAny.special_form)
|
||||
expected_types[name + '_id'] = primary_key_type
|
||||
if field_type:
|
||||
expected_types[name] = field_type
|
||||
return expected_types
|
||||
|
||||
|
||||
def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression]:
|
||||
"""Return the expression for the specific argument.
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import os
|
||||
from configparser import ConfigParser
|
||||
from typing import Callable, Dict, Optional, Set, cast
|
||||
from typing import Callable, Dict, Optional, cast
|
||||
|
||||
from dataclasses import dataclass
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import TypeInfo
|
||||
from mypy.options import Options
|
||||
@@ -10,7 +8,9 @@ from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin
|
||||
from mypy.types import Instance, Type
|
||||
|
||||
from mypy_django_plugin import helpers, monkeypatch
|
||||
from mypy_django_plugin.plugins.fields import determine_type_of_array_field
|
||||
from mypy_django_plugin.config import Config
|
||||
from mypy_django_plugin.plugins.fields import determine_type_of_array_field, record_field_properties_into_outer_model_class
|
||||
from mypy_django_plugin.plugins.init_create import redefine_and_typecheck_model_init, redefine_and_typecheck_model_create
|
||||
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations
|
||||
from mypy_django_plugin.plugins.models import process_model_class
|
||||
from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with
|
||||
@@ -56,81 +56,6 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type:
|
||||
return ret
|
||||
|
||||
|
||||
def extract_base_pointer_args(model: TypeInfo) -> Set[str]:
|
||||
pointer_args: Set[str] = set()
|
||||
for base in model.bases:
|
||||
if base.type.has_base(helpers.MODEL_CLASS_FULLNAME):
|
||||
parent_name = base.type.name().lower()
|
||||
pointer_args.add(f'{parent_name}_ptr')
|
||||
pointer_args.add(f'{parent_name}_ptr_id')
|
||||
return pointer_args
|
||||
|
||||
|
||||
def redefine_model_init(ctx: FunctionContext) -> Type:
|
||||
assert isinstance(ctx.default_return_type, Instance)
|
||||
|
||||
api = cast(TypeChecker, ctx.api)
|
||||
model: TypeInfo = ctx.default_return_type.type
|
||||
|
||||
expected_types = helpers.extract_expected_types(ctx, model)
|
||||
# order is preserved, can use for positionals
|
||||
positional_names = list(expected_types.keys())
|
||||
positional_names.remove('pk')
|
||||
visited_positionals = set()
|
||||
|
||||
# check positionals
|
||||
for i, (_, actual_pos_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])):
|
||||
actual_pos_name = positional_names[i]
|
||||
api.check_subtype(actual_pos_type, expected_types[actual_pos_name],
|
||||
ctx.context,
|
||||
'Incompatible type for "{}" of "{}"'.format(actual_pos_name,
|
||||
model.name()),
|
||||
'got', 'expected')
|
||||
visited_positionals.add(actual_pos_name)
|
||||
|
||||
# extract name of base models for _ptr
|
||||
base_pointer_args = extract_base_pointer_args(model)
|
||||
|
||||
# check kwargs
|
||||
for i, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[1], ctx.arg_types[1])):
|
||||
if actual_name in base_pointer_args:
|
||||
# parent_ptr args are not supported
|
||||
continue
|
||||
if actual_name in visited_positionals:
|
||||
continue
|
||||
if actual_name is None:
|
||||
# unpacked dict as kwargs is not supported
|
||||
continue
|
||||
if actual_name not in expected_types:
|
||||
ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
|
||||
model.name()),
|
||||
ctx.context)
|
||||
continue
|
||||
api.check_subtype(actual_type, expected_types[actual_name],
|
||||
ctx.context,
|
||||
'Incompatible type for "{}" of "{}"'.format(actual_name,
|
||||
model.name()),
|
||||
'got', 'expected')
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
django_settings_module: Optional[str] = None
|
||||
ignore_missing_settings: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_config_file(self, fpath: str) -> 'Config':
|
||||
ini_config = ConfigParser()
|
||||
ini_config.read(fpath)
|
||||
if not ini_config.has_section('mypy_django_plugin'):
|
||||
raise ValueError('Invalid config file: no [mypy_django_plugin] section')
|
||||
return Config(django_settings_module=ini_config.get('mypy_django_plugin', 'django_settings',
|
||||
fallback=None),
|
||||
ignore_missing_settings=ini_config.get('mypy_django_plugin', 'ignore_missing_settings',
|
||||
fallback=False))
|
||||
|
||||
|
||||
class DjangoPlugin(Plugin):
|
||||
def __init__(self, options: Options) -> None:
|
||||
super().__init__(options)
|
||||
@@ -194,11 +119,18 @@ class DjangoPlugin(Plugin):
|
||||
|
||||
sym = self.lookup_fully_qualified(fullname)
|
||||
if sym and isinstance(sym.node, TypeInfo):
|
||||
if sym.node.has_base(helpers.FIELD_FULLNAME):
|
||||
return record_field_properties_into_outer_model_class
|
||||
if sym.node.metadata.get('django', {}).get('generated_init'):
|
||||
return redefine_model_init
|
||||
return redefine_and_typecheck_model_init
|
||||
|
||||
def get_method_hook(self, fullname: str
|
||||
) -> Optional[Callable[[MethodContext], Type]]:
|
||||
manager_classes = self._get_current_manager_bases()
|
||||
class_fullname, _, method_name = fullname.rpartition('.')
|
||||
if class_fullname in manager_classes and method_name == 'create':
|
||||
return redefine_and_typecheck_model_create
|
||||
|
||||
if fullname in {'django.apps.registry.Apps.get_model',
|
||||
'django.db.migrations.state.StateApps.get_model'}:
|
||||
return determine_model_cls_from_string_for_migrations
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
from typing import cast
|
||||
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import ListExpr, NameExpr, TupleExpr
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import Type, Instance
|
||||
from mypy.types import Instance, TupleType, Type
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
from mypy_django_plugin.plugins.models import iter_over_assignments
|
||||
|
||||
|
||||
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
|
||||
@@ -16,3 +21,54 @@ def determine_type_of_array_field(ctx: FunctionContext) -> Type:
|
||||
|
||||
return ctx.api.named_generic_type(ctx.context.callee.fullname,
|
||||
args=[get_method.type.ret_type])
|
||||
|
||||
|
||||
def record_field_properties_into_outer_model_class(ctx: FunctionContext) -> Type:
|
||||
api = cast(TypeChecker, ctx.api)
|
||||
outer_model = api.scope.active_class()
|
||||
if outer_model is None or not outer_model.has_base(helpers.MODEL_CLASS_FULLNAME):
|
||||
# outside models.Model class, undetermined
|
||||
return ctx.default_return_type
|
||||
|
||||
field_name = None
|
||||
for name_expr, stmt in iter_over_assignments(outer_model.defn):
|
||||
if stmt == ctx.context and isinstance(name_expr, NameExpr):
|
||||
field_name = name_expr.name
|
||||
break
|
||||
if field_name is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
fields_metadata = outer_model.metadata.setdefault('django', {}).setdefault('fields', {})
|
||||
|
||||
# primary key
|
||||
is_primary_key = False
|
||||
primary_key_arg = helpers.get_argument_by_name(ctx, 'primary_key')
|
||||
if primary_key_arg:
|
||||
is_primary_key = helpers.parse_bool(primary_key_arg)
|
||||
fields_metadata[field_name] = {'primary_key': is_primary_key}
|
||||
|
||||
# choices
|
||||
choices_arg = helpers.get_argument_by_name(ctx, 'choices')
|
||||
if choices_arg and isinstance(choices_arg, (TupleExpr, ListExpr)):
|
||||
# iterable of 2 element tuples of two kinds
|
||||
_, analyzed_choices = api.analyze_iterable_item_type(choices_arg)
|
||||
if isinstance(analyzed_choices, TupleType):
|
||||
first_element_type = analyzed_choices.items[0]
|
||||
if isinstance(first_element_type, Instance):
|
||||
fields_metadata[field_name]['choices'] = first_element_type.type.fullname()
|
||||
|
||||
# nullability
|
||||
null_arg = helpers.get_argument_by_name(ctx, 'null')
|
||||
is_nullable = False
|
||||
if null_arg:
|
||||
is_nullable = helpers.parse_bool(null_arg)
|
||||
fields_metadata[field_name]['null'] = is_nullable
|
||||
|
||||
# is_blankable
|
||||
blank_arg = helpers.get_argument_by_name(ctx, 'blank')
|
||||
is_blankable = False
|
||||
if blank_arg:
|
||||
is_blankable = helpers.parse_bool(blank_arg)
|
||||
fields_metadata[field_name]['blank'] = is_blankable
|
||||
|
||||
return ctx.default_return_type
|
||||
|
||||
182
mypy_django_plugin/plugins/init_create.py
Normal file
182
mypy_django_plugin/plugins/init_create.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from typing import Dict, Optional, Set, cast, Any
|
||||
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import FuncDef, TypeInfo, Var
|
||||
from mypy.plugin import FunctionContext, MethodContext
|
||||
from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, UnionType
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
|
||||
|
||||
def extract_base_pointer_args(model: TypeInfo) -> Set[str]:
|
||||
pointer_args: Set[str] = set()
|
||||
for base in model.bases:
|
||||
if base.type.has_base(helpers.MODEL_CLASS_FULLNAME):
|
||||
parent_name = base.type.name().lower()
|
||||
pointer_args.add(f'{parent_name}_ptr')
|
||||
pointer_args.add(f'{parent_name}_ptr_id')
|
||||
return pointer_args
|
||||
|
||||
|
||||
def redefine_and_typecheck_model_init(ctx: FunctionContext) -> Type:
|
||||
assert isinstance(ctx.default_return_type, Instance)
|
||||
|
||||
api = cast(TypeChecker, ctx.api)
|
||||
model: TypeInfo = ctx.default_return_type.type
|
||||
|
||||
expected_types = extract_expected_types(ctx, model)
|
||||
# order is preserved, can use for positionals
|
||||
positional_names = list(expected_types.keys())
|
||||
positional_names.remove('pk')
|
||||
visited_positionals = set()
|
||||
|
||||
# check positionals
|
||||
for i, (_, actual_pos_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])):
|
||||
actual_pos_name = positional_names[i]
|
||||
api.check_subtype(actual_pos_type, expected_types[actual_pos_name],
|
||||
ctx.context,
|
||||
'Incompatible type for "{}" of "{}"'.format(actual_pos_name,
|
||||
model.name()),
|
||||
'got', 'expected')
|
||||
visited_positionals.add(actual_pos_name)
|
||||
|
||||
# extract name of base models for _ptr
|
||||
base_pointer_args = extract_base_pointer_args(model)
|
||||
|
||||
# check kwargs
|
||||
for i, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[1], ctx.arg_types[1])):
|
||||
if actual_name in base_pointer_args:
|
||||
# parent_ptr args are not supported
|
||||
continue
|
||||
if actual_name in visited_positionals:
|
||||
continue
|
||||
if actual_name is None:
|
||||
# unpacked dict as kwargs is not supported
|
||||
continue
|
||||
if actual_name not in expected_types:
|
||||
ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
|
||||
model.name()),
|
||||
ctx.context)
|
||||
continue
|
||||
api.check_subtype(actual_type, expected_types[actual_name],
|
||||
ctx.context,
|
||||
'Incompatible type for "{}" of "{}"'.format(actual_name,
|
||||
model.name()),
|
||||
'got', 'expected')
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def redefine_and_typecheck_model_create(ctx: MethodContext) -> Type:
|
||||
api = cast(TypeChecker, ctx.api)
|
||||
if isinstance(ctx.type, Instance) and len(ctx.type.args) > 0:
|
||||
model: TypeInfo = ctx.type.args[0].type
|
||||
else:
|
||||
if isinstance(ctx.default_return_type, AnyType):
|
||||
return ctx.default_return_type
|
||||
model: TypeInfo = ctx.default_return_type.type
|
||||
|
||||
# extract name of base models for _ptr
|
||||
base_pointer_args = extract_base_pointer_args(model)
|
||||
expected_types = extract_expected_types(ctx, model)
|
||||
|
||||
for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]):
|
||||
if actual_name in base_pointer_args:
|
||||
# parent_ptr args are not supported
|
||||
continue
|
||||
if actual_name is None:
|
||||
# unpacked dict as kwargs is not supported
|
||||
continue
|
||||
if actual_name not in expected_types:
|
||||
api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
|
||||
model.name()),
|
||||
ctx.context)
|
||||
continue
|
||||
api.check_subtype(actual_type, expected_types[actual_name],
|
||||
ctx.context,
|
||||
'Incompatible type for "{}" of "{}"'.format(actual_name,
|
||||
model.name()),
|
||||
'got', 'expected')
|
||||
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
|
||||
if not isinstance(tp, Instance):
|
||||
return None
|
||||
if tp.type.has_base(helpers.FIELD_FULLNAME):
|
||||
set_method = tp.type.get_method('__set__')
|
||||
if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType):
|
||||
if 'value' in set_method.type.arg_names:
|
||||
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
|
||||
if isinstance(set_value_type, Instance):
|
||||
set_value_type = helpers.fill_typevars(tp, set_value_type)
|
||||
return set_value_type
|
||||
elif isinstance(set_value_type, UnionType):
|
||||
items_no_typevars = []
|
||||
for item in set_value_type.items:
|
||||
if isinstance(item, Instance):
|
||||
item = helpers.fill_typevars(tp, item)
|
||||
items_no_typevars.append(item)
|
||||
return UnionType(items_no_typevars)
|
||||
|
||||
get_method = tp.type.get_method('__get__')
|
||||
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
|
||||
return get_method.type.ret_type
|
||||
# GenericForeignKey
|
||||
if tp.type.has_base(helpers.GENERIC_FOREIGN_KEY_FULLNAME):
|
||||
return AnyType(TypeOfAny.special_form)
|
||||
return None
|
||||
|
||||
|
||||
def get_fields_metadata(model: TypeInfo) -> Dict[str, Any]:
|
||||
return model.metadata.setdefault('django', {}).setdefault('fields', {})
|
||||
|
||||
|
||||
def extract_primary_key_type(model: TypeInfo) -> Optional[Type]:
|
||||
for field_name, props in get_fields_metadata(model).items():
|
||||
is_primary_key = props.get('primary_key', False)
|
||||
if is_primary_key:
|
||||
return extract_field_setter_type(model.names[field_name].type)
|
||||
return None
|
||||
|
||||
|
||||
def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]:
|
||||
field_metadata = get_fields_metadata(model).get(field_name, {})
|
||||
if 'choices' in field_metadata:
|
||||
return field_metadata['choices']
|
||||
return None
|
||||
|
||||
|
||||
def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]:
|
||||
expected_types: Dict[str, Type] = {}
|
||||
|
||||
primary_key_type = extract_primary_key_type(model)
|
||||
if not primary_key_type:
|
||||
# no explicit primary key, set pk to Any and add id
|
||||
primary_key_type = AnyType(TypeOfAny.special_form)
|
||||
expected_types['id'] = ctx.api.named_generic_type('builtins.int', [])
|
||||
|
||||
expected_types['pk'] = primary_key_type
|
||||
|
||||
for base in model.mro:
|
||||
for name, sym in base.names.items():
|
||||
if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance):
|
||||
tp = sym.node.type
|
||||
field_type = extract_field_setter_type(tp)
|
||||
if field_type is None:
|
||||
continue
|
||||
|
||||
choices_type_fullname = extract_choices_type(model, name)
|
||||
if choices_type_fullname:
|
||||
field_type = UnionType([field_type, ctx.api.named_generic_type(choices_type_fullname, [])])
|
||||
|
||||
if tp.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}:
|
||||
ref_to_model = tp.args[0]
|
||||
if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME):
|
||||
primary_key_type = extract_primary_key_type(ref_to_model.type)
|
||||
if not primary_key_type:
|
||||
primary_key_type = AnyType(TypeOfAny.special_form)
|
||||
expected_types[name + '_id'] = primary_key_type
|
||||
if field_type:
|
||||
expected_types[name] = field_type
|
||||
return expected_types
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import cast, Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import TypeInfo, Expression, StrExpr, NameExpr, RefExpr, Var
|
||||
from mypy.nodes import Expression, StrExpr, TypeInfo
|
||||
from mypy.plugin import MethodContext
|
||||
from mypy.types import Type, Instance, TypeType
|
||||
from mypy.types import Instance, Type, TypeType
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Dict, Iterator, Optional, Tuple, cast
|
||||
from typing import Dict, Iterator, List, Optional, Tuple, cast
|
||||
|
||||
import dataclasses
|
||||
from mypy.nodes import AssignmentStmt, CallExpr, ClassDef, Context, Expression, Lvalue, MDEF, MemberExpr, \
|
||||
MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var, Argument, ARG_STAR2, ARG_STAR
|
||||
from mypy.nodes import ARG_STAR, ARG_STAR2, Argument, AssignmentStmt, CallExpr, ClassDef, Context, Expression, IndexExpr, \
|
||||
Lvalue, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var
|
||||
from mypy.plugin import ClassDefContext
|
||||
from mypy.plugins.common import add_method
|
||||
from mypy.semanal import SemanticAnalyzerPass2
|
||||
from mypy.types import Instance, AnyType, TypeOfAny, NoneTyp
|
||||
from mypy.types import AnyType, Instance, NoneTyp, TypeOfAny
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
|
||||
@@ -27,18 +27,20 @@ class ModelClassInitializer(metaclass=ABCMeta):
|
||||
return metaclass_sym.node
|
||||
return None
|
||||
|
||||
def is_abstract_model(self) -> bool:
|
||||
def get_meta_attribute(self, name: str) -> Optional[Expression]:
|
||||
meta_node = self.get_nested_meta_node()
|
||||
if meta_node is None:
|
||||
return False
|
||||
return None
|
||||
|
||||
for lvalue, rvalue in iter_over_assignments(meta_node.defn):
|
||||
if isinstance(lvalue, NameExpr) and lvalue.name == 'abstract':
|
||||
is_abstract = self.api.parse_bool(rvalue)
|
||||
if is_abstract:
|
||||
# abstract model do not need 'objects' queryset
|
||||
return True
|
||||
return False
|
||||
if isinstance(lvalue, NameExpr) and lvalue.name == name:
|
||||
return rvalue
|
||||
|
||||
def is_abstract_model(self) -> bool:
|
||||
is_abstract_expr = self.get_meta_attribute('abstract')
|
||||
if is_abstract_expr is None:
|
||||
return False
|
||||
return self.api.parse_bool(is_abstract_expr)
|
||||
|
||||
def add_new_node_to_model_class(self, name: str, typ: Instance) -> None:
|
||||
var = Var(name=name, type=typ)
|
||||
@@ -93,25 +95,65 @@ class InjectAnyAsBaseForNestedMeta(ModelClassInitializer):
|
||||
meta_node.fallback_to_any = True
|
||||
|
||||
|
||||
def get_model_argument(manager_info: TypeInfo) -> Optional[Instance]:
|
||||
for base in manager_info.bases:
|
||||
if base.args:
|
||||
model_arg = base.args[0]
|
||||
if isinstance(model_arg, Instance) and model_arg.type.has_base(helpers.MODEL_CLASS_FULLNAME):
|
||||
return model_arg
|
||||
return None
|
||||
|
||||
|
||||
class AddDefaultObjectsManager(ModelClassInitializer):
|
||||
def is_default_objects_attr(self, sym: SymbolTableNode) -> bool:
|
||||
return sym.fullname == helpers.MODEL_CLASS_FULLNAME + '.' + 'objects'
|
||||
def add_new_manager(self, name: str, manager_type: Optional[Instance]) -> None:
|
||||
if manager_type is None:
|
||||
return None
|
||||
self.add_new_node_to_model_class(name, manager_type)
|
||||
|
||||
def add_private_default_manager(self, manager_type: Optional[Instance]) -> None:
|
||||
if manager_type is None:
|
||||
return None
|
||||
self.add_new_node_to_model_class('_default_manager', manager_type)
|
||||
|
||||
def get_existing_managers(self) -> List[Tuple[str, TypeInfo]]:
|
||||
managers = []
|
||||
for base in self.model_classdef.info.mro:
|
||||
for name_expr, member_expr in iter_call_assignments(base.defn):
|
||||
manager_name = name_expr.name
|
||||
callee_expr = member_expr.callee
|
||||
if isinstance(callee_expr, IndexExpr):
|
||||
callee_expr = callee_expr.analyzed.expr
|
||||
if isinstance(callee_expr, (MemberExpr, NameExpr)) \
|
||||
and isinstance(callee_expr.node, TypeInfo) \
|
||||
and callee_expr.node.has_base(helpers.BASE_MANAGER_CLASS_FULLNAME):
|
||||
managers.append((manager_name, callee_expr.node))
|
||||
return managers
|
||||
|
||||
def run(self) -> None:
|
||||
existing_objects_sym = self.model_classdef.info.get('objects')
|
||||
if (existing_objects_sym is not None
|
||||
and not self.is_default_objects_attr(existing_objects_sym)):
|
||||
return None
|
||||
existing_managers = self.get_existing_managers()
|
||||
if existing_managers:
|
||||
first_manager_type = None
|
||||
for manager_name, manager_type_info in existing_managers:
|
||||
manager_type = Instance(manager_type_info, args=[Instance(self.model_classdef.info, [])])
|
||||
self.add_new_manager(name=manager_name, manager_type=manager_type)
|
||||
if first_manager_type is None:
|
||||
first_manager_type = manager_type
|
||||
else:
|
||||
if self.is_abstract_model():
|
||||
# abstract models do not need 'objects' queryset
|
||||
return None
|
||||
|
||||
first_manager_type = self.api.named_type_or_none(helpers.MANAGER_CLASS_FULLNAME,
|
||||
args=[Instance(self.model_classdef.info, [])])
|
||||
self.add_new_manager('objects', manager_type=first_manager_type)
|
||||
|
||||
if self.is_abstract_model():
|
||||
# abstract models do not need 'objects' queryset
|
||||
return None
|
||||
|
||||
typ = self.api.named_type_or_none(helpers.MANAGER_CLASS_FULLNAME,
|
||||
args=[Instance(self.model_classdef.info, [])])
|
||||
if not typ:
|
||||
return None
|
||||
self.add_new_node_to_model_class('objects', typ)
|
||||
default_manager_name_expr = self.get_meta_attribute('default_manager_name')
|
||||
if isinstance(default_manager_name_expr, StrExpr):
|
||||
self.add_private_default_manager(self.model_classdef.info.get(default_manager_name_expr.value).type)
|
||||
else:
|
||||
self.add_private_default_manager(first_manager_type)
|
||||
|
||||
|
||||
class AddIdAttributeIfPrimaryKeyTrueIsNotSet(ModelClassInitializer):
|
||||
|
||||
@@ -75,7 +75,9 @@ IGNORED_ERRORS = {
|
||||
'Argument 1 to "FileWrapper" has incompatible type "StringIO"; expected "IO[bytes]"',
|
||||
'Incompatible types in assignment',
|
||||
'"object" not callable',
|
||||
'Incompatible type for "pk" of "Collector" (got "int", expected "str")'
|
||||
'Incompatible type for "pk" of "Collector" (got "int", expected "str")',
|
||||
re.compile('Unexpected attribute "[a-z]+" for model "Model"'),
|
||||
'Unexpected attribute "two_id" for model "CyclicOne"'
|
||||
],
|
||||
'aggregation': [
|
||||
'Incompatible types in assignment (expression has type "QuerySet[Any]", variable has type "List[Any]")',
|
||||
@@ -207,7 +209,8 @@ IGNORED_ERRORS = {
|
||||
'Incompatible types in assignment (expression has type "Type[Field]',
|
||||
'DummyArrayField',
|
||||
'DummyJSONField',
|
||||
'Argument "encoder" to "JSONField" has incompatible type "DjangoJSONEncoder"; expected "Optional[Type[JSONEncoder]]"'
|
||||
'Argument "encoder" to "JSONField" has incompatible type "DjangoJSONEncoder"; expected "Optional[Type[JSONEncoder]]"',
|
||||
'for model "CITestModel"'
|
||||
],
|
||||
'properties': [
|
||||
re.compile('Unexpected attribute "(full_name|full_name_2)" for model "Person"')
|
||||
@@ -430,7 +433,6 @@ TESTS_DIRS = [
|
||||
'model_options',
|
||||
'model_package',
|
||||
'model_regress',
|
||||
# not practical
|
||||
'modeladmin',
|
||||
# TODO: 'multiple_database',
|
||||
'mutually_referential',
|
||||
@@ -495,9 +497,7 @@ TESTS_DIRS = [
|
||||
'transaction_hooks',
|
||||
'transactions',
|
||||
'unmanaged_models',
|
||||
|
||||
'update',
|
||||
|
||||
'update_only_fields',
|
||||
'urlpatterns',
|
||||
|
||||
|
||||
139
test-data/typecheck/managers.test
Normal file
139
test-data/typecheck/managers.test
Normal file
@@ -0,0 +1,139 @@
|
||||
[CASE test_every_model_has_objects_queryset_available]
|
||||
from django.db import models
|
||||
class User(models.Model):
|
||||
pass
|
||||
reveal_type(User.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.User]'
|
||||
reveal_type(User.objects.get()) # E: Revealed type is 'main.User*'
|
||||
|
||||
[CASE every_model_has_its_own_objects_queryset]
|
||||
from django.db import models
|
||||
class Parent(models.Model):
|
||||
pass
|
||||
class Child(Parent):
|
||||
pass
|
||||
reveal_type(Parent.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.Parent]'
|
||||
reveal_type(Child.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.Child]'
|
||||
[out]
|
||||
|
||||
[CASE if_manager_is_defined_on_model_do_not_add_objects]
|
||||
from django.db import models
|
||||
|
||||
class MyModel(models.Model):
|
||||
authors = models.Manager[MyModel]()
|
||||
reveal_type(MyModel.authors) # E: Revealed type is 'django.db.models.manager.Manager[main.MyModel]'
|
||||
reveal_type(MyModel.objects) # E: Revealed type is 'Any'
|
||||
[out]
|
||||
|
||||
[CASE test_model_objects_attribute_present_in_case_of_model_cls_passed_as_generic_parameter]
|
||||
from typing import TypeVar, Generic, Type
|
||||
from django.db import models
|
||||
|
||||
_T = TypeVar('_T', bound=models.Model)
|
||||
class Base(Generic[_T]):
|
||||
def __init__(self, model_cls: Type[_T]):
|
||||
self.model_cls = model_cls
|
||||
reveal_type(self.model_cls._default_manager) # E: Revealed type is 'django.db.models.manager.Manager[django.db.models.base.Model]'
|
||||
class MyModel(models.Model):
|
||||
pass
|
||||
base_instance = Base(MyModel)
|
||||
reveal_type(base_instance.model_cls._default_manager) # E: Revealed type is 'django.db.models.manager.Manager[main.MyModel]'
|
||||
|
||||
class Child(Base[MyModel]):
|
||||
def method(self) -> None:
|
||||
reveal_type(self.model_cls._default_manager) # E: Revealed type is 'django.db.models.manager.Manager[main.MyModel]'
|
||||
|
||||
[CASE if_custom_manager_defined_it_is_set_to_default_manager]
|
||||
from typing import TypeVar
|
||||
from django.db import models
|
||||
_T = TypeVar('_T', bound=models.Model)
|
||||
class CustomManager(models.Manager[_T]):
|
||||
pass
|
||||
class MyModel(models.Model):
|
||||
manager = CustomManager[MyModel]()
|
||||
reveal_type(MyModel._default_manager) # E: Revealed type is 'main.CustomManager[main.MyModel]'
|
||||
|
||||
[CASE if_default_manager_name_is_passed_set_default_manager_to_it]
|
||||
from typing import TypeVar
|
||||
from django.db import models
|
||||
|
||||
_T = TypeVar('_T', bound=models.Model)
|
||||
class Manager1(models.Manager[_T]):
|
||||
pass
|
||||
class Manager2(models.Manager[_T]):
|
||||
pass
|
||||
class MyModel(models.Model):
|
||||
class Meta:
|
||||
default_manager_name = 'm2'
|
||||
m1: Manager1[MyModel]
|
||||
m2: Manager2[MyModel]
|
||||
reveal_type(MyModel._default_manager) # E: Revealed type is 'main.Manager2[main.MyModel]'
|
||||
|
||||
[CASE test_leave_as_is_if_objects_is_set_and_fill_typevars_with_outer_class]
|
||||
from django.db import models
|
||||
|
||||
class UserManager(models.Manager[MyUser]):
|
||||
def get_or_404(self) -> MyUser:
|
||||
pass
|
||||
|
||||
class MyUser(models.Model):
|
||||
objects = UserManager()
|
||||
|
||||
reveal_type(MyUser.objects) # E: Revealed type is 'main.UserManager[main.MyUser]'
|
||||
reveal_type(MyUser.objects.get()) # E: Revealed type is 'main.MyUser*'
|
||||
reveal_type(MyUser.objects.get_or_404()) # E: Revealed type is 'main.MyUser'
|
||||
|
||||
[CASE model_imported_from_different_file]
|
||||
from django.db import models
|
||||
from models.main import Inventory
|
||||
|
||||
class Band(models.Model):
|
||||
pass
|
||||
reveal_type(Inventory.objects) # E: Revealed type is 'django.db.models.manager.Manager[models.main.Inventory]'
|
||||
reveal_type(Band.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.Band]'
|
||||
[file models/__init__.py]
|
||||
[file models/main.py]
|
||||
from django.db import models
|
||||
class Inventory(models.Model):
|
||||
pass
|
||||
|
||||
[CASE managers_that_defined_on_other_models_do_not_influence]
|
||||
from django.db import models
|
||||
|
||||
class AbstractPerson(models.Model):
|
||||
abstract_persons = models.Manager[AbstractPerson]()
|
||||
class PublishedBookManager(models.Manager[Book]):
|
||||
pass
|
||||
class AnnotatedBookManager(models.Manager[Book]):
|
||||
pass
|
||||
class Book(models.Model):
|
||||
title = models.CharField(max_length=50)
|
||||
published_objects = PublishedBookManager()
|
||||
annotated_objects = AnnotatedBookManager()
|
||||
|
||||
reveal_type(AbstractPerson.abstract_persons) # E: Revealed type is 'django.db.models.manager.Manager[main.AbstractPerson]'
|
||||
reveal_type(Book.published_objects) # E: Revealed type is 'main.PublishedBookManager[main.Book]'
|
||||
Book.published_objects.create(title='hello')
|
||||
reveal_type(Book.annotated_objects) # E: Revealed type is 'main.AnnotatedBookManager[main.Book]'
|
||||
Book.annotated_objects.create(title='hello')
|
||||
[out]
|
||||
|
||||
[CASE managers_inherited_from_abstract_classes_multiple_inheritance]
|
||||
from django.db import models
|
||||
class CustomManager1(models.Manager[AbstractBase1]):
|
||||
pass
|
||||
class AbstractBase1(models.Model):
|
||||
class Meta:
|
||||
abstract = True
|
||||
name = models.CharField(max_length=50)
|
||||
manager1 = CustomManager1()
|
||||
class CustomManager2(models.Manager[AbstractBase2]):
|
||||
pass
|
||||
class AbstractBase2(models.Model):
|
||||
class Meta:
|
||||
abstract = True
|
||||
value = models.CharField(max_length=50)
|
||||
restricted = CustomManager2()
|
||||
|
||||
class Child(AbstractBase1, AbstractBase2):
|
||||
pass
|
||||
[out]
|
||||
34
test-data/typecheck/model_create.test
Normal file
34
test-data/typecheck/model_create.test
Normal file
@@ -0,0 +1,34 @@
|
||||
[CASE default_manager_create_is_typechecked]
|
||||
from django.db import models
|
||||
|
||||
class User(models.Model):
|
||||
name = models.CharField(max_length=100)
|
||||
age = models.IntegerField()
|
||||
|
||||
User.objects.create(name='Max', age=10)
|
||||
User.objects.create(name=1010) # E: Incompatible type for "name" of "User" (got "int", expected "Union[str, Combinable]")
|
||||
[out]
|
||||
|
||||
[CASE model_recognises_parent_attributes]
|
||||
from django.db import models
|
||||
|
||||
class Parent(models.Model):
|
||||
name = models.CharField(max_length=100)
|
||||
class Child(Parent):
|
||||
lastname = models.CharField(max_length=100)
|
||||
Child.objects.create(name='Maxim', lastname='Maxim2')
|
||||
[out]
|
||||
|
||||
[CASE deep_multiple_inheritance_with_create]
|
||||
from django.db import models
|
||||
|
||||
class Parent1(models.Model):
|
||||
name1 = models.CharField(max_length=50)
|
||||
class Parent2(models.Model):
|
||||
id2 = models.AutoField(primary_key=True)
|
||||
name2 = models.CharField(max_length=50)
|
||||
class Child1(Parent1, Parent2):
|
||||
value = models.IntegerField()
|
||||
class Child4(Child1):
|
||||
value4 = models.IntegerField()
|
||||
Child4.objects.create(name1='n1', name2='n2', value=1, value4=4)
|
||||
@@ -149,3 +149,9 @@ MyModel(field=time())
|
||||
MyModel(field='12:00')
|
||||
MyModel(field=100) # E: Incompatible type for "field" of "MyModel" (got "int", expected "Union[str, time]")
|
||||
|
||||
[CASE charfield_with_integer_choices]
|
||||
from django.db import models
|
||||
class MyModel(models.Model):
|
||||
day = models.CharField(max_length=3, choices=((1, 'Fri'), (2, 'Sat')))
|
||||
MyModel(day=1)
|
||||
[out]
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
[CASE test_every_model_has_objects_queryset_available]
|
||||
from django.db import models
|
||||
class User(models.Model):
|
||||
pass
|
||||
reveal_type(User.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.User]'
|
||||
reveal_type(User.objects.get()) # E: Revealed type is 'main.User*'
|
||||
|
||||
[CASE test_leave_as_is_if_objects_is_set_and_fill_typevars_with_outer_class]
|
||||
from django.db import models
|
||||
|
||||
class UserManager(models.Manager[User]):
|
||||
def get_or_404(self) -> User:
|
||||
pass
|
||||
|
||||
class User(models.Model):
|
||||
objects = UserManager()
|
||||
|
||||
reveal_type(User.objects) # E: Revealed type is 'main.UserManager'
|
||||
reveal_type(User.objects.get()) # E: Revealed type is 'main.User*'
|
||||
reveal_type(User.objects.get_or_404()) # E: Revealed type is 'main.User'
|
||||
|
||||
[CASE test_model_objects_attribute_present_in_case_of_model_cls_passed_as_parameter]
|
||||
from typing import Type
|
||||
from django.db import models
|
||||
|
||||
class Base:
|
||||
def __init__(self, model_cls: Type[models.Model]):
|
||||
self.model_cls = model_cls
|
||||
class MyModel(models.Model):
|
||||
pass
|
||||
reveal_type(Base(MyModel).model_cls.objects) # E: Revealed type is 'django.db.models.manager.Manager[django.db.models.base.Model]'
|
||||
|
||||
[CASE test_model_objects_attribute_present_in_case_of_model_cls_passed_as_generic_parameter]
|
||||
from typing import TypeVar, Generic, Type
|
||||
from django.db import models
|
||||
|
||||
_T = TypeVar('_T', bound=models.Model)
|
||||
class Base(Generic[_T]):
|
||||
def __init__(self, model_cls: Type[_T]):
|
||||
self.model_cls = model_cls
|
||||
reveal_type(self.model_cls.objects) # E: Revealed type is 'django.db.models.manager.Manager[django.db.models.base.Model]'
|
||||
class MyModel(models.Model):
|
||||
pass
|
||||
base_instance = Base(MyModel)
|
||||
reveal_type(base_instance.model_cls.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.MyModel]'
|
||||
|
||||
class Child(Base[MyModel]):
|
||||
def method(self) -> None:
|
||||
reveal_type(self.model_cls.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.MyModel]'
|
||||
Reference in New Issue
Block a user