From dc4f606f6344f244dad5c5d88e4f132fc8f06600 Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Tue, 13 Nov 2018 17:43:21 +0300 Subject: [PATCH] add support for OneToOneField, more stubs --- django-stubs/db/models/__init__.pyi | 8 +- django-stubs/db/models/base.pyi | 12 +- django-stubs/db/models/deletion.pyi | 9 +- django-stubs/db/models/fields/related.pyi | 10 ++ mypy_django_plugin/model_classes.py | 18 -- mypy_django_plugin/plugins/callbacks.py | 164 ------------------ .../plugins/field_to_python_type.py | 30 ---- mypy_django_plugin/plugins/related_fields.py | 49 +++++- test/test-data/check-model-relations.test | 35 ++-- 9 files changed, 97 insertions(+), 238 deletions(-) delete mode 100644 mypy_django_plugin/model_classes.py delete mode 100644 mypy_django_plugin/plugins/callbacks.py delete mode 100644 mypy_django_plugin/plugins/field_to_python_type.py diff --git a/django-stubs/db/models/__init__.pyi b/django-stubs/db/models/__init__.pyi index 4f69bb8..e49c3c3 100644 --- a/django-stubs/db/models/__init__.pyi +++ b/django-stubs/db/models/__init__.pyi @@ -7,6 +7,10 @@ from .fields import (AutoField as AutoField, Field as Field, SlugField as SlugField, TextField as TextField) -from .fields.related import (ForeignKey as ForeignKey) -from .deletion import CASCADE as CASCADE +from .fields.related import (ForeignKey as ForeignKey, + OneToOneField as OneToOneField) +from .deletion import (CASCADE as CASCADE, + SET_DEFAULT as SET_DEFAULT, + SET_NULL as SET_NULL, + DO_NOTHING as DO_NOTHING) from .query import QuerySet as QuerySet \ No newline at end of file diff --git a/django-stubs/db/models/base.pyi b/django-stubs/db/models/base.pyi index 1b044c2..9e43777 100644 --- a/django-stubs/db/models/base.pyi +++ b/django-stubs/db/models/base.pyi @@ -1,6 +1,16 @@ +from typing import Any + + class ModelBase(type): pass class Model(metaclass=ModelBase): - pass \ No newline at end of file + class DoesNotExist(Exception): + pass + + def __init__(self, **kwargs) -> None: ... + + def delete(self, + using: Any = ..., + keep_parents: bool = ...) -> None: ... \ No newline at end of file diff --git a/django-stubs/db/models/deletion.pyi b/django-stubs/db/models/deletion.pyi index 45dc807..fce55fe 100644 --- a/django-stubs/db/models/deletion.pyi +++ b/django-stubs/db/models/deletion.pyi @@ -1,2 +1,7 @@ -def CASCADE(collector, field, sub_objs, using): - ... +def CASCADE(collector, field, sub_objs, using): ... + +def SET_NULL(collector, field, sub_objs, using): ... + +def SET_DEFAULT(collector, field, sub_objs, using): ... + +def DO_NOTHING(collector, field, sub_objs, using): ... diff --git a/django-stubs/db/models/fields/related.pyi b/django-stubs/db/models/fields/related.pyi index 763e7a7..c6a7ac4 100644 --- a/django-stubs/db/models/fields/related.pyi +++ b/django-stubs/db/models/fields/related.pyi @@ -10,5 +10,15 @@ class ForeignKey(Field, Generic[_T]): def __init__(self, to: Union[Type[_T], str], on_delete: Any, + related_name: str = ..., + **kwargs): ... + def __get__(self, instance, owner) -> _T: ... + + +class OneToOneField(Field, Generic[_T]): + def __init__(self, + to: Union[Type[_T], str], + on_delete: Any, + related_name: str = ..., **kwargs): ... def __get__(self, instance, owner) -> _T: ... diff --git a/mypy_django_plugin/model_classes.py b/mypy_django_plugin/model_classes.py deleted file mode 100644 index 5f0f0d5..0000000 --- a/mypy_django_plugin/model_classes.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Set - -import dataclasses - - -def get_default_base_models(): - return {'django.db.models.base.Model'} - - -@dataclasses.dataclass -class DjangoModelsRegistry(object): - base_models: Set[str] = dataclasses.field(default_factory=get_default_base_models) - - def __contains__(self, item: str) -> bool: - return item in self.base_models - - def __iter__(self): - return iter(self.base_models) diff --git a/mypy_django_plugin/plugins/callbacks.py b/mypy_django_plugin/plugins/callbacks.py deleted file mode 100644 index d785a6e..0000000 --- a/mypy_django_plugin/plugins/callbacks.py +++ /dev/null @@ -1,164 +0,0 @@ -import dataclasses -from mypy import types, nodes -from mypy.nodes import CallExpr, StrExpr, AssignmentStmt, RefExpr -from mypy.plugin import AttributeContext, ClassDefContext, SemanticAnalyzerPluginInterface -from mypy.types import Type, Instance, AnyType, TypeOfAny - -from mypy_django_plugin.helpers import lookup_django_model, get_app_model -from mypy_django_plugin.model_classes import DjangoModelsRegistry - -# mapping between field types and plain python types -DB_FIELDS_TO_TYPES = { - 'django.db.models.fields.CharField': 'builtins.str', - 'django.db.models.fields.TextField': 'builtins.str', - 'django.db.models.fields.BooleanField': 'builtins.bool', - # 'django.db.models.fields.NullBooleanField': 'typing.Optional[builtins.bool]', - 'django.db.models.fields.IntegerField': 'builtins.int', - 'django.db.models.fields.AutoField': 'builtins.int', - 'django.db.models.fields.FloatField': 'builtins.float', - 'django.contrib.postgres.fields.jsonb.JSONField': 'builtins.dict', - 'django.contrib.postgres.fields.array.ArrayField': 'typing.Iterable' -} - - -# def get_queryset_of(type_fullname: str): -# return model_definition.api.lookup_fully_qualified_or_none('django.db.models.QuerySet') - - -@dataclasses.dataclass -class DjangoPluginApi(object): - mypy_api: SemanticAnalyzerPluginInterface - - def get_queryset_of(self, type_fullname: str = 'django.db.models.base.Model') -> types.Type: - queryset_sym = self.mypy_api.lookup_fully_qualified_or_none('django.db.models.QuerySet') - if not queryset_sym: - return AnyType(TypeOfAny.from_error) - - generic_arg = self.mypy_api.lookup_fully_qualified_or_none(type_fullname) - if not generic_arg: - return Instance(queryset_sym.node, [AnyType(TypeOfAny.from_error)]) - - return Instance(queryset_sym.node, [Instance(generic_arg.node, [])]) - - def generate_related_manager_assignment_stmt(self, - related_mngr_name: str, - queryset_argument_type_fullname: str) -> nodes.AssignmentStmt: - rvalue = nodes.TempNode(AnyType(TypeOfAny.special_form)) - assignment = nodes.AssignmentStmt(lvalues=[nodes.NameExpr(related_mngr_name)], - rvalue=rvalue, - new_syntax=True, - type=self.get_queryset_of(queryset_argument_type_fullname)) - return assignment - - -class DetermineFieldPythonTypeCallback(object): - def __init__(self, models_registry: DjangoModelsRegistry): - self.models_registry = models_registry - - def __call__(self, attr_context: AttributeContext) -> Type: - default_attr_type = attr_context.default_attr_type - - if isinstance(default_attr_type, Instance): - attr_type_fullname = default_attr_type.type.fullname() - if attr_type_fullname in DB_FIELDS_TO_TYPES: - return attr_context.api.named_type(DB_FIELDS_TO_TYPES[attr_type_fullname]) - - # if 'base' in default_attr_type.type.metadata: - # referred_base_model = default_attr_type.type.metadata['base'] - - if 'members' in attr_context.type.type.metadata: - arg_name = attr_context.context.name - if arg_name in attr_context.type.type.metadata['members']: - referred_base_model = attr_context.type.type.metadata['members'][arg_name] - - typ = lookup_django_model(attr_context.api, referred_base_model) - try: - return Instance(typ.node, []) - except AssertionError as e: - return typ.type - - return default_attr_type - - -# fields which real type is inside to= expression -REFERENCING_DB_FIELDS = { - 'django.db.models.fields.related.ForeignKey', - 'django.db.models.fields.related.OneToOneField' -} - - -def save_referred_to_model_in_metadata(rvalue: CallExpr) -> None: - to_arg_value = rvalue.args[rvalue.arg_names.index('to')] - if isinstance(to_arg_value, StrExpr): - referred_model_fullname = get_app_model(to_arg_value.value) - else: - referred_model_fullname = to_arg_value.fullname - - rvalue.callee.node.metadata['base'] = referred_model_fullname - - -class CollectModelsInformationCallback(object): - def __init__(self, model_registry: DjangoModelsRegistry): - self.model_registry = model_registry - - def __call__(self, model_definition: ClassDefContext) -> None: - self.model_registry.base_models.add(model_definition.cls.fullname) - plugin_api = DjangoPluginApi(mypy_api=model_definition.api) - - for member in model_definition.cls.defs.body: - if isinstance(member, AssignmentStmt): - if len(member.lvalues) > 1: - return None - - arg_name = member.lvalues[0].name - arg_name_as_id = arg_name + '_id' - - rvalue = member.rvalue - if isinstance(rvalue, CallExpr): - if not isinstance(rvalue.callee, RefExpr): - return None - - if rvalue.callee.fullname in REFERENCING_DB_FIELDS: - if rvalue.callee.fullname == 'django.db.models.fields.related.ForeignKey': - model_definition.cls.info.names[arg_name_as_id] = \ - model_definition.api.lookup_fully_qualified('builtins.int') - - referred_to_model = rvalue.args[rvalue.arg_names.index('to')] - if isinstance(referred_to_model, StrExpr): - referred_model_fullname = get_app_model(referred_to_model.value) - else: - referred_model_fullname = referred_to_model.fullname - - rvalue.callee.node.metadata['base'] = referred_model_fullname - - referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname) - - if 'related_name' in rvalue.arg_names: - related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value - referred_model_class_def = referred_model.node.defn # type: nodes.ClassDef - - referred_model_class_def.defs.body.append( - plugin_api.generate_related_manager_assignment_stmt(related_arg_value, - model_definition.cls.fullname)) - - if rvalue.callee.fullname == 'django.db.models.fields.related.OneToOneField': - referred_to_model = rvalue.args[rvalue.arg_names.index('to')] - if isinstance(referred_to_model, StrExpr): - referred_model_fullname = get_app_model(referred_to_model.value) - else: - referred_model_fullname = referred_to_model.fullname - - referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname) - - if 'related_name' in rvalue.arg_names: - related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value - referred_model.node.names[related_arg_value] = \ - model_definition.api.lookup_fully_qualified_or_none(model_definition.cls.fullname) - - rvalue.callee.node.metadata['base'] = referred_model_fullname - - rvalue.callee.node.metadata['name'] = arg_name - if 'members' not in model_definition.cls.info.metadata: - model_definition.cls.info.metadata['members'] = {} - - model_definition.cls.info.metadata['members'][arg_name] = rvalue.callee.node.metadata.get('base', None) diff --git a/mypy_django_plugin/plugins/field_to_python_type.py b/mypy_django_plugin/plugins/field_to_python_type.py deleted file mode 100644 index 6021408..0000000 --- a/mypy_django_plugin/plugins/field_to_python_type.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Optional, Callable, Type - -from mypy.plugin import Plugin, ClassDefContext, AttributeContext - -from mypy_django_plugin.model_classes import DjangoModelsRegistry -from mypy_django_plugin.plugins.callbacks import CollectModelsInformationCallback, DetermineFieldPythonTypeCallback - - -class FieldToPythonTypePlugin(Plugin): - model_registry = DjangoModelsRegistry() - - def get_base_class_hook(self, fullname: str - ) -> Optional[Callable[[ClassDefContext], None]]: - if fullname in self.model_registry: - return CollectModelsInformationCallback(self.model_registry) - - return None - - def get_attribute_hook(self, fullname: str - ) -> Optional[Callable[[AttributeContext], Type]]: - # print(fullname) - classname, _, attrname = fullname.rpartition('.') - if classname and classname in self.model_registry: - return DetermineFieldPythonTypeCallback(self.model_registry) - - return None - - -def plugin(version): - return FieldToPythonTypePlugin diff --git a/mypy_django_plugin/plugins/related_fields.py b/mypy_django_plugin/plugins/related_fields.py index 41acb03..773f58b 100644 --- a/mypy_django_plugin/plugins/related_fields.py +++ b/mypy_django_plugin/plugins/related_fields.py @@ -1,21 +1,33 @@ -from typing import Optional, Callable +from typing import Optional, Callable, cast +from mypy.checker import TypeChecker from mypy.nodes import Var, MDEF, SymbolTableNode from mypy.plugin import Plugin, FunctionContext from mypy.types import Type, CallableType, Instance +def extract_to_value_type(ctx: FunctionContext) -> Optional[Type]: + assert 'to' in ctx.context.arg_names + to_arg_value = ctx.arg_types[ctx.context.arg_names.index('to')][0] + if not isinstance(to_arg_value, CallableType): + return None + + return to_arg_value.ret_type + + +def extract_related_name_value(ctx: FunctionContext) -> str: + return ctx.context.args[ctx.context.arg_names.index('related_name')].value + + def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type: if 'related_name' not in ctx.context.arg_names: return ctx.default_return_type - assert 'to' in ctx.context.arg_names - to_arg_value = ctx.arg_types[ctx.context.arg_names.index('to')][0] - if not isinstance(to_arg_value, CallableType): + referred_to = extract_to_value_type(ctx) + if not referred_to: return ctx.default_return_type - referred_to = to_arg_value.ret_type - related_name = ctx.context.args[ctx.context.arg_names.index('related_name')].value + related_name = extract_related_name_value(ctx) outer_class_info = ctx.api.tscope.classes[-1] queryset_type = ctx.api.named_generic_type('django.db.models.QuerySet', @@ -28,11 +40,36 @@ def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type: return ctx.default_return_type +def set_related_name_instance_for_onetoonefield(ctx: FunctionContext) -> Type: + if 'related_name' not in ctx.context.arg_names: + return ctx.default_return_type + + referred_to = extract_to_value_type(ctx) + if referred_to is None: + return ctx.default_return_type + + related_name = extract_related_name_value(ctx) + outer_class_info = ctx.api.tscope.classes[-1] + + api = cast(TypeChecker, ctx.api) + related_instance_type = api.named_type(outer_class_info.fullname()) + related_var = Var(related_name, related_instance_type) + related_var.info = related_instance_type.type + + referred_to.type.names[related_name] = SymbolTableNode(MDEF, related_var, + plugin_generated=True) + return ctx.default_return_type + + class RelatedFieldsPlugin(Plugin): def get_function_hook(self, fullname: str ) -> Optional[Callable[[FunctionContext], Type]]: if fullname == 'django.db.models.fields.related.ForeignKey': return set_related_name_manager_for_foreign_key + + if fullname == 'django.db.models.fields.related.OneToOneField': + return set_related_name_instance_for_onetoonefield + return None diff --git a/test/test-data/check-model-relations.test b/test/test-data/check-model-relations.test index 167458a..2b63be2 100644 --- a/test/test-data/check-model-relations.test +++ b/test/test-data/check-model-relations.test @@ -1,18 +1,4 @@ -[case testForeignKeyWithClass] -from django.db import models - -class Publisher(models.Model): - pass - -class Book(models.Model): - publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE) - -book = Book() -reveal_type(book.publisher) # E: Revealed type is 'main.Publisher*' -[out] - - -[case testForeignKeyRelatedName] +[case testForeignKeyField] from django.db import models class Publisher(models.Model): @@ -22,6 +8,25 @@ class Book(models.Model): publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE, related_name='books') +book = Book() +reveal_type(book.publisher) # E: Revealed type is 'main.Publisher*' + publisher = Publisher() reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]' [out] + +[case testOneToOneField] +from django.db import models + +class User(models.Model): + pass + +class Profile(models.Model): + user = models.OneToOneField(to=User, on_delete=models.CASCADE, related_name='profile') + +profile = Profile() +reveal_type(profile.user) # E: Revealed type is 'main.User*' + +user = User() +reveal_type(user.profile) # E: Revealed type is 'main.Profile' +