diff --git a/django-stubs-generated/contrib/postgres/fields/array.pyi b/django-stubs-generated/contrib/postgres/fields/array.pyi index b926f5a..85c5212 100644 --- a/django-stubs-generated/contrib/postgres/fields/array.pyi +++ b/django-stubs-generated/contrib/postgres/fields/array.pyi @@ -19,7 +19,7 @@ class ArrayField(CheckFieldDefaultMixin, Field, Generic[_T]): from_db_value: Any = ... def __init__( - self, base_field: _T, size: None = ..., **kwargs: Any + self, base_field: Field, size: None = ..., **kwargs: Any ) -> None: ... @property def model(self): ... diff --git a/django-stubs/contrib/postgres/fields/array.pyi b/django-stubs/contrib/postgres/fields/array.pyi index c348fc0..56f0bcf 100644 --- a/django-stubs/contrib/postgres/fields/array.pyi +++ b/django-stubs/contrib/postgres/fields/array.pyi @@ -1,11 +1,14 @@ -from typing import List, Any +from typing import List, Any, TypeVar, Generic from django.contrib.postgres.fields.mixins import CheckFieldDefaultMixin from django.db.models import Field +_T = TypeVar('_T', bound=Field) -class ArrayField(CheckFieldDefaultMixin, Field): + +class ArrayField(CheckFieldDefaultMixin, Field, Generic[_T]): def __init__(self, base_field: Field, **kwargs): ... - def __get__(self, instance, owner) -> List[Any]: ... \ No newline at end of file + + def __get__(self, instance, owner) -> List[_T]: ... diff --git a/mypy_django_plugin/plugins/postgres_fields.py b/mypy_django_plugin/plugins/postgres_fields.py index 629382e..7b4512c 100644 --- a/mypy_django_plugin/plugins/postgres_fields.py +++ b/mypy_django_plugin/plugins/postgres_fields.py @@ -1,14 +1,24 @@ -from mypy.plugin import Plugin, ClassDefContext +from typing import Optional, Callable + +from mypy.plugin import Plugin, FunctionContext +from mypy.types import Type -def determine_type_of_array_field(context: ClassDefContext) -> None: - pass +def determine_type_of_array_field(ctx: FunctionContext) -> Type: + assert 'base_field' in ctx.context.arg_names + base_field_arg_index = ctx.context.arg_names.index('base_field') + base_field_arg_type = ctx.arg_types[base_field_arg_index][0] + + return ctx.api.named_generic_type(ctx.context.callee.fullname, + args=[base_field_arg_type.type.names['__get__'].type.ret_type]) class PostgresFieldsPlugin(Plugin): - def get_base_class_hook(self, fullname: str - ): - return determine_type_of_array_field + def get_function_hook(self, fullname: str + ) -> Optional[Callable[[FunctionContext], Type]]: + if fullname == 'django.contrib.postgres.fields.array.ArrayField': + return determine_type_of_array_field + return None def plugin(version): diff --git a/test/test-data/check-postgres-fields.test b/test/test-data/check-postgres-fields.test index 9c86871..e469399 100644 --- a/test/test-data/check-postgres-fields.test +++ b/test/test-data/check-postgres-fields.test @@ -7,7 +7,9 @@ from django.contrib.postgres.fields import ArrayField class User(models.Model): members = ArrayField(base_field=models.IntegerField()) + members_as_text = ArrayField(base_field=models.CharField(max_length=255)) user = User() -reveal_type(user.members) # E: Revealed type is 'typing.List[int]' +reveal_type(user.members) # E: Revealed type is 'builtins.list[builtins.int*]' +reveal_type(user.members_as_text) # E: Revealed type is 'builtins.list[builtins.str*]' [out] \ No newline at end of file