split helpers into smaller files

This commit is contained in:
Maxim Kurnikov
2019-07-12 15:09:51 +03:00
parent a9c1bcbbc6
commit 9c5a6be9a7
18 changed files with 199 additions and 200 deletions

View File

@@ -52,7 +52,7 @@ class BaseExpression:
is_summary: bool = ... is_summary: bool = ...
filterable: bool = ... filterable: bool = ...
window_compatible: bool = ... window_compatible: bool = ...
output_field: Any output_field: Field
def __init__(self, output_field: Optional[_OutputField] = ...) -> None: ... def __init__(self, output_field: Optional[_OutputField] = ...) -> None: ...
def get_db_converters(self, connection: Any) -> List[Callable]: ... def get_db_converters(self, connection: Any) -> List[Callable]: ...
def get_source_expressions(self) -> List[Any]: ... def get_source_expressions(self) -> List[Any]: ...
@@ -74,8 +74,6 @@ class BaseExpression:
@property @property
def field(self) -> Field: ... def field(self) -> Field: ...
@property @property
def output_field(self) -> Field: ...
@property
def convert_value(self) -> Callable: ... def convert_value(self) -> Callable: ...
def get_lookup(self, lookup: str) -> Optional[Type[Lookup]]: ... def get_lookup(self, lookup: str) -> Optional[Type[Lookup]]: ...
def get_transform(self, name: str) -> Optional[Type[Expression]]: ... def get_transform(self, name: str) -> Optional[Type[Expression]]: ...

View File

@@ -14,6 +14,7 @@ from django.http import HttpResponse
class ContentNotRenderedError(Exception): ... class ContentNotRenderedError(Exception): ...
class SimpleTemplateResponse(HttpResponse): class SimpleTemplateResponse(HttpResponse):
content: Any = ...
closed: bool closed: bool
cookies: SimpleCookie cookies: SimpleCookie
status_code: int status_code: int
@@ -35,15 +36,10 @@ class SimpleTemplateResponse(HttpResponse):
@property @property
def rendered_content(self) -> str: ... def rendered_content(self) -> str: ...
def add_post_render_callback(self, callback: Callable) -> None: ... def add_post_render_callback(self, callback: Callable) -> None: ...
content: Any = ...
def render(self) -> SimpleTemplateResponse: ... def render(self) -> SimpleTemplateResponse: ...
@property @property
def is_rendered(self) -> bool: ... def is_rendered(self) -> bool: ...
def __iter__(self) -> Any: ... def __iter__(self) -> Any: ...
@property
def content(self): ...
@content.setter
def content(self, value: Any) -> None: ...
class TemplateResponse(SimpleTemplateResponse): class TemplateResponse(SimpleTemplateResponse):
client: Client client: Client

View File

View File

@@ -0,0 +1,29 @@
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
FIELD_FULLNAME = 'django.db.models.fields.Field'
CHAR_FIELD_FULLNAME = 'django.db.models.fields.CharField'
ARRAY_FIELD_FULLNAME = 'django.contrib.postgres.fields.array.ArrayField'
AUTO_FIELD_FULLNAME = 'django.db.models.fields.AutoField'
GENERIC_FOREIGN_KEY_FULLNAME = 'django.contrib.contenttypes.fields.GenericForeignKey'
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
MANYTOMANY_FIELD_FULLNAME = 'django.db.models.fields.related.ManyToManyField'
DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject'
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
BASE_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.BaseManager'
MANAGER_CLASS_FULLNAME = 'django.db.models.manager.Manager'
RELATED_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.RelatedManager'
BASEFORM_CLASS_FULLNAME = 'django.forms.forms.BaseForm'
FORM_CLASS_FULLNAME = 'django.forms.forms.Form'
MODELFORM_CLASS_FULLNAME = 'django.forms.models.ModelForm'
FORM_MIXIN_CLASS_FULLNAME = 'django.views.generic.edit.FormMixin'
MANAGER_CLASSES = {
MANAGER_CLASS_FULLNAME,
RELATED_MANAGER_CLASS_FULLNAME,
BASE_MANAGER_CLASS_FULLNAME,
QUERYSET_CLASS_FULLNAME
}

View File

