add plugin support for ArrayField base_field

This commit is contained in:
Maxim Kurnikov
2018-11-10 19:31:06 +03:00
parent 0bd4bc98fc
commit 0ab77f8f05
4 changed files with 26 additions and 11 deletions

View File

@@ -19,7 +19,7 @@ class ArrayField(CheckFieldDefaultMixin, Field, Generic[_T]):
from_db_value: Any = ... from_db_value: Any = ...
def __init__( def __init__(
self, base_field: _T, size: None = ..., **kwargs: Any self, base_field: Field, size: None = ..., **kwargs: Any
) -> None: ... ) -> None: ...
@property @property
def model(self): ... def model(self): ...

View File

@@ -1,11 +1,14 @@
from typing import List, Any from typing import List, Any, TypeVar, Generic
from django.contrib.postgres.fields.mixins import CheckFieldDefaultMixin from django.contrib.postgres.fields.mixins import CheckFieldDefaultMixin
from django.db.models import Field from django.db.models import Field
_T = TypeVar('_T', bound=Field)
class ArrayField(CheckFieldDefaultMixin, Field):
class ArrayField(CheckFieldDefaultMixin, Field, Generic[_T]):
def __init__(self, def __init__(self,
base_field: Field, base_field: Field,
**kwargs): ... **kwargs): ...
def __get__(self, instance, owner) -> List[Any]: ...
def __get__(self, instance, owner) -> List[_T]: ...

View File

@@ -1,14 +1,24 @@
from mypy.plugin import Plugin, ClassDefContext from typing import Optional, Callable
from mypy.plugin import Plugin, FunctionContext
from mypy.types import Type
def determine_type_of_array_field(context: ClassDefContext) -> None: def determine_type_of_array_field(ctx: FunctionContext) -> Type:
pass assert 'base_field' in ctx.context.arg_names
base_field_arg_index = ctx.context.arg_names.index('base_field')
base_field_arg_type = ctx.arg_types[base_field_arg_index][0]
return ctx.api.named_generic_type(ctx.context.callee.fullname,
args=[base_field_arg_type.type.names['__get__'].type.ret_type])
class PostgresFieldsPlugin(Plugin): class PostgresFieldsPlugin(Plugin):
def get_base_class_hook(self, fullname: str def get_function_hook(self, fullname: str
): ) -> Optional[Callable[[FunctionContext], Type]]:
return determine_type_of_array_field if fullname == 'django.contrib.postgres.fields.array.ArrayField':
return determine_type_of_array_field
return None
def plugin(version): def plugin(version):

View File

@@ -7,7 +7,9 @@ from django.contrib.postgres.fields import ArrayField
class User(models.Model): class User(models.Model):
members = ArrayField(base_field=models.IntegerField()) members = ArrayField(base_field=models.IntegerField())
members_as_text = ArrayField(base_field=models.CharField(max_length=255))
user = User() user = User()
reveal_type(user.members) # E: Revealed type is 'typing.List[int]' reveal_type(user.members) # E: Revealed type is 'builtins.list[builtins.int*]'
reveal_type(user.members_as_text) # E: Revealed type is 'builtins.list[builtins.str*]'
[out] [out]