add BaseManager.create() typechecking

This commit is contained in:
Maxim Kurnikov
2019-02-12 03:54:37 +03:00
parent 7aafca2e5d
commit 9eb95fbab3
19 changed files with 557 additions and 267 deletions

View File

@@ -11,7 +11,7 @@ CHANGE: int
DELETION: int
ACTION_FLAG_CHOICES: Any
class LogEntryManager(models.Manager):
class LogEntryManager(models.Manager["LogEntry"]):
def log_action(
self,
user_id: int,

View File

@@ -16,7 +16,7 @@ class Permission(models.Model):
content_type_id: int
name: models.CharField = ...
content_type: models.ForeignKey[ContentType] = ...
codename: str = ...
codename: models.CharField = ...
def natural_key(self) -> Tuple[str, str, str]: ...
class GroupManager(models.Manager):

View File

@@ -1,10 +1,6 @@
from typing import Any, Optional
from django.db import models
class Redirect(models.Model):
id: None
site_id: int
site: Any = ...
old_path: str = ...
new_path: str = ...
site: models.ForeignKey = ...
old_path: models.CharField = ...
new_path: models.CharField = ...

View File

@@ -1,13 +1,14 @@
from typing import Any, Dict, Optional, Type, Union
from typing import Dict, Optional, Type, Union
from django.contrib.sessions.backends.base import SessionBase
from django.contrib.sessions.base_session import AbstractBaseSession
from django.contrib.sessions.models import Session
from django.core.signing import Serializer
from django.db.models.base import Model
class SessionStore(SessionBase):
accessed: bool
serializer: Type[django.core.signing.JSONSerializer]
serializer: Type[Serializer]
def __init__(self, session_key: Optional[str] = ...) -> None: ...
@classmethod
def get_model_class(cls) -> Type[Session]: ...

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union, ClassVar, Sequence
from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union, ClassVar, Sequence, Generic
from django.db.models.manager import Manager
@@ -10,9 +10,9 @@ class Model(metaclass=ModelBase):
class DoesNotExist(Exception): ...
class Meta: ...
_meta: Any
_default_manager: Manager[Model]
pk: Any = ...
objects: Manager[Model]
def __init__(self, *args, **kwargs) -> None: ...
def __init__(self: _Self, *args, **kwargs) -> None: ...
def delete(self, using: Any = ..., keep_parents: bool = ...) -> Tuple[int, Dict[str, int]]: ...
def full_clean(self, exclude: Optional[List[str]] = ..., validate_unique: bool = ...) -> None: ...
def clean_fields(self, exclude: List[str] = ...) -> None: ...

View File

@@ -1,6 +1,6 @@
import uuid
from datetime import date, time, datetime, timedelta
from typing import Any, Optional, Tuple, Iterable, Callable, Dict, Union, Type, TypeVar
from typing import Any, Optional, Tuple, Iterable, Callable, Dict, Union, Type, TypeVar, Generic
import decimal
from typing_extensions import Literal
@@ -14,7 +14,7 @@ from django.forms import Widget, Field as FormField
from .mixins import NOT_PROVIDED as NOT_PROVIDED
_Choice = Tuple[Any, Any]
_ChoiceNamedGroup = Union[Tuple[str, Iterable[_Choice]], Tuple[str, Any]]
_ChoiceNamedGroup = Tuple[str, Iterable[_Choice]]
_FieldChoices = Iterable[Union[_Choice, _ChoiceNamedGroup]]
_ValidatorCallable = Callable[..., None]
@@ -76,7 +76,7 @@ class SmallIntegerField(IntegerField): ...
class BigIntegerField(IntegerField): ...
class FloatField(Field):
def __set__(self, instance, value: Union[float, int, Combinable]) -> float: ...
def __set__(self, instance, value: Union[float, int, str, Combinable]) -> float: ...
def __get__(self, instance, owner) -> float: ...
class DecimalField(Field):
@@ -102,7 +102,7 @@ class DecimalField(Field):
validators: Iterable[_ValidatorCallable] = ...,
error_messages: Optional[_ErrorMessagesToOverride] = ...,
): ...
def __set__(self, instance, value: Union[str, Combinable]) -> decimal.Decimal: ...
def __set__(self, instance, value: Union[str, float, decimal.Decimal, Combinable]) -> decimal.Decimal: ...
def __get__(self, instance, owner) -> decimal.Decimal: ...
class AutoField(Field):
@@ -167,15 +167,15 @@ class EmailField(CharField): ...
class URLField(CharField): ...
class TextField(Field):
def __set__(self, instance, value: str) -> None: ...
def __set__(self, instance, value: Union[str, Combinable]) -> None: ...
def __get__(self, instance, owner) -> str: ...
class BooleanField(Field):
def __set__(self, instance, value: bool) -> None: ...
def __set__(self, instance, value: Union[bool, Combinable]) -> None: ...
def __get__(self, instance, owner) -> bool: ...
class NullBooleanField(Field):
def __set__(self, instance, value: Optional[bool]) -> None: ...
def __set__(self, instance, value: Optional[Union[bool, Combinable]]) -> None: ...
def __get__(self, instance, owner) -> Optional[bool]: ...
class IPAddressField(Field):

