diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index a2f68fe..331a3f9 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -1,6 +1,6 @@ from typing import Optional, Tuple, cast -from django.db.models.fields import Field +from django.db.models.fields import AutoField, Field from django.db.models.fields.related import RelatedField from mypy.nodes import AssignmentStmt, NameExpr, TypeInfo from mypy.plugin import FunctionContext @@ -99,13 +99,26 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context ) -def get_field_descriptor_types(field_info: TypeInfo, is_nullable: bool) -> Tuple[MypyType, MypyType]: - set_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_set_type", is_nullable=is_nullable) - get_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_get_type", is_nullable=is_nullable) +def get_field_descriptor_types( + field_info: TypeInfo, *, is_set_nullable: bool, is_get_nullable: bool +) -> Tuple[MypyType, MypyType]: + set_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_set_type", is_nullable=is_set_nullable) + get_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_get_type", is_nullable=is_get_nullable) return set_type, get_type -def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance: +def set_descriptor_types_for_field_callback(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: + current_field = _get_current_field_from_assignment(ctx, django_context) + if current_field is not None: + if isinstance(current_field, AutoField): + return set_descriptor_types_for_field(ctx, is_set_nullable=True) + + return set_descriptor_types_for_field(ctx) + + +def set_descriptor_types_for_field( + ctx: FunctionContext, *, is_set_nullable: bool = False, is_get_nullable: bool = False +) -> Instance: default_return_type = cast(Instance, ctx.default_return_type) is_nullable = False @@ -113,7 +126,11 @@ def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance: if null_expr is not None: is_nullable = helpers.parse_bool(null_expr) or False - set_type, get_type = get_field_descriptor_types(default_return_type.type, is_nullable) + set_type, get_type = get_field_descriptor_types( + default_return_type.type, + is_set_nullable=is_set_nullable or is_nullable, + is_get_nullable=is_get_nullable or is_nullable, + ) return helpers.reparametrize_instance(default_return_type, [set_type, get_type]) @@ -148,4 +165,4 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan if default_return_type.type.has_base(fullnames.ARRAY_FIELD_FULLNAME): return determine_type_of_array_field(ctx, django_context) - return set_descriptor_types_for_field(ctx) + return set_descriptor_types_for_field_callback(ctx, django_context) diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index 14d09f3..00ad398 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -97,7 +97,9 @@ class AddDefaultPrimaryKey(ModelClassInitializer): auto_field_fullname = helpers.get_class_fullname(auto_field.__class__) auto_field_info = self.lookup_typeinfo_or_incomplete_defn_error(auto_field_fullname) - set_type, get_type = fields.get_field_descriptor_types(auto_field_info, is_nullable=False) + set_type, get_type = fields.get_field_descriptor_types( + auto_field_info, is_set_nullable=True, is_get_nullable=False + ) self.add_new_node_to_model_class(auto_field.attname, Instance(auto_field_info, [set_type, get_type])) @@ -131,7 +133,9 @@ class AddRelatedModelsId(ModelClassInitializer): continue is_nullable = self.django_context.get_field_nullability(field, None) - set_type, get_type = get_field_descriptor_types(field_info, is_nullable) + set_type, get_type = get_field_descriptor_types( + field_info, is_set_nullable=is_nullable, is_get_nullable=is_nullable + ) self.add_new_node_to_model_class(field.attname, Instance(field_info, [set_type, get_type])) diff --git a/tests/typecheck/fields/test_nullable.yml b/tests/typecheck/fields/test_nullable.yml index e1fb2c6..f3ddaaa 100644 --- a/tests/typecheck/fields/test_nullable.yml +++ b/tests/typecheck/fields/test_nullable.yml @@ -1,3 +1,36 @@ +- case: autofield_can_be_set_to_none + main: | + from myapp.models import MyModel, MyModelExplicitPK + m = MyModel() + m.id = 3 + m.id = None + m2 = MyModel(id=None) + MyModel.objects.create(id=None) + MyModel.objects.all().update(id=None) # Should give an error since there's a not-null constraint + + def foo(a: int) -> bool: + return True + m2 = MyModel() + foo(m2.id) + + # At runtime, this would be an error, unless m.save() was called to populate the `id` field. + # but the plugin cannot catch this. + foo(m.id) + + exp = MyModelExplicitPK() + exp.id = 3 + exp.id = None + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class MyModel(models.Model): + pass + class MyModelExplicitPK(models.Model): + id = models.AutoField(primary_key=True) - case: nullable_field_with_strict_optional_true main: | from myapp.models import MyModel