diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index 73ccf93..fb83bfd 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -1,7 +1,7 @@ from typing import Dict, List, Optional, Type, Union, cast from django.db.models.base import Model -from django.db.models.fields import DateField, DateTimeField +from django.db.models.fields import DateField, DateTimeField, Field from django.db.models.fields.related import ForeignKey from django.db.models.fields.reverse_related import ManyToManyRel, ManyToOneRel, OneToOneRel from mypy.checker import TypeChecker @@ -69,7 +69,7 @@ class ModelClassInitializer: self.run_with_model_cls(model_cls) def run_with_model_cls(self, model_cls): - pass + raise NotImplementedError("Implement this in subclasses") class InjectAnyAsBaseForNestedMeta(ModelClassInitializer): @@ -95,15 +95,42 @@ class InjectAnyAsBaseForNestedMeta(ModelClassInitializer): class AddDefaultPrimaryKey(ModelClassInitializer): def run_with_model_cls(self, model_cls: Type[Model]) -> None: auto_field = model_cls._meta.auto_field - if auto_field and not self.model_classdef.info.has_readable_member(auto_field.attname): - # autogenerated field + if auto_field: + self.create_autofield( + auto_field=auto_field, + dest_name=auto_field.attname, + existing_field=not self.model_classdef.info.has_readable_member(auto_field.attname), + ) + + def create_autofield( + self, + auto_field: Field, + dest_name: str, + existing_field: bool, + ) -> None: + if existing_field: auto_field_fullname = helpers.get_class_fullname(auto_field.__class__) auto_field_info = self.lookup_typeinfo_or_incomplete_defn_error(auto_field_fullname) set_type, get_type = fields.get_field_descriptor_types( - auto_field_info, is_set_nullable=True, is_get_nullable=False + auto_field_info, + is_set_nullable=True, + is_get_nullable=False, + ) + + self.add_new_node_to_model_class(dest_name, Instance(auto_field_info, [set_type, get_type])) + + +class AddPrimaryKeyAlias(AddDefaultPrimaryKey): + def run_with_model_cls(self, model_cls: Type[Model]) -> None: + # We also need to override existing `pk` definition from `stubs`: + auto_field = model_cls._meta.pk + if auto_field: + self.create_autofield( + auto_field=auto_field, + dest_name="pk", + existing_field=self.model_classdef.info.has_readable_member(auto_field.name), ) - self.add_new_node_to_model_class(auto_field.attname, Instance(auto_field_info, [set_type, get_type])) class AddRelatedModelsId(ModelClassInitializer): @@ -117,7 +144,7 @@ class AddRelatedModelsId(ModelClassInitializer): if field_sym is not None and field_sym.node is not None: error_context = field_sym.node self.api.fail( - f"Cannot find model {field.related_model!r} " f"referenced in field {field.name!r} ", + f"Cannot find model {field.related_model!r} referenced in field {field.name!r}", ctx=error_context, ) self.add_new_node_to_model_class(field.attname, AnyType(TypeOfAny.explicit)) @@ -374,6 +401,7 @@ def process_model_class(ctx: ClassDefContext, django_context: DjangoContext) -> initializers = [ InjectAnyAsBaseForNestedMeta, AddDefaultPrimaryKey, + AddPrimaryKeyAlias, AddRelatedModelsId, AddManagers, AddDefaultManagerAttribute, diff --git a/tests/typecheck/models/test_primary_key.yml b/tests/typecheck/models/test_primary_key.yml new file mode 100644 index 0000000..1c989d6 --- /dev/null +++ b/tests/typecheck/models/test_primary_key.yml @@ -0,0 +1,66 @@ +- case: test_access_to_id_field_through_self_if_no_primary_key_defined + main: | + from myapp.models import MyModel + x = MyModel.objects.get(id=1) + reveal_type(x.id) # N: Revealed type is "builtins.int*" + reveal_type(x.pk) # N: Revealed type is "builtins.int*" + + MyModel.objects.get(pk=1) + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class MyModel(models.Model): + def __str__(self): + reveal_type(self.id) # N: Revealed type is "builtins.int*" + reveal_type(self.pk) # N: Revealed type is "builtins.int*" + + +- case: test_access_to_id_field_through_self_if_primary_key_is_defined + main: | + from myapp.models import MyModel + x = MyModel.objects.get(id='a') + reveal_type(x.id) # N: Revealed type is "builtins.str*" + reveal_type(x.pk) # N: Revealed type is "builtins.str*" + + MyModel.objects.get(pk='a') + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class MyModel(models.Model): + id = models.CharField(max_length=10, primary_key=True) + def __str__(self): + reveal_type(self.id) # N: Revealed type is "builtins.str*" + reveal_type(self.pk) # N: Revealed type is "builtins.str*" + + +- case: test_access_to_id_field_through_self_if_primary_key_has_different_name + main: | + from myapp.models import MyModel + x = MyModel.objects.get(primary='a') + reveal_type(x.primary) # N: Revealed type is "builtins.str*" + reveal_type(x.pk) # N: Revealed type is "builtins.str*" + x.id # E: "MyModel" has no attribute "id" + + MyModel.objects.get(pk='a') + MyModel.objects.get(id='a') # E: Cannot resolve keyword 'id' into field. Choices are: primary + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class MyModel(models.Model): + primary = models.CharField(max_length=10, primary_key=True) + def __str__(self): + reveal_type(self.primary) # N: Revealed type is "builtins.str*" + reveal_type(self.pk) # N: Revealed type is "builtins.str*" + self.id # E: "MyModel" has no attribute "id"