add BaseManager.create() typechecking

This commit is contained in:
Maxim Kurnikov
2019-02-12 03:54:37 +03:00
parent 7aafca2e5d
commit 9eb95fbab3
19 changed files with 557 additions and 267 deletions

View File

@@ -1,8 +1,6 @@
import os
from configparser import ConfigParser
from typing import Callable, Dict, Optional, Set, cast
from typing import Callable, Dict, Optional, cast
from dataclasses import dataclass
from mypy.checker import TypeChecker
from mypy.nodes import TypeInfo
from mypy.options import Options
@@ -10,7 +8,9 @@ from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin
from mypy.types import Instance, Type
from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.plugins.fields import determine_type_of_array_field
from mypy_django_plugin.config import Config
from mypy_django_plugin.plugins.fields import determine_type_of_array_field, record_field_properties_into_outer_model_class
from mypy_django_plugin.plugins.init_create import redefine_and_typecheck_model_init, redefine_and_typecheck_model_create
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations
from mypy_django_plugin.plugins.models import process_model_class
from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with
@@ -56,81 +56,6 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type:
return ret
def extract_base_pointer_args(model: TypeInfo) -> Set[str]:
pointer_args: Set[str] = set()
for base in model.bases:
if base.type.has_base(helpers.MODEL_CLASS_FULLNAME):
parent_name = base.type.name().lower()
pointer_args.add(f'{parent_name}_ptr')
pointer_args.add(f'{parent_name}_ptr_id')
return pointer_args
def redefine_model_init(ctx: FunctionContext) -> Type:
assert isinstance(ctx.default_return_type, Instance)
api = cast(TypeChecker, ctx.api)
model: TypeInfo = ctx.default_return_type.type
expected_types = helpers.extract_expected_types(ctx, model)
# order is preserved, can use for positionals
positional_names = list(expected_types.keys())
positional_names.remove('pk')
visited_positionals = set()
# check positionals
for i, (_, actual_pos_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])):
actual_pos_name = positional_names[i]
api.check_subtype(actual_pos_type, expected_types[actual_pos_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_pos_name,
model.name()),
'got', 'expected')
visited_positionals.add(actual_pos_name)
# extract name of base models for _ptr
base_pointer_args = extract_base_pointer_args(model)
# check kwargs
for i, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[1], ctx.arg_types[1])):
if actual_name in base_pointer_args:
# parent_ptr args are not supported
continue
if actual_name in visited_positionals:
continue
if actual_name is None:
# unpacked dict as kwargs is not supported
continue
if actual_name not in expected_types:
ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
model.name()),
ctx.context)
continue
api.check_subtype(actual_type, expected_types[actual_name],
ctx.context,
'Incompatible type for "{}" of "{}"'.format(actual_name,
model.name()),
'got', 'expected')
return ctx.default_return_type
@dataclass
class Config:
django_settings_module: Optional[str] = None
ignore_missing_settings: bool = False
@classmethod
def from_config_file(self, fpath: str) -> 'Config':
ini_config = ConfigParser()
ini_config.read(fpath)
if not ini_config.has_section('mypy_django_plugin'):
raise ValueError('Invalid config file: no [mypy_django_plugin] section')
return Config(django_settings_module=ini_config.get('mypy_django_plugin', 'django_settings',
fallback=None),
ignore_missing_settings=ini_config.get('mypy_django_plugin', 'ignore_missing_settings',
fallback=False))
class DjangoPlugin(Plugin):
def __init__(self, options: Options) -> None:
super().__init__(options)
@@ -194,11 +119,18 @@ class DjangoPlugin(Plugin):
sym = self.lookup_fully_qualified(fullname)
if sym and isinstance(sym.node, TypeInfo):
if sym.node.has_base(helpers.FIELD_FULLNAME):
return record_field_properties_into_outer_model_class
if sym.node.metadata.get('django', {}).get('generated_init'):
return redefine_model_init
return redefine_and_typecheck_model_init
def get_method_hook(self, fullname: str
) -> Optional[Callable[[MethodContext], Type]]:
manager_classes = self._get_current_manager_bases()
class_fullname, _, method_name = fullname.rpartition('.')
if class_fullname in manager_classes and method_name == 'create':
return redefine_and_typecheck_model_create
if fullname in {'django.apps.registry.Apps.get_model',
'django.db.migrations.state.StateApps.get_model'}:
return determine_model_cls_from_string_for_migrations