mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-07 20:54:29 +08:00
add BaseManager.create() typechecking
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
import os
|
||||
from configparser import ConfigParser
|
||||
from typing import Callable, Dict, Optional, Set, cast
|
||||
from typing import Callable, Dict, Optional, cast
|
||||
|
||||
from dataclasses import dataclass
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import TypeInfo
|
||||
from mypy.options import Options
|
||||
@@ -10,7 +8,9 @@ from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin
|
||||
from mypy.types import Instance, Type
|
||||
|
||||
from mypy_django_plugin import helpers, monkeypatch
|
||||
from mypy_django_plugin.plugins.fields import determine_type_of_array_field
|
||||
from mypy_django_plugin.config import Config
|
||||
from mypy_django_plugin.plugins.fields import determine_type_of_array_field, record_field_properties_into_outer_model_class
|
||||
from mypy_django_plugin.plugins.init_create import redefine_and_typecheck_model_init, redefine_and_typecheck_model_create
|
||||
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations
|
||||
from mypy_django_plugin.plugins.models import process_model_class
|
||||
from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with
|
||||
@@ -56,81 +56,6 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type:
|
||||
return ret
|
||||
|
||||
|
||||
def extract_base_pointer_args(model: TypeInfo) -> Set[str]:
|
||||
pointer_args: Set[str] = set()
|
||||
for base in model.bases:
|
||||
if base.type.has_base(helpers.MODEL_CLASS_FULLNAME):
|
||||
parent_name = base.type.name().lower()
|
||||
pointer_args.add(f'{parent_name}_ptr')
|
||||
pointer_args.add(f'{parent_name}_ptr_id')
|
||||
return pointer_args
|
||||
|
||||
|
||||
def redefine_model_init(ctx: FunctionContext) -> Type:
|
||||
assert isinstance(ctx.default_return_type, Instance)
|
||||
|
||||
api = cast(TypeChecker, ctx.api)
|
||||
model: TypeInfo = ctx.default_return_type.type
|
||||
|
||||
expected_types = helpers.extract_expected_types(ctx, model)
|
||||
# order is preserved, can use for positionals
|
||||
positional_names = list(expected_types.keys())
|
||||
positional_names.remove('pk')
|
||||
visited_positionals = set()
|
||||
|
||||
# check positionals
|
||||
for i, (_, actual_pos_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])):
|
||||
actual_pos_name = positional_names[i]
|
||||
api.check_subtype(actual_pos_type, expected_types[actual_pos_name],
|
||||
ctx.context,
|
||||
'Incompatible type for "{}" of "{}"'.format(actual_pos_name,
|
||||
model.name()),
|
||||
'got', 'expected')
|
||||
visited_positionals.add(actual_pos_name)
|
||||
|
||||
# extract name of base models for _ptr
|
||||
base_pointer_args = extract_base_pointer_args(model)
|
||||
|
||||
# check kwargs
|
||||
for i, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[1], ctx.arg_types[1])):
|
||||
if actual_name in base_pointer_args:
|
||||
# parent_ptr args are not supported
|
||||
continue
|
||||
if actual_name in visited_positionals:
|
||||
continue
|
||||
if actual_name is None:
|
||||
# unpacked dict as kwargs is not supported
|
||||
continue
|
||||
if actual_name not in expected_types:
|
||||
ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
|
||||
model.name()),
|
||||
ctx.context)
|
||||
continue
|
||||
api.check_subtype(actual_type, expected_types[actual_name],
|
||||
ctx.context,
|
||||
'Incompatible type for "{}" of "{}"'.format(actual_name,
|
||||
model.name()),
|
||||
'got', 'expected')
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
django_settings_module: Optional[str] = None
|
||||
ignore_missing_settings: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_config_file(self, fpath: str) -> 'Config':
|
||||
ini_config = ConfigParser()
|
||||
ini_config.read(fpath)
|
||||
if not ini_config.has_section('mypy_django_plugin'):
|
||||
raise ValueError('Invalid config file: no [mypy_django_plugin] section')
|
||||
return Config(django_settings_module=ini_config.get('mypy_django_plugin', 'django_settings',
|
||||
fallback=None),
|
||||
ignore_missing_settings=ini_config.get('mypy_django_plugin', 'ignore_missing_settings',
|
||||
fallback=False))
|
||||
|
||||
|
||||
class DjangoPlugin(Plugin):
|
||||
def __init__(self, options: Options) -> None:
|
||||
super().__init__(options)
|
||||
@@ -194,11 +119,18 @@ class DjangoPlugin(Plugin):
|
||||
|
||||
sym = self.lookup_fully_qualified(fullname)
|
||||
if sym and isinstance(sym.node, TypeInfo):
|
||||
if sym.node.has_base(helpers.FIELD_FULLNAME):
|
||||
return record_field_properties_into_outer_model_class
|
||||
if sym.node.metadata.get('django', {}).get('generated_init'):
|
||||
return redefine_model_init
|
||||
return redefine_and_typecheck_model_init
|
||||
|
||||
def get_method_hook(self, fullname: str
|
||||
) -> Optional[Callable[[MethodContext], Type]]:
|
||||
manager_classes = self._get_current_manager_bases()
|
||||
class_fullname, _, method_name = fullname.rpartition('.')
|
||||
if class_fullname in manager_classes and method_name == 'create':
|
||||
return redefine_and_typecheck_model_create
|
||||
|
||||
if fullname in {'django.apps.registry.Apps.get_model',
|
||||
'django.db.migrations.state.StateApps.get_model'}:
|
||||
return determine_model_cls_from_string_for_migrations
|
||||
|
||||
Reference in New Issue
Block a user