@@ -12,38 +12,11 @@ from mypy.types import (
AnyType, Instance, NoneTyp, TupleType, Type, TypedDictType, TypeOfAny, TypeVarType, UnionType, AnyType, Instance, NoneTyp, TupleType, Type, TypedDictType, TypeOfAny, TypeVarType, UnionType,
) )
from mypy_django_plugin.lib import metadata, fullnames
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
FIELD_FULLNAME = 'django.db.models.fields.Field'
CHAR_FIELD_FULLNAME = 'django.db.models.fields.CharField'
ARRAY_FIELD_FULLNAME = 'django.contrib.postgres.fields.array.ArrayField'
AUTO_FIELD_FULLNAME = 'django.db.models.fields.AutoField'
GENERIC_FOREIGN_KEY_FULLNAME = 'django.contrib.contenttypes.fields.GenericForeignKey'
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
MANYTOMANY_FIELD_FULLNAME = 'django.db.models.fields.related.ManyToManyField'
DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject'
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
BASE_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.BaseManager'
MANAGER_CLASS_FULLNAME = 'django.db.models.manager.Manager'
RELATED_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.RelatedManager'
BASEFORM_CLASS_FULLNAME = 'django.forms.forms.BaseForm'
FORM_CLASS_FULLNAME = 'django.forms.forms.Form'
MODELFORM_CLASS_FULLNAME = 'django.forms.models.ModelForm'
FORM_MIXIN_CLASS_FULLNAME = 'django.views.generic.edit.FormMixin'
MANAGER_CLASSES = {
MANAGER_CLASS_FULLNAME,
RELATED_MANAGER_CLASS_FULLNAME,
BASE_MANAGER_CLASS_FULLNAME,
QUERYSET_CLASS_FULLNAME
}
def get_models_file(app_name: str, all_modules: typing.Dict[str, MypyFile]) -> Optional[MypyFile]: def get_models_file(app_name: str, all_modules: typing.Dict[str, MypyFile]) -> Optional[MypyFile]:
models_module = '.'.join([app_name, 'models']) models_module = '.'.join([app_name, 'models'])
@@ -208,10 +181,10 @@ def iter_over_assignments(class_or_module: typing.Union[ClassDef, MypyFile]
def extract_field_setter_type(tp: Instance) -> Optional[Type]: def extract_field_setter_type(tp: Instance) -> Optional[Type]:
""" Extract __set__ value of a field. """ """ Extract __set__ value of a field. """
if tp.type.has_base(FIELD_FULLNAME): if tp.type.has_base(fullnames.FIELD_FULLNAME):
return tp.args[0] return tp.args[0]
# GenericForeignKey # GenericForeignKey
if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME): if tp.type.has_base(fullnames.GENERIC_FOREIGN_KEY_FULLNAME):
return AnyType(TypeOfAny.special_form) return AnyType(TypeOfAny.special_form)
return None return None
@@ -220,39 +193,19 @@ def extract_field_getter_type(tp: Type) -> Optional[Type]:
""" Extract return type of __get__ of subclass of Field""" """ Extract return type of __get__ of subclass of Field"""
if not isinstance(tp, Instance): if not isinstance(tp, Instance):
return None return None
if tp.type.has_base(FIELD_FULLNAME): if tp.type.has_base(fullnames.FIELD_FULLNAME):
return tp.args[1] return tp.args[1]
# GenericForeignKey # GenericForeignKey
if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME): if tp.type.has_base(fullnames.GENERIC_FOREIGN_KEY_FULLNAME):
return AnyType(TypeOfAny.special_form) return AnyType(TypeOfAny.special_form)
return None return None
def get_django_metadata(model_info: TypeInfo) -> Dict[str, typing.Any]:
return model_info.metadata.setdefault('django', {})
def get_related_field_primary_key_names(base_model: TypeInfo) -> typing.List[str]:
return get_django_metadata(base_model).setdefault('related_field_primary_keys', [])
def get_fields_metadata(model: TypeInfo) -> Dict[str, typing.Any]:
return get_django_metadata(model).setdefault('fields', {})
def get_lookups_metadata(model: TypeInfo) -> Dict[str, typing.Any]:
return get_django_metadata(model).setdefault('lookups', {})
def get_related_managers_metadata(model: TypeInfo) -> Dict[str, typing.Any]:
return get_django_metadata(model).setdefault('related_managers', {})
def extract_explicit_set_type_of_model_primary_key(model: TypeInfo) -> Optional[Type]: def extract_explicit_set_type_of_model_primary_key(model: TypeInfo) -> Optional[Type]:
""" """
If field with primary_key=True is set on the model, extract its __set__ type. If field with primary_key=True is set on the model, extract its __set__ type.
""" """
for field_name, props in get_fields_metadata(model).items(): for field_name, props in metadata.get_fields_metadata(model).items():
is_primary_key = props.get('primary_key', False) is_primary_key = props.get('primary_key', False)
if is_primary_key: if is_primary_key:
return extract_field_setter_type(model.names[field_name].type) return extract_field_setter_type(model.names[field_name].type)
@@ -260,7 +213,7 @@ def extract_explicit_set_type_of_model_primary_key(model: TypeInfo) -> Optional[
def extract_primary_key_type_for_get(model: TypeInfo) -> Optional[Type]: def extract_primary_key_type_for_get(model: TypeInfo) -> Optional[Type]:
for field_name, props in get_fields_metadata(model).items(): for field_name, props in metadata.get_fields_metadata(model).items():
is_primary_key = props.get('primary_key', False) is_primary_key = props.get('primary_key', False)
if is_primary_key: if is_primary_key:
return extract_field_getter_type(model.names[field_name].type) return extract_field_getter_type(model.names[field_name].type)
@@ -312,13 +265,13 @@ def get_assigned_value_for_class(type_info: TypeInfo, name: str) -> Optional[Exp
def is_field_nullable(model: TypeInfo, field_name: str) -> bool: def is_field_nullable(model: TypeInfo, field_name: str) -> bool:
return get_fields_metadata(model).get(field_name, {}).get('null', False) return metadata.get_fields_metadata(model).get(field_name, {}).get('null', False)
def is_foreign_key_like(t: Type) -> bool: def is_foreign_key_like(t: Type) -> bool:
if not isinstance(t, Instance): if not isinstance(t, Instance):
return False return False
return has_any_of_bases(t.type, (FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME)) return has_any_of_bases(t.type, (fullnames.FOREIGN_KEY_FULLNAME, fullnames.ONETOONE_FIELD_FULLNAME))
def build_class_with_annotated_fields(api: 'TypeChecker', base: Type, fields: 'OrderedDict[str, Type]', def build_class_with_annotated_fields(api: 'TypeChecker', base: Type, fields: 'OrderedDict[str, Type]',
@@ -408,7 +361,7 @@ def iter_call_assignments(klass: ClassDef) -> typing.Iterator[typing.Tuple[Lvalu
def get_related_manager_type_from_metadata(model_info: TypeInfo, related_manager_name: str, def get_related_manager_type_from_metadata(model_info: TypeInfo, related_manager_name: str,
api: CheckerPluginInterface) -> Optional[Instance]: api: CheckerPluginInterface) -> Optional[Instance]:
related_manager_metadata = get_related_managers_metadata(model_info) related_manager_metadata = metadata.get_related_managers_metadata(model_info)
if not related_manager_metadata: if not related_manager_metadata:
return None return None
@@ -435,7 +388,7 @@ def get_related_manager_type_from_metadata(model_info: TypeInfo, related_manager
def get_primary_key_field_name(model_info: TypeInfo) -> Optional[str]: def get_primary_key_field_name(model_info: TypeInfo) -> Optional[str]:
for base in model_info.mro: for base in model_info.mro:
fields = get_fields_metadata(base) fields = metadata.get_fields_metadata(base)
for field_name, field_props in fields.items(): for field_name, field_props in fields.items():
is_primary_key = field_props.get('primary_key', False) is_primary_key = field_props.get('primary_key', False)
if is_primary_key: if is_primary_key:

View File

@@ -5,7 +5,7 @@ from mypy.nodes import TypeInfo
from mypy.plugin import CheckerPluginInterface from mypy.plugin import CheckerPluginInterface
from mypy.types import Instance, Type from mypy.types import Instance, Type
from mypy_django_plugin import helpers from mypy_django_plugin.lib import metadata, helpers
@dataclasses.dataclass @dataclasses.dataclass
@@ -138,12 +138,12 @@ def get_actual_field_name_for_lookup_field(lookup: str, model_type_info: TypeInf
If it's not, return the original lookup. If it's not, return the original lookup.
""" """
lookups_metadata = helpers.get_lookups_metadata(model_type_info) lookups_metadata = metadata.get_lookups_metadata(model_type_info)
lookup_metadata = lookups_metadata.get(lookup) lookup_metadata = lookups_metadata.get(lookup)
if lookup_metadata is None: if lookup_metadata is None:
# If not found on current model, look in all bases for their lookup metadata # If not found on current model, look in all bases for their lookup metadata
for base in model_type_info.mro: for base in model_type_info.mro:
lookups_metadata = helpers.get_lookups_metadata(base) lookups_metadata = metadata.get_lookups_metadata(base)
lookup_metadata = lookups_metadata.get(lookup) lookup_metadata = lookups_metadata.get(lookup)
if lookup_metadata: if lookup_metadata:
break break

View File

@@ -0,0 +1,23 @@
from typing import Any, Dict, List
from mypy.nodes import TypeInfo
def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]:
return model_info.metadata.setdefault('django', {})
def get_related_field_primary_key_names(base_model: TypeInfo) -> List[str]:
return get_django_metadata(base_model).setdefault('related_field_primary_keys', [])
def get_fields_metadata(model: TypeInfo) -> Dict[str, Any]:
return get_django_metadata(model).setdefault('fields', {})
def get_lookups_metadata(model: TypeInfo) -> Dict[str, Any]:
return get_django_metadata(model).setdefault('lookups', {})
def get_related_managers_metadata(model: TypeInfo) -> Dict[str, Any]:
return get_django_metadata(model).setdefault('related_managers', {})

View File

@@ -9,7 +9,7 @@ from mypy.plugin import (
) )
from mypy.types import AnyType, Instance, Type, TypeOfAny from mypy.types import AnyType, Instance, Type, TypeOfAny
from mypy_django_plugin import helpers from mypy_django_plugin.lib import metadata, fullnames, helpers
from mypy_django_plugin.config import Config from mypy_django_plugin.config import Config
from mypy_django_plugin.transformers import fields, init_create from mypy_django_plugin.transformers import fields, init_create
from mypy_django_plugin.transformers.forms import ( from mypy_django_plugin.transformers.forms import (
@@ -33,27 +33,27 @@ from mypy_django_plugin.transformers.settings import (
def transform_model_class(ctx: ClassDefContext, ignore_missing_model_attributes: bool) -> None: def transform_model_class(ctx: ClassDefContext, ignore_missing_model_attributes: bool) -> None:
try: try:
sym = ctx.api.lookup_fully_qualified(helpers.MODEL_CLASS_FULLNAME) sym = ctx.api.lookup_fully_qualified(fullnames.MODEL_CLASS_FULLNAME)
except KeyError: except KeyError:
# models.Model is not loaded, skip metadata model write # models.Model is not loaded, skip metadata model write
pass pass
else: else:
if sym is not None and isinstance(sym.node, TypeInfo): if sym is not None and isinstance(sym.node, TypeInfo):
helpers.get_django_metadata(sym.node)['model_bases'][ctx.cls.fullname] = 1 metadata.get_django_metadata(sym.node)['model_bases'][ctx.cls.fullname] = 1
process_model_class(ctx, ignore_missing_model_attributes) process_model_class(ctx, ignore_missing_model_attributes)
def transform_manager_class(ctx: ClassDefContext) -> None: def transform_manager_class(ctx: ClassDefContext) -> None:
sym = ctx.api.lookup_fully_qualified_or_none(helpers.MANAGER_CLASS_FULLNAME) sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo): if sym is not None and isinstance(sym.node, TypeInfo):
helpers.get_django_metadata(sym.node)['manager_bases'][ctx.cls.fullname] = 1 metadata.get_django_metadata(sym.node)['manager_bases'][ctx.cls.fullname] = 1
def transform_form_class(ctx: ClassDefContext) -> None: def transform_form_class(ctx: ClassDefContext) -> None:
sym = ctx.api.lookup_fully_qualified_or_none(helpers.BASEFORM_CLASS_FULLNAME) sym = ctx.api.lookup_fully_qualified_or_none(fullnames.BASEFORM_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo): if sym is not None and isinstance(sym.node, TypeInfo):
helpers.get_django_metadata(sym.node)['baseform_bases'][ctx.cls.fullname] = 1 metadata.get_django_metadata(sym.node)['baseform_bases'][ctx.cls.fullname] = 1
make_meta_nested_class_inherit_from_any(ctx) make_meta_nested_class_inherit_from_any(ctx)
@@ -67,16 +67,16 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type:
# not in class # not in class
return ret return ret
outer_model_info = api.tscope.classes[0] outer_model_info = api.tscope.classes[0]
if not outer_model_info.has_base(helpers.MODEL_CLASS_FULLNAME): if not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME):
return ret return ret
if not isinstance(ret, Instance): if not isinstance(ret, Instance):
return ret return ret
has_manager_base = False has_manager_base = False
for i, base in enumerate(ret.type.bases): for i, base in enumerate(ret.type.bases):
if base.type.fullname() in {helpers.MANAGER_CLASS_FULLNAME, if base.type.fullname() in {fullnames.MANAGER_CLASS_FULLNAME,
helpers.RELATED_MANAGER_CLASS_FULLNAME, fullnames.RELATED_MANAGER_CLASS_FULLNAME,
helpers.BASE_MANAGER_CLASS_FULLNAME}: fullnames.BASE_MANAGER_CLASS_FULLNAME}:
has_manager_base = True has_manager_base = True
break break
@@ -118,7 +118,7 @@ def return_type_for_id_field(ctx: AttributeContext) -> Type:
def transform_form_view(ctx: ClassDefContext) -> None: def transform_form_view(ctx: ClassDefContext) -> None:
form_class_value = helpers.get_assigned_value_for_class(ctx.cls.info, 'form_class') form_class_value = helpers.get_assigned_value_for_class(ctx.cls.info, 'form_class')
if isinstance(form_class_value, NameExpr): if isinstance(form_class_value, NameExpr):
helpers.get_django_metadata(ctx.cls.info)['form_class'] = form_class_value.fullname metadata.get_django_metadata(ctx.cls.info)['form_class'] = form_class_value.fullname
class DjangoPlugin(Plugin): class DjangoPlugin(Plugin):
@@ -137,36 +137,36 @@ class DjangoPlugin(Plugin):
self.django_settings_module = os.environ['DJANGO_SETTINGS_MODULE'] self.django_settings_module = os.environ['DJANGO_SETTINGS_MODULE']
def _get_current_model_bases(self) -> Dict[str, int]: def _get_current_model_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(helpers.MODEL_CLASS_FULLNAME) model_sym = self.lookup_fully_qualified(fullnames.MODEL_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo): if model_sym is not None and isinstance(model_sym.node, TypeInfo):
return (helpers.get_django_metadata(model_sym.node) return (metadata.get_django_metadata(model_sym.node)
.setdefault('model_bases', {helpers.MODEL_CLASS_FULLNAME: 1})) .setdefault('model_bases', {fullnames.MODEL_CLASS_FULLNAME: 1}))
else: else:
return {} return {}
def _get_current_manager_bases(self) -> Dict[str, int]: def _get_current_manager_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(helpers.MANAGER_CLASS_FULLNAME) model_sym = self.lookup_fully_qualified(fullnames.MANAGER_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo): if model_sym is not None and isinstance(model_sym.node, TypeInfo):
return (helpers.get_django_metadata(model_sym.node) return (metadata.get_django_metadata(model_sym.node)
.setdefault('manager_bases', {helpers.MANAGER_CLASS_FULLNAME: 1})) .setdefault('manager_bases', {fullnames.MANAGER_CLASS_FULLNAME: 1}))
else: else:
return {} return {}
def _get_current_form_bases(self) -> Dict[str, int]: def _get_current_form_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(helpers.BASEFORM_CLASS_FULLNAME) model_sym = self.lookup_fully_qualified(fullnames.BASEFORM_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo): if model_sym is not None and isinstance(model_sym.node, TypeInfo):
return (helpers.get_django_metadata(model_sym.node) return (metadata.get_django_metadata(model_sym.node)
.setdefault('baseform_bases', {helpers.BASEFORM_CLASS_FULLNAME: 1, .setdefault('baseform_bases', {fullnames.BASEFORM_CLASS_FULLNAME: 1,
helpers.FORM_CLASS_FULLNAME: 1, fullnames.FORM_CLASS_FULLNAME: 1,
helpers.MODELFORM_CLASS_FULLNAME: 1})) fullnames.MODELFORM_CLASS_FULLNAME: 1}))
else: else:
return {} return {}
def _get_current_queryset_bases(self) -> Dict[str, int]: def _get_current_queryset_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(helpers.QUERYSET_CLASS_FULLNAME) model_sym = self.lookup_fully_qualified(fullnames.QUERYSET_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo): if model_sym is not None and isinstance(model_sym.node, TypeInfo):
return (helpers.get_django_metadata(model_sym.node) return (metadata.get_django_metadata(model_sym.node)
.setdefault('queryset_bases', {helpers.QUERYSET_CLASS_FULLNAME: 1})) .setdefault('queryset_bases', {fullnames.QUERYSET_CLASS_FULLNAME: 1}))
else: else:
return {} return {}
@@ -205,10 +205,10 @@ class DjangoPlugin(Plugin):
info = self._get_typeinfo_or_none(fullname) info = self._get_typeinfo_or_none(fullname)
if info: if info:
if info.has_base(helpers.FIELD_FULLNAME): if info.has_base(fullnames.FIELD_FULLNAME):
return fields.adjust_return_type_of_field_instantiation return fields.process_field_instantiation
if helpers.get_django_metadata(info).get('generated_init'): if metadata.get_django_metadata(info).get('generated_init'):
return init_create.redefine_and_typecheck_model_init return init_create.redefine_and_typecheck_model_init
def get_method_hook(self, fullname: str def get_method_hook(self, fullname: str
@@ -217,22 +217,22 @@ class DjangoPlugin(Plugin):
if method_name == 'get_form_class': if method_name == 'get_form_class':
info = self._get_typeinfo_or_none(class_name) info = self._get_typeinfo_or_none(class_name)
if info and info.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME): if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
return extract_proper_type_for_get_form_class return extract_proper_type_for_get_form_class
if method_name == 'get_form': if method_name == 'get_form':
info = self._get_typeinfo_or_none(class_name) info = self._get_typeinfo_or_none(class_name)
if info and info.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME): if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
return extract_proper_type_for_get_form return extract_proper_type_for_get_form
if method_name == 'values': if method_name == 'values':
model_info = self._get_typeinfo_or_none(class_name) model_info = self._get_typeinfo_or_none(class_name)
if model_info and model_info.has_base(helpers.QUERYSET_CLASS_FULLNAME): if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
return extract_proper_type_for_queryset_values return extract_proper_type_for_queryset_values
if method_name == 'values_list': if method_name == 'values_list':
model_info = self._get_typeinfo_or_none(class_name) model_info = self._get_typeinfo_or_none(class_name)
if model_info and model_info.has_base(helpers.QUERYSET_CLASS_FULLNAME): if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
return extract_proper_type_queryset_values_list return extract_proper_type_queryset_values_list
if fullname in {'django.apps.registry.Apps.get_model', if fullname in {'django.apps.registry.Apps.get_model',
@@ -254,11 +254,11 @@ class DjangoPlugin(Plugin):
if fullname in self._get_current_manager_bases(): if fullname in self._get_current_manager_bases():
return transform_manager_class return transform_manager_class
if fullname in self._get_current_form_bases(): # if fullname in self._get_current_form_bases():
return transform_form_class # return transform_form_class
info = self._get_typeinfo_or_none(fullname) info = self._get_typeinfo_or_none(fullname)
if info and info.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME): if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME):
return transform_form_view return transform_form_view
return None return None
@@ -266,7 +266,7 @@ class DjangoPlugin(Plugin):
def get_attribute_hook(self, fullname: str def get_attribute_hook(self, fullname: str
) -> Optional[Callable[[AttributeContext], Type]]: ) -> Optional[Callable[[AttributeContext], Type]]:
class_name, _, attr_name = fullname.rpartition('.') class_name, _, attr_name = fullname.rpartition('.')
if class_name == helpers.DUMMY_SETTINGS_BASE_CLASS: if class_name == fullnames.DUMMY_SETTINGS_BASE_CLASS:
return partial(get_type_of_setting, return partial(get_type_of_setting,
setting_name=attr_name, setting_name=attr_name,
settings_modules=self._get_settings_modules_in_order_of_priority(), settings_modules=self._get_settings_modules_in_order_of_priority(),
@@ -278,7 +278,7 @@ class DjangoPlugin(Plugin):
model_info = self._get_typeinfo_or_none(class_name) model_info = self._get_typeinfo_or_none(class_name)
if model_info: if model_info:
related_managers = helpers.get_related_managers_metadata(model_info) related_managers = metadata.get_related_managers_metadata(model_info)
if attr_name in related_managers: if attr_name in related_managers:
return partial(determine_type_of_related_manager, return partial(determine_type_of_related_manager,
related_manager_name=attr_name) related_manager_name=attr_name)

View File

@@ -7,7 +7,7 @@ from mypy.types import (
AnyType, CallableType, Instance, TupleType, Type, UnionType, AnyType, CallableType, Instance, TupleType, Type, UnionType,
) )
from mypy_django_plugin import helpers from mypy_django_plugin.lib import metadata, fullnames, helpers
def extract_referred_to_type(ctx: FunctionContext) -> Optional[Instance]: def extract_referred_to_type(ctx: FunctionContext) -> Optional[Instance]:
@@ -43,9 +43,9 @@ def extract_referred_to_type(ctx: FunctionContext) -> Optional[Instance]:
referred_to_type = arg_type.ret_type referred_to_type = arg_type.ret_type
if not isinstance(referred_to_type, Instance): if not isinstance(referred_to_type, Instance):
return None return None
if not referred_to_type.type.has_base(helpers.MODEL_CLASS_FULLNAME): if not referred_to_type.type.has_base(fullnames.MODEL_CLASS_FULLNAME):
ctx.api.msg.fail(f'to= parameter value must be ' ctx.api.msg.fail(f'to= parameter value must be '
f'a subclass of {helpers.MODEL_CLASS_FULLNAME}', f'a subclass of {fullnames.MODEL_CLASS_FULLNAME!r}',
context=ctx.context) context=ctx.context)
return None return None
@@ -118,26 +118,27 @@ def transform_into_proper_return_type(ctx: FunctionContext) -> Type:
if not isinstance(default_return_type, Instance): if not isinstance(default_return_type, Instance):
return default_return_type return default_return_type
if helpers.has_any_of_bases(default_return_type.type, (helpers.FOREIGN_KEY_FULLNAME, if helpers.has_any_of_bases(default_return_type.type, (fullnames.FOREIGN_KEY_FULLNAME,
helpers.ONETOONE_FIELD_FULLNAME, fullnames.ONETOONE_FIELD_FULLNAME,
helpers.MANYTOMANY_FIELD_FULLNAME)): fullnames.MANYTOMANY_FIELD_FULLNAME)):
return fill_descriptor_types_for_related_field(ctx) return fill_descriptor_types_for_related_field(ctx)
if default_return_type.type.has_base(helpers.ARRAY_FIELD_FULLNAME): if default_return_type.type.has_base(fullnames.ARRAY_FIELD_FULLNAME):
return determine_type_of_array_field(ctx) return determine_type_of_array_field(ctx)
return set_descriptor_types_for_field(ctx) return set_descriptor_types_for_field(ctx)
def adjust_return_type_of_field_instantiation(ctx: FunctionContext) -> Type: def process_field_instantiation(ctx: FunctionContext) -> Type:
record_field_properties_into_outer_model_class(ctx) # Parse __init__ parameters of field into corresponding Model's metadata
parse_field_init_arguments_into_model_metadata(ctx)
return transform_into_proper_return_type(ctx) return transform_into_proper_return_type(ctx)
def record_field_properties_into_outer_model_class(ctx: FunctionContext) -> None: def parse_field_init_arguments_into_model_metadata(ctx: FunctionContext) -> None:
api = cast(TypeChecker, ctx.api) api = cast(TypeChecker, ctx.api)
outer_model = api.scope.active_class() outer_model = api.scope.active_class()
if outer_model is None or not outer_model.has_base(helpers.MODEL_CLASS_FULLNAME): if outer_model is None or not outer_model.has_base(fullnames.MODEL_CLASS_FULLNAME):
# outside models.Model class, undetermined # outside models.Model class, undetermined
return return
@@ -149,7 +150,7 @@ def record_field_properties_into_outer_model_class(ctx: FunctionContext) -> None
if field_name is None: if field_name is None:
return return
fields_metadata = helpers.get_fields_metadata(outer_model) fields_metadata = metadata.get_fields_metadata(outer_model)
# primary key # primary key
is_primary_key = False is_primary_key = False

View File

@@ -1,7 +1,7 @@
from mypy.plugin import ClassDefContext, MethodContext from mypy.plugin import ClassDefContext, MethodContext
from mypy.types import CallableType, Instance, NoneTyp, Type, TypeType from mypy.types import CallableType, Instance, NoneTyp, Type, TypeType
from mypy_django_plugin import helpers from mypy_django_plugin.lib import metadata, helpers
def make_meta_nested_class_inherit_from_any(ctx: ClassDefContext) -> None: def make_meta_nested_class_inherit_from_any(ctx: ClassDefContext) -> None:
@@ -19,7 +19,7 @@ def extract_proper_type_for_get_form(ctx: MethodContext) -> Type:
form_class_type = helpers.get_argument_type_by_name(ctx, 'form_class') form_class_type = helpers.get_argument_type_by_name(ctx, 'form_class')
if form_class_type is None or isinstance(form_class_type, NoneTyp): if form_class_type is None or isinstance(form_class_type, NoneTyp):
# extract from specified form_class in metadata # extract from specified form_class in metadata
form_class_fullname = helpers.get_django_metadata(object_type.type).get('form_class', None) form_class_fullname = metadata.get_django_metadata(object_type.type).get('form_class', None)
if not form_class_fullname: if not form_class_fullname:
return ctx.default_return_type return ctx.default_return_type
@@ -39,7 +39,7 @@ def extract_proper_type_for_get_form_class(ctx: MethodContext) -> Type:
if not isinstance(object_type, Instance): if not isinstance(object_type, Instance):
return ctx.default_return_type return ctx.default_return_type
form_class_fullname = helpers.get_django_metadata(object_type.type).get('form_class', None) form_class_fullname = metadata.get_django_metadata(object_type.type).get('form_class', None)
if not form_class_fullname: if not form_class_fullname:
return ctx.default_return_type return ctx.default_return_type

View File

@@ -5,13 +5,13 @@ from mypy.nodes import TypeInfo, Var
from mypy.plugin import FunctionContext, MethodContext from mypy.plugin import FunctionContext, MethodContext
from mypy.types import AnyType, Instance, Type, TypeOfAny from mypy.types import AnyType, Instance, Type, TypeOfAny
from mypy_django_plugin import helpers from mypy_django_plugin.lib import metadata, fullnames, helpers
def extract_base_pointer_args(model: TypeInfo) -> Set[str]: def extract_base_pointer_args(model: TypeInfo) -> Set[str]:
pointer_args: Set[str] = set() pointer_args: Set[str] = set()
for base in model.bases: for base in model.bases:
if base.type.has_base(helpers.MODEL_CLASS_FULLNAME): if base.type.has_base(fullnames.MODEL_CLASS_FULLNAME):
parent_name = base.type.name().lower() parent_name = base.type.name().lower()
pointer_args.add(f'{parent_name}_ptr') pointer_args.add(f'{parent_name}_ptr')
pointer_args.add(f'{parent_name}_ptr_id') pointer_args.add(f'{parent_name}_ptr_id')
@@ -105,7 +105,7 @@ def redefine_and_typecheck_model_create(ctx: MethodContext) -> Type:
def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]: def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]:
field_metadata = helpers.get_fields_metadata(model).get(field_name, {}) field_metadata = metadata.get_fields_metadata(model).get(field_name, {})
if 'choices' in field_metadata: if 'choices' in field_metadata:
return field_metadata['choices'] return field_metadata['choices']
return None return None
@@ -146,8 +146,8 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo,
if field_type is None: if field_type is None:
continue continue
if helpers.has_any_of_bases(typ.type, (helpers.FOREIGN_KEY_FULLNAME, if helpers.has_any_of_bases(typ.type, (fullnames.FOREIGN_KEY_FULLNAME,
helpers.ONETOONE_FIELD_FULLNAME)): fullnames.ONETOONE_FIELD_FULLNAME)):
related_primary_key_type = AnyType(TypeOfAny.implementation_artifact) related_primary_key_type = AnyType(TypeOfAny.implementation_artifact)
# in case it's optional, we need Instance type # in case it's optional, we need Instance type
referred_to_model = typ.args[1] referred_to_model = typ.args[1]
@@ -156,7 +156,7 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo,
referred_to_model = helpers.make_required(typ.args[1]) referred_to_model = helpers.make_required(typ.args[1])
if isinstance(referred_to_model, Instance) and referred_to_model.type.has_base( if isinstance(referred_to_model, Instance) and referred_to_model.type.has_base(
helpers.MODEL_CLASS_FULLNAME): fullnames.MODEL_CLASS_FULLNAME):
pk_type = helpers.extract_explicit_set_type_of_model_primary_key(referred_to_model.type) pk_type = helpers.extract_explicit_set_type_of_model_primary_key(referred_to_model.type)
if not pk_type: if not pk_type:
# extract set type of AutoField # extract set type of AutoField
@@ -170,11 +170,11 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo,
expected_types[name + '_id'] = related_primary_key_type expected_types[name + '_id'] = related_primary_key_type
field_metadata = helpers.get_fields_metadata(model).get(name, {}) field_metadata = metadata.get_fields_metadata(model).get(name, {})
if field_type: if field_type:
# related fields could be None in __init__ (but should be specified before save()) # related fields could be None in __init__ (but should be specified before save())
if helpers.has_any_of_bases(typ.type, (helpers.FOREIGN_KEY_FULLNAME, if helpers.has_any_of_bases(typ.type, (fullnames.FOREIGN_KEY_FULLNAME,
helpers.ONETOONE_FIELD_FULLNAME)) and is_init: fullnames.ONETOONE_FIELD_FULLNAME)) and is_init:
field_type = helpers.make_optional(field_type) field_type = helpers.make_optional(field_type)
# if primary_key=True and default specified # if primary_key=True and default specified
@@ -184,7 +184,7 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo,
# if CharField(blank=True,...) and not nullable, then field can be None in __init__ # if CharField(blank=True,...) and not nullable, then field can be None in __init__
elif ( elif (
helpers.has_any_of_bases(typ.type, (helpers.CHAR_FIELD_FULLNAME,)) and is_init and helpers.has_any_of_bases(typ.type, (fullnames.CHAR_FIELD_FULLNAME,)) and is_init and
field_metadata.get('blank', False) and not field_metadata.get('null', False) field_metadata.get('blank', False) and not field_metadata.get('null', False)
): ):
field_type = helpers.make_optional(field_type) field_type = helpers.make_optional(field_type)

View File

@@ -5,7 +5,7 @@ from mypy.nodes import Expression, StrExpr, TypeInfo
from mypy.plugin import MethodContext from mypy.plugin import MethodContext
from mypy.types import Instance, Type, TypeType from mypy.types import Instance, Type, TypeType
from mypy_django_plugin import helpers from mypy_django_plugin.lib import helpers
def get_string_value_from_expr(expr: Expression) -> Optional[str]: def get_string_value_from_expr(expr: Expression) -> Optional[str]:

View File

@@ -11,7 +11,7 @@ from mypy.plugins.common import add_method
from mypy.semanal import SemanticAnalyzerPass2 from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import AnyType, Instance, NoneTyp, TypeOfAny from mypy.types import AnyType, Instance, NoneTyp, TypeOfAny
from mypy_django_plugin import helpers from mypy_django_plugin.lib import metadata, fullnames, helpers
@dataclasses.dataclass @dataclasses.dataclass
@@ -55,8 +55,8 @@ def iter_over_one_to_n_related_fields(klass: ClassDef) -> Iterator[Tuple[NameExp
for lvalue, rvalue in helpers.iter_call_assignments(klass): for lvalue, rvalue in helpers.iter_call_assignments(klass):
if (isinstance(lvalue, NameExpr) if (isinstance(lvalue, NameExpr)
and isinstance(rvalue.callee, MemberExpr)): and isinstance(rvalue.callee, MemberExpr)):
if rvalue.callee.fullname in {helpers.FOREIGN_KEY_FULLNAME, if rvalue.callee.fullname in {fullnames.FOREIGN_KEY_FULLNAME,
helpers.ONETOONE_FIELD_FULLNAME}: fullnames.ONETOONE_FIELD_FULLNAME}:
yield lvalue, rvalue yield lvalue, rvalue
@@ -97,7 +97,7 @@ class AddDefaultObjectsManager(ModelClassInitializer):
callee_expr = callee_expr.analyzed.expr callee_expr = callee_expr.analyzed.expr
if isinstance(callee_expr, (MemberExpr, NameExpr)) \ if isinstance(callee_expr, (MemberExpr, NameExpr)) \
and isinstance(callee_expr.node, TypeInfo) \ and isinstance(callee_expr.node, TypeInfo) \
and callee_expr.node.has_base(helpers.BASE_MANAGER_CLASS_FULLNAME): and callee_expr.node.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME):
managers.append((manager_name, callee_expr.node)) managers.append((manager_name, callee_expr.node))
return managers return managers
@@ -115,7 +115,7 @@ class AddDefaultObjectsManager(ModelClassInitializer):
# abstract models do not need 'objects' queryset # abstract models do not need 'objects' queryset
return None return None
first_manager_type = self.api.named_type_or_none(helpers.MANAGER_CLASS_FULLNAME, first_manager_type = self.api.named_type_or_none(fullnames.MANAGER_CLASS_FULLNAME,
args=[Instance(self.model_classdef.info, [])]) args=[Instance(self.model_classdef.info, [])])
self.add_new_manager('objects', manager_type=first_manager_type) self.add_new_manager('objects', manager_type=first_manager_type)
@@ -148,16 +148,19 @@ class AddRelatedManagers(ModelClassInitializer):
self.add_new_node_to_model_class(manager_name, self.api.builtin_type('builtins.object')) self.add_new_node_to_model_class(manager_name, self.api.builtin_type('builtins.object'))
# save name in metadata for use in get_attribute_hook later # save name in metadata for use in get_attribute_hook later
related_managers_metadata = helpers.get_related_managers_metadata(self.model_classdef.info) related_managers_metadata = metadata.get_related_managers_metadata(self.model_classdef.info)
related_managers_metadata[manager_name] = related_field_type_data related_managers_metadata[manager_name] = related_field_type_data
def run(self) -> None: def run(self) -> None:
for module_name, module_file in self.api.modules.items(): for module_name, module_file in self.api.modules.items():
for model_defn in helpers.iter_over_classdefs(module_file): for model_defn in helpers.iter_over_classdefs(module_file):
for lvalue, rvalue in helpers.iter_call_assignments(model_defn): if not model_defn.info:
if is_related_field(rvalue, module_file): self.api.defer()
for lvalue, field_init in helpers.iter_call_assignments(model_defn):
if is_related_field(field_init, module_file):
try: try:
referenced_model_fullname = extract_ref_to_fullname(rvalue, referenced_model_fullname = extract_referenced_model_fullname(field_init,
module_file=module_file, module_file=module_file,
all_modules=self.api.modules) all_modules=self.api.modules)
except helpers.SelfReference: except helpers.SelfReference:
@@ -168,39 +171,33 @@ class AddRelatedManagers(ModelClassInitializer):
if self.model_classdef.fullname == referenced_model_fullname: if self.model_classdef.fullname == referenced_model_fullname:
related_name = model_defn.name.lower() + '_set' related_name = model_defn.name.lower() + '_set'
if 'related_name' in rvalue.arg_names: if 'related_name' in field_init.arg_names:
related_name_expr = rvalue.args[rvalue.arg_names.index('related_name')] related_name_expr = field_init.args[field_init.arg_names.index('related_name')]
if not isinstance(related_name_expr, StrExpr): if not isinstance(related_name_expr, StrExpr):
# not string 'related_name=' not yet supported
continue continue
related_name = related_name_expr.value related_name = related_name_expr.value
if related_name == '+': if related_name == '+':
# No backwards relation is desired # No backwards relation is desired
continue continue
if 'related_query_name' in rvalue.arg_names: # Default related_query_name to related_name
related_query_name_expr = rvalue.args[rvalue.arg_names.index('related_query_name')]
if not isinstance(related_query_name_expr, StrExpr):
related_query_name = None
else:
related_query_name = related_query_name_expr.value
# TODO: Handle defaulting to model name if related_name is not set
else:
# No related_query_name specified, default to related_name
related_query_name = related_name related_query_name = related_name
if 'related_query_name' in field_init.arg_names:
related_query_name_expr = field_init.args[field_init.arg_names.index('related_query_name')]
if isinstance(related_query_name_expr, StrExpr):
related_query_name = related_query_name_expr.value
else:
# not string 'related_query_name=' is not yet supported
related_query_name = None
# TODO: Handle defaulting to model name if related_name is not set
# field_type_data = get_related_field_type(rvalue, self.api, defn.info)
# if typ is None:
# continue
# TODO: recursively serialize types, or just https://github.com/python/mypy/issues/6506
# as long as Model is not a Generic, one level depth is fine # as long as Model is not a Generic, one level depth is fine
if rvalue.callee.name in {'ForeignKey', 'ManyToManyField'}: if field_init.callee.name in {'ForeignKey', 'ManyToManyField'}:
field_type_data = { field_type_data = {
'manager': helpers.RELATED_MANAGER_CLASS_FULLNAME, 'manager': fullnames.RELATED_MANAGER_CLASS_FULLNAME,
'of': [model_defn.info.fullname()] 'of': [model_defn.info.fullname()]
} }
# return api.named_type_or_none(helpers.RELATED_MANAGER_CLASS_FULLNAME,
# args=[Instance(related_model_typ, [])])
else: else:
field_type_data = { field_type_data = {
'manager': model_defn.info.fullname(), 'manager': model_defn.info.fullname(),
@@ -211,7 +208,7 @@ class AddRelatedManagers(ModelClassInitializer):
if related_query_name is not None: if related_query_name is not None:
# Only create related_query_name if it is a string literal # Only create related_query_name if it is a string literal
helpers.get_lookups_metadata(self.model_classdef.info)[related_query_name] = { metadata.get_lookups_metadata(self.model_classdef.info)[related_query_name] = {
'related_query_name_target': related_name 'related_query_name_target': related_name
} }
@@ -219,20 +216,18 @@ class AddRelatedManagers(ModelClassInitializer):
def get_related_field_type(rvalue: CallExpr, related_model_typ: TypeInfo) -> Dict[str, Any]: def get_related_field_type(rvalue: CallExpr, related_model_typ: TypeInfo) -> Dict[str, Any]:
if rvalue.callee.name in {'ForeignKey', 'ManyToManyField'}: if rvalue.callee.name in {'ForeignKey', 'ManyToManyField'}:
return { return {
'manager': helpers.RELATED_MANAGER_CLASS_FULLNAME, 'manager': fullnames.RELATED_MANAGER_CLASS_FULLNAME,
'of': [related_model_typ.fullname()] 'of': [related_model_typ.fullname()]
} }
# return api.named_type_or_none(helpers.RELATED_MANAGER_CLASS_FULLNAME,
# args=[Instance(related_model_typ, [])])
else: else:
return { return {
'manager': related_model_typ.fullname(), 'manager': related_model_typ.fullname(),
'of': [] 'of': []
} }
# return Instance(related_model_typ, [])
def is_related_field(expr: CallExpr, module_file: MypyFile) -> bool: def is_related_field(expr: CallExpr, module_file: MypyFile) -> bool:
""" Checks whether current CallExpr represents any supported RelatedField subclass"""
if isinstance(expr.callee, MemberExpr) and isinstance(expr.callee.expr, NameExpr): if isinstance(expr.callee, MemberExpr) and isinstance(expr.callee.expr, NameExpr):
module = module_file.names.get(expr.callee.expr.name) module = module_file.names.get(expr.callee.expr.name)
if module \ if module \
@@ -244,12 +239,15 @@ def is_related_field(expr: CallExpr, module_file: MypyFile) -> bool:
return False return False
def extract_ref_to_fullname(rvalue_expr: CallExpr, def extract_referenced_model_fullname(rvalue_expr: CallExpr,
module_file: MypyFile, all_modules: Dict[str, MypyFile]) -> Optional[str]: module_file: MypyFile,
all_modules: Dict[str, MypyFile]) -> Optional[str]:
""" Returns fullname of a Model referenced in "to=" argument of the CallExpr"""
if 'to' in rvalue_expr.arg_names: if 'to' in rvalue_expr.arg_names:
to_expr = rvalue_expr.args[rvalue_expr.arg_names.index('to')] to_expr = rvalue_expr.args[rvalue_expr.arg_names.index('to')]
else: else:
to_expr = rvalue_expr.args[0] to_expr = rvalue_expr.args[0]
if isinstance(to_expr, NameExpr): if isinstance(to_expr, NameExpr):
return module_file.names[to_expr.name].fullname return module_file.names[to_expr.name].fullname
elif isinstance(to_expr, StrExpr): elif isinstance(to_expr, StrExpr):

View File

@@ -8,8 +8,8 @@ from mypy.plugin import (
) )
from mypy.types import AnyType, Instance, Type, TypeOfAny from mypy.types import AnyType, Instance, Type, TypeOfAny
from mypy_django_plugin import helpers from mypy_django_plugin.lib import helpers
from mypy_django_plugin.lookups import ( from mypy_django_plugin.lib.lookups import (
LookupException, RelatedModelNode, resolve_lookup, LookupException, RelatedModelNode, resolve_lookup,
) )

View File

@@ -4,7 +4,7 @@ from mypy.checkmember import AttributeContext
from mypy.nodes import TypeInfo from mypy.nodes import TypeInfo
from mypy.types import AnyType, Instance, Type, TypeOfAny, UnionType from mypy.types import AnyType, Instance, Type, TypeOfAny, UnionType
from mypy_django_plugin import helpers from mypy_django_plugin.lib import fullnames, helpers
def _extract_referred_to_type_info(typ: Union[UnionType, Instance]) -> Optional[TypeInfo]: def _extract_referred_to_type_info(typ: Union[UnionType, Instance]) -> Optional[TypeInfo]:
@@ -22,7 +22,7 @@ def extract_and_return_primary_key_of_bound_related_field_parameter(ctx: Attribu
if not isinstance(ctx.default_attr_type, Instance) or not (ctx.default_attr_type.type.fullname() == 'builtins.int'): if not isinstance(ctx.default_attr_type, Instance) or not (ctx.default_attr_type.type.fullname() == 'builtins.int'):
return ctx.default_attr_type return ctx.default_attr_type
if not isinstance(ctx.type, Instance) or not ctx.type.type.has_base(helpers.MODEL_CLASS_FULLNAME): if not isinstance(ctx.type, Instance) or not ctx.type.type.has_base(fullnames.MODEL_CLASS_FULLNAME):
return ctx.default_attr_type return ctx.default_attr_type
field_name = ctx.context.name.split('_')[0] field_name = ctx.context.name.split('_')[0]

View File

@@ -5,7 +5,7 @@ from mypy.checkmember import AttributeContext
from mypy.nodes import NameExpr, StrExpr, SymbolTableNode, TypeInfo from mypy.nodes import NameExpr, StrExpr, SymbolTableNode, TypeInfo
from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeType from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeType
from mypy_django_plugin import helpers from mypy_django_plugin.lib import helpers
if TYPE_CHECKING: if TYPE_CHECKING:
from mypy.checker import TypeChecker from mypy.checker import TypeChecker

View File

@@ -1,27 +1,3 @@
[CASE array_field_descriptor_access]
from django.db import models
from django.contrib.postgres.fields import ArrayField
class User(models.Model):
array = ArrayField(base_field=models.Field())
user = User()
reveal_type(user.array) # N: Revealed type is 'builtins.list*[Any]'
[/CASE]
[CASE array_field_base_field_parsed_into_generic_typevar]
from django.db import models
from django.contrib.postgres.fields import ArrayField
class User(models.Model):
members = ArrayField(base_field=models.IntegerField())
members_as_text = ArrayField(base_field=models.CharField(max_length=255))
user = User()
reveal_type(user.members) # N: Revealed type is 'builtins.list*[builtins.int]'
reveal_type(user.members_as_text) # N: Revealed type is 'builtins.list*[builtins.str]'
[/CASE]
[CASE test_model_fields_classes_present_as_primitives] [CASE test_model_fields_classes_present_as_primitives]
from django.db import models from django.db import models
@@ -142,3 +118,27 @@ MyModel(notnulltext="")
MyModel().notnulltext = None # E: Incompatible types in assignment (expression has type "None", variable has type "Union[str, int, Combinable]") MyModel().notnulltext = None # E: Incompatible types in assignment (expression has type "None", variable has type "Union[str, int, Combinable]")
reveal_type(MyModel().notnulltext) # N: Revealed type is 'builtins.str*' reveal_type(MyModel().notnulltext) # N: Revealed type is 'builtins.str*'
[/CASE] [/CASE]
[CASE array_field_descriptor_access]
from django.db import models
from django.contrib.postgres.fields import ArrayField
class User(models.Model):
array = ArrayField(base_field=models.Field())
user = User()
reveal_type(user.array) # N: Revealed type is 'builtins.list*[Any]'
[/CASE]
[CASE array_field_base_field_parsed_into_generic_typevar]
from django.db import models
from django.contrib.postgres.fields import ArrayField
class User(models.Model):
members = ArrayField(base_field=models.IntegerField())
members_as_text = ArrayField(base_field=models.CharField(max_length=255))
user = User()
reveal_type(user.members) # N: Revealed type is 'builtins.list*[builtins.int]'
reveal_type(user.members_as_text) # N: Revealed type is 'builtins.list*[builtins.str]'
[/CASE]

View File

@@ -173,6 +173,7 @@ class User(models.Model):
[CASE models_triple_circular_reference] [CASE models_triple_circular_reference]
from myapp.models import App from myapp.models import App
reveal_type(App().owner) # N: Revealed type is 'myapp.models.user.User'
reveal_type(App().owner.profile) # N: Revealed type is 'myapp.models.profile.Profile' reveal_type(App().owner.profile) # N: Revealed type is 'myapp.models.profile.Profile'
[file myapp/__init__.py] [file myapp/__init__.py]