View File

@@ -1,9 +1,8 @@
from importlib.abc import SourceLoader
from typing import Any, Callable, Dict, List, Optional, Type, Union
from types import TracebackType
from typing import Any, Callable, Dict, List, MutableMapping, Optional, Type, Union
from django.core.handlers.wsgi import WSGIRequest
from django.http.request import QueryDict
from django.http.request import HttpRequest, QueryDict
from django.http.response import Http404, HttpResponse
from django.utils.safestring import SafeText
@@ -19,21 +18,21 @@ def cleanse_setting(key: Union[int, str], value: Any) -> Any: ...
def get_safe_settings() -> Dict[str, Any]: ...
def technical_500_response(request: Any, exc_type: Any, exc_value: Any, tb: Any, status_code: int = ...): ...
def get_default_exception_reporter_filter() -> ExceptionReporterFilter: ...
def get_exception_reporter_filter(request: Optional[WSGIRequest]) -> ExceptionReporterFilter: ...
def get_exception_reporter_filter(request: Optional[HttpRequest]) -> ExceptionReporterFilter: ...
class ExceptionReporterFilter:
def get_post_parameters(self, request: Any): ...
def get_traceback_frame_variables(self, request: Any, tb_frame: Any): ...
class SafeExceptionReporterFilter(ExceptionReporterFilter):
def is_active(self, request: Optional[WSGIRequest]) -> bool: ...
def get_cleansed_multivaluedict(self, request: WSGIRequest, multivaluedict: QueryDict) -> QueryDict: ...
def get_post_parameters(self, request: Optional[WSGIRequest]) -> Union[Dict[Any, Any], QueryDict]: ...
def cleanse_special_types(self, request: Optional[WSGIRequest], value: Any) -> Any: ...
def is_active(self, request: Optional[HttpRequest]) -> bool: ...
def get_cleansed_multivaluedict(self, request: HttpRequest, multivaluedict: QueryDict) -> QueryDict: ...
def get_post_parameters(self, request: Optional[HttpRequest]) -> MutableMapping[str, Any]: ...
def cleanse_special_types(self, request: Optional[HttpRequest], value: Any) -> Any: ...
def get_traceback_frame_variables(self, request: Any, tb_frame: Any): ...
class ExceptionReporter:
request: Optional[WSGIRequest] = ...
request: Optional[HttpRequest] = ...
filter: ExceptionReporterFilter = ...
exc_type: None = ...
exc_value: Optional[str] = ...
@@ -44,7 +43,7 @@ class ExceptionReporter:
postmortem: None = ...
def __init__(
self,
request: Optional[WSGIRequest],
request: Optional[HttpRequest],
exc_type: Optional[Type[BaseException]],
exc_value: Optional[Union[str, BaseException]],
tb: Optional[TracebackType],
@@ -63,5 +62,5 @@ class ExceptionReporter:
module_name: Optional[str] = None,
): ...
def technical_404_response(request: WSGIRequest, exception: Http404) -> HttpResponse: ...
def default_urlconf(request: WSGIRequest) -> HttpResponse: ...
def technical_404_response(request: HttpRequest, exception: Http404) -> HttpResponse: ...
def default_urlconf(request: HttpRequest) -> HttpResponse: ...

View File

@@ -0,0 +1,21 @@
from configparser import ConfigParser
from typing import Optional
from dataclasses import dataclass
@dataclass
class Config:
django_settings_module: Optional[str] = None
ignore_missing_settings: bool = False
@classmethod
def from_config_file(self, fpath: str) -> 'Config':
ini_config = ConfigParser()
ini_config.read(fpath)
if not ini_config.has_section('mypy_django_plugin'):
raise ValueError('Invalid config file: no [mypy_django_plugin] section')
return Config(django_settings_module=ini_config.get('mypy_django_plugin', 'django_settings',
fallback=None),
ignore_missing_settings=ini_config.get('mypy_django_plugin', 'ignore_missing_settings',
fallback=False))

View File

