mirror of
https://github.com/davidhalter/django-stubs.git
synced 2026-05-24 09:18:41 +08:00
add proper __init__, create() support
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
from typing import Optional, Tuple, cast
|
||||
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import Expression, ListExpr, NameExpr, StrExpr, TupleExpr, TypeInfo
|
||||
from mypy.nodes import StrExpr, TypeInfo
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import AnyType, CallableType, Instance, TupleType, Type as MypyType, UnionType
|
||||
from mypy.types import AnyType, CallableType, Instance, Type as MypyType, UnionType
|
||||
|
||||
from mypy_django_plugin_newsemanal.context import DjangoContext
|
||||
from mypy_django_plugin_newsemanal.lib import fullnames, helpers, metadata
|
||||
from mypy_django_plugin_newsemanal.django.context import DjangoContext
|
||||
from mypy_django_plugin_newsemanal.lib import fullnames, helpers
|
||||
|
||||
|
||||
def extract_referred_to_type(ctx: FunctionContext, django_context: DjangoContext) -> Optional[Instance]:
|
||||
|
||||
@@ -1,35 +1,69 @@
|
||||
from typing import cast
|
||||
from typing import List, Tuple, Type, Union
|
||||
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import Argument, Var, ARG_NAMED
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import Type as MypyType, Instance
|
||||
from django.db.models.base import Model
|
||||
from mypy.plugin import FunctionContext, MethodContext
|
||||
from mypy.types import Instance, Type as MypyType
|
||||
|
||||
from mypy_django_plugin_newsemanal.context import DjangoContext
|
||||
from mypy_django_plugin_newsemanal.lib import helpers
|
||||
from mypy_django_plugin_newsemanal.django.context import DjangoContext
|
||||
|
||||
|
||||
def get_actual_types(ctx: Union[MethodContext, FunctionContext], expected_keys: List[str]) -> List[Tuple[str, MypyType]]:
|
||||
actual_types = []
|
||||
# positionals
|
||||
for pos, (actual_name, actual_type) in enumerate(zip(ctx.arg_names[0], ctx.arg_types[0])):
|
||||
if actual_name is None:
|
||||
if ctx.callee_arg_names[0] == 'kwargs':
|
||||
# unpacked dict as kwargs is not supported
|
||||
continue
|
||||
actual_name = expected_keys[pos]
|
||||
actual_types.append((actual_name, actual_type))
|
||||
# kwargs
|
||||
if len(ctx.callee_arg_names) > 1:
|
||||
for actual_name, actual_type in zip(ctx.arg_names[1], ctx.arg_types[1]):
|
||||
if actual_name is None:
|
||||
# unpacked dict as kwargs is not supported
|
||||
continue
|
||||
actual_types.append((actual_name, actual_type))
|
||||
return actual_types
|
||||
|
||||
|
||||
def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext,
|
||||
model_cls: Type[Model], method: str) -> MypyType:
|
||||
expected_types = django_context.get_expected_types(ctx.api, model_cls, method)
|
||||
expected_keys = [key for key in expected_types.keys() if key != 'pk']
|
||||
|
||||
for actual_name, actual_type in get_actual_types(ctx, expected_keys):
|
||||
if actual_name not in expected_types:
|
||||
ctx.api.fail('Unexpected attribute "{}" for model "{}"'.format(actual_name,
|
||||
model_cls.__name__),
|
||||
ctx.context)
|
||||
continue
|
||||
ctx.api.check_subtype(actual_type, expected_types[actual_name],
|
||||
ctx.context,
|
||||
'Incompatible type for "{}" of "{}"'.format(actual_name,
|
||||
model_cls.__name__),
|
||||
'got', 'expected')
|
||||
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def redefine_and_typecheck_model_init(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
|
||||
assert isinstance(ctx.default_return_type, Instance)
|
||||
|
||||
api = cast(TypeChecker, ctx.api)
|
||||
model_fullname = ctx.default_return_type.type.fullname()
|
||||
model_cls = django_context.get_model_class_by_fullname(model_fullname)
|
||||
if model_cls is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
model_info = ctx.default_return_type.type
|
||||
model_cls = django_context.get_model_class_by_fullname(model_info.fullname())
|
||||
return typecheck_model_method(ctx, django_context, model_cls, '__init__')
|
||||
|
||||
# expected_types = {}
|
||||
# for field in model_cls._meta.get_fields():
|
||||
# field_fullname = helpers.get_class_fullname(field.__class__)
|
||||
# field_info = api.lookup_typeinfo(field_fullname)
|
||||
# field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type',
|
||||
# is_nullable=False)
|
||||
# field_kwarg = Argument(variable=Var(field.attname, field_set_type),
|
||||
# type_annotation=field_set_type,
|
||||
# initializer=None,
|
||||
# kind=ARG_NAMED)
|
||||
# expected_types[field.attname] = field_set_type
|
||||
# for field_name, field in model_cls._meta.fields_map.items():
|
||||
# print()
|
||||
|
||||
# print()
|
||||
return ctx.default_return_type
|
||||
def redefine_and_typecheck_model_create(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
|
||||
isinstance(ctx.default_return_type, Instance)
|
||||
|
||||
model_fullname = ctx.default_return_type.type.fullname()
|
||||
model_cls = django_context.get_model_class_by_fullname(model_fullname)
|
||||
if model_cls is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
return typecheck_model_method(ctx, django_context, model_cls, 'create')
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
import dataclasses
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Optional, Type, cast
|
||||
from typing import cast
|
||||
|
||||
from django.db.models.base import Model
|
||||
from django.db.models.fields.related import ForeignKey
|
||||
from mypy.newsemanal.semanal import NewSemanticAnalyzer
|
||||
from mypy.nodes import ARG_NAMED_OPT, Argument, ClassDef, MDEF, SymbolTableNode, TypeInfo, Var
|
||||
from mypy.nodes import ClassDef, MDEF, SymbolTableNode, TypeInfo, Var
|
||||
from mypy.plugin import ClassDefContext
|
||||
from mypy.plugins import common
|
||||
from mypy.types import AnyType, Instance, NoneType, Type as MypyType, UnionType
|
||||
from mypy.types import Instance
|
||||
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.db.models.fields import CharField, Field
|
||||
from mypy_django_plugin_newsemanal.context import DjangoContext
|
||||
from django.db.models.fields import Field
|
||||
from mypy_django_plugin_newsemanal.django.context import DjangoContext
|
||||
from mypy_django_plugin_newsemanal.lib import helpers
|
||||
from mypy_django_plugin_newsemanal.transformers import fields
|
||||
from mypy_django_plugin_newsemanal.transformers.fields import get_field_descriptor_types
|
||||
@@ -52,101 +49,6 @@ class ModelClassInitializer(metaclass=ABCMeta):
|
||||
var.is_initialized_in_class = True
|
||||
var.is_inferred = True
|
||||
self.model_classdef.info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True)
|
||||
# assert self.model_classdef.info == self.api.type
|
||||
# self.api.add_symbol_table_node(name, SymbolTableNode(MDEF, var, plugin_generated=True))
|
||||
|
||||
def convert_any_to_type(self, typ: MypyType, referred_to_type: MypyType) -> MypyType:
|
||||
if isinstance(typ, UnionType):
|
||||
converted_items = []
|
||||
for item in typ.items:
|
||||
converted_items.append(self.convert_any_to_type(item, referred_to_type))
|
||||
return UnionType.make_union(converted_items,
|
||||
line=typ.line, column=typ.column)
|
||||
if isinstance(typ, Instance):
|
||||
args = []
|
||||
for default_arg in typ.args:
|
||||
if isinstance(default_arg, AnyType):
|
||||
args.append(referred_to_type)
|
||||
else:
|
||||
args.append(default_arg)
|
||||
return helpers.reparametrize_instance(typ, args)
|
||||
|
||||
if isinstance(typ, AnyType):
|
||||
return referred_to_type
|
||||
|
||||
return typ
|
||||
|
||||
def get_field_set_type(self, field: Field, method: str) -> MypyType:
|
||||
target_field = field
|
||||
if isinstance(field, ForeignKey):
|
||||
target_field = field.target_field
|
||||
field_fullname = helpers.get_class_fullname(target_field.__class__)
|
||||
field_info = self.lookup_typeinfo_or_incomplete_defn_error(field_fullname)
|
||||
field_set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type',
|
||||
is_nullable=self.get_field_nullability(field, method))
|
||||
if isinstance(target_field, ArrayField):
|
||||
argument_field_type = self.get_field_set_type(target_field.base_field, method)
|
||||
field_set_type = self.convert_any_to_type(field_set_type, argument_field_type)
|
||||
return field_set_type
|
||||
|
||||
def get_field_nullability(self, field: Field, method: Optional[str]) -> bool:
|
||||
nullable = field.null
|
||||
if not nullable and isinstance(field, CharField) and field.blank:
|
||||
return True
|
||||
if method == '__init__':
|
||||
if field.primary_key or isinstance(field, ForeignKey):
|
||||
return True
|
||||
return nullable
|
||||
|
||||
def get_field_kind(self, field: Field, method: str):
|
||||
if method == '__init__':
|
||||
# all arguments are optional in __init__
|
||||
return ARG_NAMED_OPT
|
||||
|
||||
def get_primary_key_field(self, model_cls: Type[Model]) -> Field:
|
||||
for field in model_cls._meta.get_fields():
|
||||
if isinstance(field, Field):
|
||||
if field.primary_key:
|
||||
return field
|
||||
raise ValueError('No primary key defined')
|
||||
|
||||
def make_field_kwarg(self, name: str, field: Field, method: str) -> Argument:
|
||||
field_set_type = self.get_field_set_type(field, method)
|
||||
kind = self.get_field_kind(field, method)
|
||||
field_kwarg = Argument(variable=Var(name, field_set_type),
|
||||
type_annotation=field_set_type,
|
||||
initializer=None,
|
||||
kind=kind)
|
||||
return field_kwarg
|
||||
|
||||
def get_field_kwargs(self, model_cls: Type[Model], method: str):
|
||||
field_kwargs = []
|
||||
if method == '__init__':
|
||||
# add primary key `pk`
|
||||
primary_key_field = self.get_primary_key_field(model_cls)
|
||||
field_kwarg = self.make_field_kwarg('pk', primary_key_field, method)
|
||||
field_kwargs.append(field_kwarg)
|
||||
|
||||
for field in model_cls._meta.get_fields():
|
||||
if isinstance(field, Field):
|
||||
field_kwarg = self.make_field_kwarg(field.attname, field, method)
|
||||
field_kwargs.append(field_kwarg)
|
||||
|
||||
if isinstance(field, ForeignKey):
|
||||
attname = field.name
|
||||
related_model_fullname = helpers.get_class_fullname(field.related_model)
|
||||
model_info = self.lookup_typeinfo_or_incomplete_defn_error(related_model_fullname)
|
||||
is_nullable = self.get_field_nullability(field, method)
|
||||
field_set_type = Instance(model_info, [])
|
||||
if is_nullable:
|
||||
field_set_type = helpers.make_optional(field_set_type)
|
||||
kind = self.get_field_kind(field, method)
|
||||
field_kwarg = Argument(variable=Var(attname, field_set_type),
|
||||
type_annotation=field_set_type,
|
||||
initializer=None,
|
||||
kind=kind)
|
||||
field_kwargs.append(field_kwarg)
|
||||
return field_kwargs
|
||||
|
||||
@abstractmethod
|
||||
def run(self) -> None:
|
||||
@@ -198,9 +100,9 @@ class AddRelatedModelsId(ModelClassInitializer):
|
||||
|
||||
for field in model_cls._meta.get_fields():
|
||||
if isinstance(field, ForeignKey):
|
||||
rel_primary_key_field = self.get_primary_key_field(field.related_model)
|
||||
rel_primary_key_field = self.django_context.get_primary_key_field(field.related_model)
|
||||
field_info = self.lookup_field_typeinfo_or_incomplete_defn_error(rel_primary_key_field)
|
||||
is_nullable = self.get_field_nullability(field, None)
|
||||
is_nullable = self.django_context.fields_context.get_field_nullability(field, None)
|
||||
set_type, get_type = get_field_descriptor_types(field_info, is_nullable)
|
||||
self.add_new_node_to_model_class(field.attname,
|
||||
Instance(field_info, [set_type, get_type]))
|
||||
@@ -228,16 +130,6 @@ class AddManagers(ModelClassInitializer):
|
||||
self.add_new_node_to_model_class('_default_manager', default_manager)
|
||||
|
||||
|
||||
class AddInitMethod(ModelClassInitializer):
|
||||
def run(self):
|
||||
model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.info.fullname())
|
||||
if model_cls is None:
|
||||
return
|
||||
|
||||
field_kwargs = self.get_field_kwargs(model_cls, '__init__')
|
||||
common.add_method(self.ctx, '__init__', field_kwargs, NoneType())
|
||||
|
||||
|
||||
def process_model_class(ctx: ClassDefContext,
|
||||
django_context: DjangoContext) -> None:
|
||||
initializers = [
|
||||
@@ -245,11 +137,10 @@ def process_model_class(ctx: ClassDefContext,
|
||||
AddDefaultPrimaryKey,
|
||||
AddRelatedModelsId,
|
||||
AddManagers,
|
||||
AddInitMethod
|
||||
]
|
||||
for initializer_cls in initializers:
|
||||
try:
|
||||
initializer_cls.from_ctx(ctx, django_context).run()
|
||||
except helpers.IncompleteDefnException:
|
||||
if not ctx.api.final_iteration:
|
||||
ctx.api.defer()
|
||||
ctx.api.defer()
|
||||
|
||||
@@ -2,7 +2,7 @@ from mypy.nodes import TypeInfo
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import Type as MypyType, TypeType, Instance
|
||||
|
||||
from mypy_django_plugin_newsemanal.context import DjangoContext
|
||||
from mypy_django_plugin_newsemanal.django.context import DjangoContext
|
||||
from mypy_django_plugin_newsemanal.lib import helpers
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user