@@ -1,10 +1,9 @@
import typing
from typing import Dict, Optional
from mypy.nodes import Expression, FuncDef, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo, Var, AssignmentStmt, \
CallExpr
from mypy.nodes import Expression, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo
from mypy.plugin import FunctionContext
from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType, UnionType
from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeVarType
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
FIELD_FULLNAME = 'django.db.models.fields.Field'
@@ -119,74 +118,6 @@ def fill_typevars(tp: Instance, type_to_fill: Instance) -> Instance:
return reparametrize_with(type_to_fill, typevar_values)
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
if tp.type.has_base(FIELD_FULLNAME):
set_method = tp.type.get_method('__set__')
if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType):
if 'value' in set_method.type.arg_names:
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
if isinstance(set_value_type, Instance):
set_value_type = fill_typevars(tp, set_value_type)
return set_value_type
elif isinstance(set_value_type, UnionType):
items_no_typevars = []
for item in set_value_type.items:
if isinstance(item, Instance):
item = fill_typevars(tp, item)
items_no_typevars.append(item)
return UnionType(items_no_typevars)
get_method = tp.type.get_method('__get__')
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
return get_method.type.ret_type
# GenericForeignKey
if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME):
return AnyType(TypeOfAny.special_form)
return None
def extract_primary_key_type(model: TypeInfo) -> Optional[Type]:
# only primary keys defined in current class for now
for stmt in model.defn.defs.body:
if isinstance(stmt, AssignmentStmt) and isinstance(stmt.rvalue, CallExpr):
name_expr = stmt.lvalues[0]
if isinstance(name_expr, NameExpr):
name = name_expr.name
if 'primary_key' in stmt.rvalue.arg_names:
is_primary_key = stmt.rvalue.args[stmt.rvalue.arg_names.index('primary_key')]
if is_primary_key:
return extract_field_setter_type(model.names[name].type)
return None
def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]:
expected_types: Dict[str, Type] = {}
primary_key_type = extract_primary_key_type(model)
if not primary_key_type:
# no explicit primary key, set pk to Any and add id
primary_key_type = AnyType(TypeOfAny.special_form)
expected_types['id'] = ctx.api.named_generic_type('builtins.int', [])
expected_types['pk'] = primary_key_type
for base in model.mro:
for name, sym in base.names.items():
if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance):
tp = sym.node.type
field_type = extract_field_setter_type(tp)
if tp.type.fullname() in {FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME}:
ref_to_model = tp.args[0]
if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(MODEL_CLASS_FULLNAME):
primary_key_type = extract_primary_key_type(ref_to_model.type)
if not primary_key_type:
primary_key_type = AnyType(TypeOfAny.special_form)
expected_types[name + '_id'] = primary_key_type
if field_type:
expected_types[name] = field_type
return expected_types
def get_argument_by_name(ctx: FunctionContext, name: str) -> Optional[Expression]:
"""Return the expression for the specific argument.

View File

@@ -1,8 +1,6 @@
import os
from configparser import ConfigParser
from typing import Callable, Dict, Optional, Set, cast
from typing import Callable, Dict, Optional, cast
from dataclasses import dataclass
from mypy.checker import TypeChecker
from mypy.nodes import TypeInfo
from mypy.options import Options
@@ -10,7 +8,9 @@ from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin
from mypy.types import Instance, Type
from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.plugins.fields import determine_type_of_array_field
from mypy_django_plugin.config import Config
from mypy_django_plugin.plugins.fields import determine_type_of_array_field, record_field_properties_into_outer_model_class
from mypy_django_plugin.plugins.init_create import redefine_and_typecheck_model_init, redefine_and_typecheck_model_create
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations
from mypy_django_plugin.plugins.models import process_model_class
from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with
@@ -56,81 +56,6 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type:
return ret
def extract_base_pointer_args(model: TypeInfo) -> Set[str]:
pointer_args: Set[str] = set()
for base in model.bases:
if base.type.has_base(helpers.MODEL_CLASS_FULLNAME):
parent_name = base.type.name().lower()
pointer_args.add(f'{parent_name}_ptr')
pointer_args.add(f'{parent_name}_ptr_id')
return pointer_args
def redefine_model_init(ctx: FunctionContext) -> Type:
assert isinstance(ctx.default_return_type, Instance)
api = cast(TypeChecker, ctx.api)
model: TypeInfo = ctx.default_return_type.type
expected_types = helpers.extract_expected_types(ctx, model)
# order is preserved, can use for positionals
positional_names = list(expected_types.keys())
positional_names.remove('pk')
visited_positionals = set()
# check positionals
for i, (_, actual_pos_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])):
actual_pos_name = positional_names[i]
api.check_subtype(actual_pos_type, expected_types[actual_pos_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_pos_name,
model.name()),
'got', 'expected')
visited_positionals.add(actual_pos_name)
# extract name of base models for _ptr
base_pointer_args = extract_base_pointer_args(model)
# check kwargs
for i, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[1], ctx.arg_types[1])):
if actual_name in base_pointer_args:
# parent_ptr args are not supported
continue
if actual_name in visited_positionals:
continue
if actual_name is None:
# unpacked dict as kwargs is not supported
continue
if actual_name not in expected_types:
ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
model.name()),
ctx.context)
continue
api.check_subtype(actual_type, expected_types[actual_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_name,
model.name()),
'got', 'expected')
return ctx.default_return_type
@dataclass
class Config:
django_settings_module: Optional[str] = None
ignore_missing_settings: bool = False
@classmethod
def from_config_file(self, fpath: str) -> 'Config':
ini_config = ConfigParser()
ini_config.read(fpath)
if not ini_config.has_section('mypy_django_plugin'):
raise ValueError('Invalid config file: no [mypy_django_plugin] section')
return Config(django_settings_module=ini_config.get('mypy_django_plugin', 'django_settings',
fallback=None),
ignore_missing_settings=ini_config.get('mypy_django_plugin', 'ignore_missing_settings',
fallback=False))
class DjangoPlugin(Plugin):
def __init__(self, options: Options) -> None:
super().__init__(options)
@@ -194,11 +119,18 @@ class DjangoPlugin(Plugin):
sym = self.lookup_fully_qualified(fullname)
if sym and isinstance(sym.node, TypeInfo):
if sym.node.has_base(helpers.FIELD_FULLNAME):
return record_field_properties_into_outer_model_class
if sym.node.metadata.get('django', {}).get('generated_init'):
return redefine_model_init
return redefine_and_typecheck_model_init
def get_method_hook(self, fullname: str
) -> Optional[Callable[[MethodContext], Type]]:
manager_classes = self._get_current_manager_bases()
class_fullname, _, method_name = fullname.rpartition('.')
if class_fullname in manager_classes and method_name == 'create':
return redefine_and_typecheck_model_create
if fullname in {'django.apps.registry.Apps.get_model',
'django.db.migrations.state.StateApps.get_model'}:
return determine_model_cls_from_string_for_migrations

View File

@@ -1,7 +1,12 @@
from typing import cast
from mypy.checker import TypeChecker
from mypy.nodes import ListExpr, NameExpr, TupleExpr
from mypy.plugin import FunctionContext
from mypy.types import Type, Instance
from mypy.types import Instance, TupleType, Type
from mypy_django_plugin import helpers
from mypy_django_plugin.plugins.models import iter_over_assignments
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
@@ -16,3 +21,54 @@ def determine_type_of_array_field(ctx: FunctionContext) -> Type:
return ctx.api.named_generic_type(ctx.context.callee.fullname,
args=[get_method.type.ret_type])
def record_field_properties_into_outer_model_class(ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api)
outer_model = api.scope.active_class()
if outer_model is None or not outer_model.has_base(helpers.MODEL_CLASS_FULLNAME):
# outside models.Model class, undetermined
return ctx.default_return_type
field_name = None
for name_expr, stmt in iter_over_assignments(outer_model.defn):
if stmt == ctx.context and isinstance(name_expr, NameExpr):
field_name = name_expr.name
break
if field_name is None:
return ctx.default_return_type
fields_metadata = outer_model.metadata.setdefault('django', {}).setdefault('fields', {})
# primary key
is_primary_key = False
primary_key_arg = helpers.get_argument_by_name(ctx, 'primary_key')
if primary_key_arg:
is_primary_key = helpers.parse_bool(primary_key_arg)
fields_metadata[field_name] = {'primary_key': is_primary_key}
# choices
choices_arg = helpers.get_argument_by_name(ctx, 'choices')
if choices_arg and isinstance(choices_arg, (TupleExpr, ListExpr)):
# iterable of 2 element tuples of two kinds
_, analyzed_choices = api.analyze_iterable_item_type(choices_arg)
if isinstance(analyzed_choices, TupleType):
first_element_type = analyzed_choices.items[0]
if isinstance(first_element_type, Instance):
fields_metadata[field_name]['choices'] = first_element_type.type.fullname()
# nullability
null_arg = helpers.get_argument_by_name(ctx, 'null')
is_nullable = False
if null_arg:
is_nullable = helpers.parse_bool(null_arg)
fields_metadata[field_name]['null'] = is_nullable
# is_blankable
blank_arg = helpers.get_argument_by_name(ctx, 'blank')
is_blankable = False
if blank_arg:
is_blankable = helpers.parse_bool(blank_arg)
fields_metadata[field_name]['blank'] = is_blankable
return ctx.default_return_type

View File

@@ -0,0 +1,182 @@
from typing import Dict, Optional, Set, cast, Any
from mypy.checker import TypeChecker
from mypy.nodes import FuncDef, TypeInfo, Var
from mypy.plugin import FunctionContext, MethodContext
from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, UnionType
from mypy_django_plugin import helpers
def extract_base_pointer_args(model: TypeInfo) -> Set[str]:
pointer_args: Set[str] = set()
for base in model.bases:
if base.type.has_base(helpers.MODEL_CLASS_FULLNAME):
parent_name = base.type.name().lower()
pointer_args.add(f'{parent_name}_ptr')
pointer_args.add(f'{parent_name}_ptr_id')
return pointer_args
def redefine_and_typecheck_model_init(ctx: FunctionContext) -> Type:
assert isinstance(ctx.default_return_type, Instance)
api = cast(TypeChecker, ctx.api)
model: TypeInfo = ctx.default_return_type.type
expected_types = extract_expected_types(ctx, model)
# order is preserved, can use for positionals
positional_names = list(expected_types.keys())
positional_names.remove('pk')
visited_positionals = set()
# check positionals
for i, (_, actual_pos_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])):
actual_pos_name = positional_names[i]
api.check_subtype(actual_pos_type, expected_types[actual_pos_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_pos_name,
model.name()),
'got', 'expected')
visited_positionals.add(actual_pos_name)
# extract name of base models for _ptr
base_pointer_args = extract_base_pointer_args(model)
# check kwargs
for i, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[1], ctx.arg_types[1])):
if actual_name in base_pointer_args:
# parent_ptr args are not supported
continue
if actual_name in visited_positionals:
continue
if actual_name is None:
# unpacked dict as kwargs is not supported
continue
if actual_name not in expected_types:
ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
model.name()),
ctx.context)
continue
api.check_subtype(actual_type, expected_types[actual_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_name,
model.name()),
'got', 'expected')
return ctx.default_return_type
def redefine_and_typecheck_model_create(ctx: MethodContext) -> Type:
api = cast(TypeChecker, ctx.api)
if isinstance(ctx.type, Instance) and len(ctx.type.args) > 0:
model: TypeInfo = ctx.type.args[0].type
else:
if isinstance(ctx.default_return_type, AnyType):
return ctx.default_return_type
model: TypeInfo = ctx.default_return_type.type
# extract name of base models for _ptr
base_pointer_args = extract_base_pointer_args(model)
expected_types = extract_expected_types(ctx, model)
for actual_name, actual_type in zip(ctx.arg_names[0], ctx.arg_types[0]):
if actual_name in base_pointer_args:
# parent_ptr args are not supported
continue
if actual_name is None:
# unpacked dict as kwargs is not supported
continue
if actual_name not in expected_types:
api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
model.name()),
ctx.context)
continue
api.check_subtype(actual_type, expected_types[actual_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_name,
model.name()),
'got', 'expected')
return ctx.default_return_type
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
if not isinstance(tp, Instance):
return None
if tp.type.has_base(helpers.FIELD_FULLNAME):
set_method = tp.type.get_method('__set__')
if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType):
if 'value' in set_method.type.arg_names:
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
if isinstance(set_value_type, Instance):
set_value_type = helpers.fill_typevars(tp, set_value_type)
return set_value_type
elif isinstance(set_value_type, UnionType):
items_no_typevars = []
for item in set_value_type.items:
if isinstance(item, Instance):
item = helpers.fill_typevars(tp, item)
items_no_typevars.append(item)
return UnionType(items_no_typevars)
get_method = tp.type.get_method('__get__')
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
return get_method.type.ret_type
# GenericForeignKey
if tp.type.has_base(helpers.GENERIC_FOREIGN_KEY_FULLNAME):
return AnyType(TypeOfAny.special_form)
return None
def get_fields_metadata(model: TypeInfo) -> Dict[str, Any]:
return model.metadata.setdefault('django', {}).setdefault('fields', {})
def extract_primary_key_type(model: TypeInfo) -> Optional[Type]:
for field_name, props in get_fields_metadata(model).items():
is_primary_key = props.get('primary_key', False)
if is_primary_key:
return extract_field_setter_type(model.names[field_name].type)
return None
def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]:
field_metadata = get_fields_metadata(model).get(field_name, {})
if 'choices' in field_metadata:
return field_metadata['choices']
return None
def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]:
expected_types: Dict[str, Type] = {}
primary_key_type = extract_primary_key_type(model)
if not primary_key_type:
# no explicit primary key, set pk to Any and add id
primary_key_type = AnyType(TypeOfAny.special_form)
expected_types['id'] = ctx.api.named_generic_type('builtins.int', [])
expected_types['pk'] = primary_key_type
for base in model.mro:
for name, sym in base.names.items():
if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance):
tp = sym.node.type
field_type = extract_field_setter_type(tp)
if field_type is None:
continue
choices_type_fullname = extract_choices_type(model, name)
if choices_type_fullname:
field_type = UnionType([field_type, ctx.api.named_generic_type(choices_type_fullname, [])])
if tp.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}:
ref_to_model = tp.args[0]
if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME):
primary_key_type = extract_primary_key_type(ref_to_model.type)
if not primary_key_type:
primary_key_type = AnyType(TypeOfAny.special_form)
expected_types[name + '_id'] = primary_key_type
if field_type:
expected_types[name] = field_type
return expected_types

View File

@@ -1,9 +1,9 @@
from typing import cast, Optional
from typing import Optional, cast
from mypy.checker import TypeChecker
from mypy.nodes import TypeInfo, Expression, StrExpr, NameExpr, RefExpr, Var
from mypy.nodes import Expression, StrExpr, TypeInfo
from mypy.plugin import MethodContext
from mypy.types import Type, Instance, TypeType
from mypy.types import Instance, Type, TypeType
from mypy_django_plugin import helpers

View File

@@ -1,13 +1,13 @@
from abc import ABCMeta, abstractmethod
from typing import Dict, Iterator, Optional, Tuple, cast
from typing import Dict, Iterator, List, Optional, Tuple, cast
import dataclasses
from mypy.nodes import AssignmentStmt, CallExpr, ClassDef, Context, Expression, Lvalue, MDEF, MemberExpr, \
MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var, Argument, ARG_STAR2, ARG_STAR
from mypy.nodes import ARG_STAR, ARG_STAR2, Argument, AssignmentStmt, CallExpr, ClassDef, Context, Expression, IndexExpr, \
Lvalue, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var
from mypy.plugin import ClassDefContext
from mypy.plugins.common import add_method
from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import Instance, AnyType, TypeOfAny, NoneTyp
from mypy.types import AnyType, Instance, NoneTyp, TypeOfAny
from mypy_django_plugin import helpers
@@ -27,18 +27,20 @@ class ModelClassInitializer(metaclass=ABCMeta):
return metaclass_sym.node
return None
def is_abstract_model(self) -> bool:
def get_meta_attribute(self, name: str) -> Optional[Expression]:
meta_node = self.get_nested_meta_node()
if meta_node is None:
return False
return None
for lvalue, rvalue in iter_over_assignments(meta_node.defn):
if isinstance(lvalue, NameExpr) and lvalue.name == 'abstract':
is_abstract = self.api.parse_bool(rvalue)
if is_abstract:
# abstract model do not need 'objects' queryset
return True
return False
if isinstance(lvalue, NameExpr) and lvalue.name == name:
return rvalue
def is_abstract_model(self) -> bool:
is_abstract_expr = self.get_meta_attribute('abstract')
if is_abstract_expr is None:
return False
return self.api.parse_bool(is_abstract_expr)
def add_new_node_to_model_class(self, name: str, typ: Instance) -> None:
var = Var(name=name, type=typ)
@@ -93,25 +95,65 @@ class InjectAnyAsBaseForNestedMeta(ModelClassInitializer):
meta_node.fallback_to_any = True
def get_model_argument(manager_info: TypeInfo) -> Optional[Instance]:
for base in manager_info.bases:
if base.args:
model_arg = base.args[0]
if isinstance(model_arg, Instance) and model_arg.type.has_base(helpers.MODEL_CLASS_FULLNAME):
return model_arg
return None
class AddDefaultObjectsManager(ModelClassInitializer):
def is_default_objects_attr(self, sym: SymbolTableNode) -> bool:
return sym.fullname == helpers.MODEL_CLASS_FULLNAME + '.' + 'objects'
def add_new_manager(self, name: str, manager_type: Optional[Instance]) -> None:
if manager_type is None:
return None
self.add_new_node_to_model_class(name, manager_type)
def add_private_default_manager(self, manager_type: Optional[Instance]) -> None:
if manager_type is None:
return None
self.add_new_node_to_model_class('_default_manager', manager_type)
def get_existing_managers(self) -> List[Tuple[str, TypeInfo]]:
managers = []
for base in self.model_classdef.info.mro:
for name_expr, member_expr in iter_call_assignments(base.defn):
manager_name = name_expr.name
callee_expr = member_expr.callee
if isinstance(callee_expr, IndexExpr):
callee_expr = callee_expr.analyzed.expr
if isinstance(callee_expr, (MemberExpr, NameExpr)) \
and isinstance(callee_expr.node, TypeInfo) \
and callee_expr.node.has_base(helpers.BASE_MANAGER_CLASS_FULLNAME):
managers.append((manager_name, callee_expr.node))
return managers
def run(self) -> None:
existing_objects_sym = self.model_classdef.info.get('objects')
if (existing_objects_sym is not None
and not self.is_default_objects_attr(existing_objects_sym)):
return None
existing_managers = self.get_existing_managers()
if existing_managers:
first_manager_type = None
for manager_name, manager_type_info in existing_managers:
manager_type = Instance(manager_type_info, args=[Instance(self.model_classdef.info, [])])
self.add_new_manager(name=manager_name, manager_type=manager_type)
if first_manager_type is None:
first_manager_type = manager_type
else:
if self.is_abstract_model():
# abstract models do not need 'objects' queryset
return None
first_manager_type = self.api.named_type_or_none(helpers.MANAGER_CLASS_FULLNAME,
args=[Instance(self.model_classdef.info, [])])
self.add_new_manager('objects', manager_type=first_manager_type)
if self.is_abstract_model():
# abstract models do not need 'objects' queryset
return None
typ = self.api.named_type_or_none(helpers.MANAGER_CLASS_FULLNAME,
args=[Instance(self.model_classdef.info, [])])
if not typ:
return None
self.add_new_node_to_model_class('objects', typ)
default_manager_name_expr = self.get_meta_attribute('default_manager_name')
if isinstance(default_manager_name_expr, StrExpr):
self.add_private_default_manager(self.model_classdef.info.get(default_manager_name_expr.value).type)
else:
self.add_private_default_manager(first_manager_type)
class AddIdAttributeIfPrimaryKeyTrueIsNotSet(ModelClassInitializer):

View File

@@ -75,7 +75,9 @@ IGNORED_ERRORS = {
'Argument 1 to "FileWrapper" has incompatible type "StringIO"; expected "IO[bytes]"',
'Incompatible types in assignment',
'"object" not callable',
'Incompatible type for "pk" of "Collector" (got "int", expected "str")'
'Incompatible type for "pk" of "Collector" (got "int", expected "str")',
re.compile('Unexpected attribute "[a-z]+" for model "Model"'),
'Unexpected attribute "two_id" for model "CyclicOne"'
],
'aggregation': [
'Incompatible types in assignment (expression has type "QuerySet[Any]", variable has type "List[Any]")',
@@ -207,7 +209,8 @@ IGNORED_ERRORS = {
'Incompatible types in assignment (expression has type "Type[Field]',
'DummyArrayField',
'DummyJSONField',
'Argument "encoder" to "JSONField" has incompatible type "DjangoJSONEncoder"; expected "Optional[Type[JSONEncoder]]"'
'Argument "encoder" to "JSONField" has incompatible type "DjangoJSONEncoder"; expected "Optional[Type[JSONEncoder]]"',
'for model "CITestModel"'
],
'properties': [
re.compile('Unexpected attribute "(full_name|full_name_2)" for model "Person"')
@@ -430,7 +433,6 @@ TESTS_DIRS = [
'model_options',
'model_package',
'model_regress',
# not practical
'modeladmin',
# TODO: 'multiple_database',
'mutually_referential',
@@ -495,9 +497,7 @@ TESTS_DIRS = [
'transaction_hooks',
'transactions',
'unmanaged_models',
'update',
'update_only_fields',
'urlpatterns',

View File

@@ -0,0 +1,139 @@
[CASE test_every_model_has_objects_queryset_available]
from django.db import models
class User(models.Model):
pass
reveal_type(User.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.User]'
reveal_type(User.objects.get()) # E: Revealed type is 'main.User*'
[CASE every_model_has_its_own_objects_queryset]
from django.db import models
class Parent(models.Model):
pass
class Child(Parent):
pass
reveal_type(Parent.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.Parent]'
reveal_type(Child.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.Child]'
[out]
[CASE if_manager_is_defined_on_model_do_not_add_objects]
from django.db import models
class MyModel(models.Model):
authors = models.Manager[MyModel]()
reveal_type(MyModel.authors) # E: Revealed type is 'django.db.models.manager.Manager[main.MyModel]'
reveal_type(MyModel.objects) # E: Revealed type is 'Any'
[out]
[CASE test_model_objects_attribute_present_in_case_of_model_cls_passed_as_generic_parameter]
from typing import TypeVar, Generic, Type
from django.db import models
_T = TypeVar('_T', bound=models.Model)
class Base(Generic[_T]):
def __init__(self, model_cls: Type[_T]):
self.model_cls = model_cls
reveal_type(self.model_cls._default_manager) # E: Revealed type is 'django.db.models.manager.Manager[django.db.models.base.Model]'
class MyModel(models.Model):
pass
base_instance = Base(MyModel)
reveal_type(base_instance.model_cls._default_manager) # E: Revealed type is 'django.db.models.manager.Manager[main.MyModel]'
class Child(Base[MyModel]):
def method(self) -> None:
reveal_type(self.model_cls._default_manager) # E: Revealed type is 'django.db.models.manager.Manager[main.MyModel]'
[CASE if_custom_manager_defined_it_is_set_to_default_manager]
from typing import TypeVar
from django.db import models
_T = TypeVar('_T', bound=models.Model)
class CustomManager(models.Manager[_T]):
pass
class MyModel(models.Model):
manager = CustomManager[MyModel]()
reveal_type(MyModel._default_manager) # E: Revealed type is 'main.CustomManager[main.MyModel]'
[CASE if_default_manager_name_is_passed_set_default_manager_to_it]
from typing import TypeVar
from django.db import models
_T = TypeVar('_T', bound=models.Model)
class Manager1(models.Manager[_T]):
pass
class Manager2(models.Manager[_T]):
pass
class MyModel(models.Model):
class Meta:
default_manager_name = 'm2'
m1: Manager1[MyModel]
m2: Manager2[MyModel]
reveal_type(MyModel._default_manager) # E: Revealed type is 'main.Manager2[main.MyModel]'
[CASE test_leave_as_is_if_objects_is_set_and_fill_typevars_with_outer_class]
from django.db import models
class UserManager(models.Manager[MyUser]):
def get_or_404(self) -> MyUser:
pass
class MyUser(models.Model):
objects = UserManager()
reveal_type(MyUser.objects) # E: Revealed type is 'main.UserManager[main.MyUser]'
reveal_type(MyUser.objects.get()) # E: Revealed type is 'main.MyUser*'
reveal_type(MyUser.objects.get_or_404()) # E: Revealed type is 'main.MyUser'
[CASE model_imported_from_different_file]
from django.db import models
from models.main import Inventory
class Band(models.Model):
pass
reveal_type(Inventory.objects) # E: Revealed type is 'django.db.models.manager.Manager[models.main.Inventory]'
reveal_type(Band.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.Band]'
[file models/__init__.py]
[file models/main.py]
from django.db import models
class Inventory(models.Model):
pass
[CASE managers_that_defined_on_other_models_do_not_influence]
from django.db import models
class AbstractPerson(models.Model):
abstract_persons = models.Manager[AbstractPerson]()
class PublishedBookManager(models.Manager[Book]):
pass
class AnnotatedBookManager(models.Manager[Book]):
pass
class Book(models.Model):
title = models.CharField(max_length=50)
published_objects = PublishedBookManager()
annotated_objects = AnnotatedBookManager()
reveal_type(AbstractPerson.abstract_persons) # E: Revealed type is 'django.db.models.manager.Manager[main.AbstractPerson]'
reveal_type(Book.published_objects) # E: Revealed type is 'main.PublishedBookManager[main.Book]'
Book.published_objects.create(title='hello')
reveal_type(Book.annotated_objects) # E: Revealed type is 'main.AnnotatedBookManager[main.Book]'
Book.annotated_objects.create(title='hello')
[out]
[CASE managers_inherited_from_abstract_classes_multiple_inheritance]
from django.db import models
class CustomManager1(models.Manager[AbstractBase1]):
pass
class AbstractBase1(models.Model):
class Meta:
abstract = True
name = models.CharField(max_length=50)
manager1 = CustomManager1()
class CustomManager2(models.Manager[AbstractBase2]):
pass
class AbstractBase2(models.Model):
class Meta:
abstract = True
value = models.CharField(max_length=50)
restricted = CustomManager2()
class Child(AbstractBase1, AbstractBase2):
pass
[out]

View File

@@ -0,0 +1,34 @@
[CASE default_manager_create_is_typechecked]
from django.db import models
class User(models.Model):
name = models.CharField(max_length=100)
age = models.IntegerField()
User.objects.create(name='Max', age=10)
User.objects.create(name=1010) # E: Incompatible type for "name" of "User" (got "int", expected "Union[str, Combinable]")
[out]
[CASE model_recognises_parent_attributes]
from django.db import models
class Parent(models.Model):
name = models.CharField(max_length=100)
class Child(Parent):
lastname = models.CharField(max_length=100)
Child.objects.create(name='Maxim', lastname='Maxim2')
[out]
[CASE deep_multiple_inheritance_with_create]
from django.db import models
class Parent1(models.Model):
name1 = models.CharField(max_length=50)
class Parent2(models.Model):
id2 = models.AutoField(primary_key=True)
name2 = models.CharField(max_length=50)
class Child1(Parent1, Parent2):
value = models.IntegerField()
class Child4(Child1):
value4 = models.IntegerField()
Child4.objects.create(name1='n1', name2='n2', value=1, value4=4)

View File

@@ -149,3 +149,9 @@ MyModel(field=time())
MyModel(field='12:00')
MyModel(field=100) # E: Incompatible type for "field" of "MyModel" (got "int", expected "Union[str, time]")
[CASE charfield_with_integer_choices]
from django.db import models
class MyModel(models.Model):
day = models.CharField(max_length=3, choices=((1, 'Fri'), (2, 'Sat')))
MyModel(day=1)
[out]

View File

@@ -1,49 +0,0 @@
[CASE test_every_model_has_objects_queryset_available]
from django.db import models
class User(models.Model):
pass
reveal_type(User.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.User]'
reveal_type(User.objects.get()) # E: Revealed type is 'main.User*'
[CASE test_leave_as_is_if_objects_is_set_and_fill_typevars_with_outer_class]
from django.db import models
class UserManager(models.Manager[User]):
def get_or_404(self) -> User:
pass
class User(models.Model):
objects = UserManager()
reveal_type(User.objects) # E: Revealed type is 'main.UserManager'
reveal_type(User.objects.get()) # E: Revealed type is 'main.User*'
reveal_type(User.objects.get_or_404()) # E: Revealed type is 'main.User'
[CASE test_model_objects_attribute_present_in_case_of_model_cls_passed_as_parameter]
from typing import Type
from django.db import models
class Base:
def __init__(self, model_cls: Type[models.Model]):
self.model_cls = model_cls
class MyModel(models.Model):
pass
reveal_type(Base(MyModel).model_cls.objects) # E: Revealed type is 'django.db.models.manager.Manager[django.db.models.base.Model]'
[CASE test_model_objects_attribute_present_in_case_of_model_cls_passed_as_generic_parameter]
from typing import TypeVar, Generic, Type
from django.db import models
_T = TypeVar('_T', bound=models.Model)
class Base(Generic[_T]):
def __init__(self, model_cls: Type[_T]):
self.model_cls = model_cls
reveal_type(self.model_cls.objects) # E: Revealed type is 'django.db.models.manager.Manager[django.db.models.base.Model]'
class MyModel(models.Model):
pass
base_instance = Base(MyModel)
reveal_type(base_instance.model_cls.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.MyModel]'
class Child(Base[MyModel]):
def method(self) -> None:
reveal_type(self.model_cls.objects) # E: Revealed type is 'django.db.models.manager.Manager[main.MyModel]'