diff --git a/django-stubs/db/models/expressions.pyi b/django-stubs/db/models/expressions.pyi index 0dbcf0b..def0086 100644 --- a/django-stubs/db/models/expressions.pyi +++ b/django-stubs/db/models/expressions.pyi @@ -52,7 +52,7 @@ class BaseExpression: is_summary: bool = ... filterable: bool = ... window_compatible: bool = ... - output_field: Any + output_field: Field def __init__(self, output_field: Optional[_OutputField] = ...) -> None: ... def get_db_converters(self, connection: Any) -> List[Callable]: ... def get_source_expressions(self) -> List[Any]: ... @@ -74,8 +74,6 @@ class BaseExpression: @property def field(self) -> Field: ... @property - def output_field(self) -> Field: ... - @property def convert_value(self) -> Callable: ... def get_lookup(self, lookup: str) -> Optional[Type[Lookup]]: ... def get_transform(self, name: str) -> Optional[Type[Expression]]: ... diff --git a/django-stubs/template/response.pyi b/django-stubs/template/response.pyi index e68bd8d..69aee5d 100644 --- a/django-stubs/template/response.pyi +++ b/django-stubs/template/response.pyi @@ -14,6 +14,7 @@ from django.http import HttpResponse class ContentNotRenderedError(Exception): ... class SimpleTemplateResponse(HttpResponse): + content: Any = ... closed: bool cookies: SimpleCookie status_code: int @@ -35,15 +36,10 @@ class SimpleTemplateResponse(HttpResponse): @property def rendered_content(self) -> str: ... def add_post_render_callback(self, callback: Callable) -> None: ... - content: Any = ... def render(self) -> SimpleTemplateResponse: ... @property def is_rendered(self) -> bool: ... def __iter__(self) -> Any: ... - @property - def content(self): ... - @content.setter - def content(self, value: Any) -> None: ... class TemplateResponse(SimpleTemplateResponse): client: Client diff --git a/mypy_django_plugin/lib/__init__.py b/mypy_django_plugin/lib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mypy_django_plugin/lib/fullnames.py b/mypy_django_plugin/lib/fullnames.py new file mode 100644 index 0000000..f686acf --- /dev/null +++ b/mypy_django_plugin/lib/fullnames.py @@ -0,0 +1,29 @@ + +MODEL_CLASS_FULLNAME = 'django.db.models.base.Model' +FIELD_FULLNAME = 'django.db.models.fields.Field' +CHAR_FIELD_FULLNAME = 'django.db.models.fields.CharField' +ARRAY_FIELD_FULLNAME = 'django.contrib.postgres.fields.array.ArrayField' +AUTO_FIELD_FULLNAME = 'django.db.models.fields.AutoField' +GENERIC_FOREIGN_KEY_FULLNAME = 'django.contrib.contenttypes.fields.GenericForeignKey' +FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey' +ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField' +MANYTOMANY_FIELD_FULLNAME = 'django.db.models.fields.related.ManyToManyField' +DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject' + +QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet' +BASE_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.BaseManager' +MANAGER_CLASS_FULLNAME = 'django.db.models.manager.Manager' +RELATED_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.RelatedManager' + +BASEFORM_CLASS_FULLNAME = 'django.forms.forms.BaseForm' +FORM_CLASS_FULLNAME = 'django.forms.forms.Form' +MODELFORM_CLASS_FULLNAME = 'django.forms.models.ModelForm' + +FORM_MIXIN_CLASS_FULLNAME = 'django.views.generic.edit.FormMixin' + +MANAGER_CLASSES = { + MANAGER_CLASS_FULLNAME, + RELATED_MANAGER_CLASS_FULLNAME, + BASE_MANAGER_CLASS_FULLNAME, + QUERYSET_CLASS_FULLNAME +} \ No newline at end of file diff --git a/mypy_django_plugin/helpers.py b/mypy_django_plugin/lib/helpers.py similarity index 83% rename from mypy_django_plugin/helpers.py rename to mypy_django_plugin/lib/helpers.py index abdc086..ac7e92e 100644 --- a/mypy_django_plugin/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -12,38 +12,11 @@ from mypy.types import ( AnyType, Instance, NoneTyp, TupleType, Type, TypedDictType, TypeOfAny, TypeVarType, UnionType, ) +from mypy_django_plugin.lib import metadata, fullnames + if typing.TYPE_CHECKING: from mypy.checker import TypeChecker -MODEL_CLASS_FULLNAME = 'django.db.models.base.Model' -FIELD_FULLNAME = 'django.db.models.fields.Field' -CHAR_FIELD_FULLNAME = 'django.db.models.fields.CharField' -ARRAY_FIELD_FULLNAME = 'django.contrib.postgres.fields.array.ArrayField' -AUTO_FIELD_FULLNAME = 'django.db.models.fields.AutoField' -GENERIC_FOREIGN_KEY_FULLNAME = 'django.contrib.contenttypes.fields.GenericForeignKey' -FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey' -ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField' -MANYTOMANY_FIELD_FULLNAME = 'django.db.models.fields.related.ManyToManyField' -DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject' - -QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet' -BASE_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.BaseManager' -MANAGER_CLASS_FULLNAME = 'django.db.models.manager.Manager' -RELATED_MANAGER_CLASS_FULLNAME = 'django.db.models.manager.RelatedManager' - -BASEFORM_CLASS_FULLNAME = 'django.forms.forms.BaseForm' -FORM_CLASS_FULLNAME = 'django.forms.forms.Form' -MODELFORM_CLASS_FULLNAME = 'django.forms.models.ModelForm' - -FORM_MIXIN_CLASS_FULLNAME = 'django.views.generic.edit.FormMixin' - -MANAGER_CLASSES = { - MANAGER_CLASS_FULLNAME, - RELATED_MANAGER_CLASS_FULLNAME, - BASE_MANAGER_CLASS_FULLNAME, - QUERYSET_CLASS_FULLNAME -} - def get_models_file(app_name: str, all_modules: typing.Dict[str, MypyFile]) -> Optional[MypyFile]: models_module = '.'.join([app_name, 'models']) @@ -208,10 +181,10 @@ def iter_over_assignments(class_or_module: typing.Union[ClassDef, MypyFile] def extract_field_setter_type(tp: Instance) -> Optional[Type]: """ Extract __set__ value of a field. """ - if tp.type.has_base(FIELD_FULLNAME): + if tp.type.has_base(fullnames.FIELD_FULLNAME): return tp.args[0] # GenericForeignKey - if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME): + if tp.type.has_base(fullnames.GENERIC_FOREIGN_KEY_FULLNAME): return AnyType(TypeOfAny.special_form) return None @@ -220,39 +193,19 @@ def extract_field_getter_type(tp: Type) -> Optional[Type]: """ Extract return type of __get__ of subclass of Field""" if not isinstance(tp, Instance): return None - if tp.type.has_base(FIELD_FULLNAME): + if tp.type.has_base(fullnames.FIELD_FULLNAME): return tp.args[1] # GenericForeignKey - if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME): + if tp.type.has_base(fullnames.GENERIC_FOREIGN_KEY_FULLNAME): return AnyType(TypeOfAny.special_form) return None -def get_django_metadata(model_info: TypeInfo) -> Dict[str, typing.Any]: - return model_info.metadata.setdefault('django', {}) - - -def get_related_field_primary_key_names(base_model: TypeInfo) -> typing.List[str]: - return get_django_metadata(base_model).setdefault('related_field_primary_keys', []) - - -def get_fields_metadata(model: TypeInfo) -> Dict[str, typing.Any]: - return get_django_metadata(model).setdefault('fields', {}) - - -def get_lookups_metadata(model: TypeInfo) -> Dict[str, typing.Any]: - return get_django_metadata(model).setdefault('lookups', {}) - - -def get_related_managers_metadata(model: TypeInfo) -> Dict[str, typing.Any]: - return get_django_metadata(model).setdefault('related_managers', {}) - - def extract_explicit_set_type_of_model_primary_key(model: TypeInfo) -> Optional[Type]: """ If field with primary_key=True is set on the model, extract its __set__ type. """ - for field_name, props in get_fields_metadata(model).items(): + for field_name, props in metadata.get_fields_metadata(model).items(): is_primary_key = props.get('primary_key', False) if is_primary_key: return extract_field_setter_type(model.names[field_name].type) @@ -260,7 +213,7 @@ def extract_explicit_set_type_of_model_primary_key(model: TypeInfo) -> Optional[ def extract_primary_key_type_for_get(model: TypeInfo) -> Optional[Type]: - for field_name, props in get_fields_metadata(model).items(): + for field_name, props in metadata.get_fields_metadata(model).items(): is_primary_key = props.get('primary_key', False) if is_primary_key: return extract_field_getter_type(model.names[field_name].type) @@ -312,13 +265,13 @@ def get_assigned_value_for_class(type_info: TypeInfo, name: str) -> Optional[Exp def is_field_nullable(model: TypeInfo, field_name: str) -> bool: - return get_fields_metadata(model).get(field_name, {}).get('null', False) + return metadata.get_fields_metadata(model).get(field_name, {}).get('null', False) def is_foreign_key_like(t: Type) -> bool: if not isinstance(t, Instance): return False - return has_any_of_bases(t.type, (FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME)) + return has_any_of_bases(t.type, (fullnames.FOREIGN_KEY_FULLNAME, fullnames.ONETOONE_FIELD_FULLNAME)) def build_class_with_annotated_fields(api: 'TypeChecker', base: Type, fields: 'OrderedDict[str, Type]', @@ -408,7 +361,7 @@ def iter_call_assignments(klass: ClassDef) -> typing.Iterator[typing.Tuple[Lvalu def get_related_manager_type_from_metadata(model_info: TypeInfo, related_manager_name: str, api: CheckerPluginInterface) -> Optional[Instance]: - related_manager_metadata = get_related_managers_metadata(model_info) + related_manager_metadata = metadata.get_related_managers_metadata(model_info) if not related_manager_metadata: return None @@ -435,7 +388,7 @@ def get_related_manager_type_from_metadata(model_info: TypeInfo, related_manager def get_primary_key_field_name(model_info: TypeInfo) -> Optional[str]: for base in model_info.mro: - fields = get_fields_metadata(base) + fields = metadata.get_fields_metadata(base) for field_name, field_props in fields.items(): is_primary_key = field_props.get('primary_key', False) if is_primary_key: diff --git a/mypy_django_plugin/lookups.py b/mypy_django_plugin/lib/lookups.py similarity index 96% rename from mypy_django_plugin/lookups.py rename to mypy_django_plugin/lib/lookups.py index a8eb54a..27035ba 100644 --- a/mypy_django_plugin/lookups.py +++ b/mypy_django_plugin/lib/lookups.py @@ -5,7 +5,7 @@ from mypy.nodes import TypeInfo from mypy.plugin import CheckerPluginInterface from mypy.types import Instance, Type -from mypy_django_plugin import helpers +from mypy_django_plugin.lib import metadata, helpers @dataclasses.dataclass @@ -138,12 +138,12 @@ def get_actual_field_name_for_lookup_field(lookup: str, model_type_info: TypeInf If it's not, return the original lookup. """ - lookups_metadata = helpers.get_lookups_metadata(model_type_info) + lookups_metadata = metadata.get_lookups_metadata(model_type_info) lookup_metadata = lookups_metadata.get(lookup) if lookup_metadata is None: # If not found on current model, look in all bases for their lookup metadata for base in model_type_info.mro: - lookups_metadata = helpers.get_lookups_metadata(base) + lookups_metadata = metadata.get_lookups_metadata(base) lookup_metadata = lookups_metadata.get(lookup) if lookup_metadata: break diff --git a/mypy_django_plugin/lib/metadata.py b/mypy_django_plugin/lib/metadata.py new file mode 100644 index 0000000..670f769 --- /dev/null +++ b/mypy_django_plugin/lib/metadata.py @@ -0,0 +1,23 @@ +from typing import Any, Dict, List + +from mypy.nodes import TypeInfo + + +def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]: + return model_info.metadata.setdefault('django', {}) + + +def get_related_field_primary_key_names(base_model: TypeInfo) -> List[str]: + return get_django_metadata(base_model).setdefault('related_field_primary_keys', []) + + +def get_fields_metadata(model: TypeInfo) -> Dict[str, Any]: + return get_django_metadata(model).setdefault('fields', {}) + + +def get_lookups_metadata(model: TypeInfo) -> Dict[str, Any]: + return get_django_metadata(model).setdefault('lookups', {}) + + +def get_related_managers_metadata(model: TypeInfo) -> Dict[str, Any]: + return get_django_metadata(model).setdefault('related_managers', {}) diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index f31b853..4c77652 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -9,7 +9,7 @@ from mypy.plugin import ( ) from mypy.types import AnyType, Instance, Type, TypeOfAny -from mypy_django_plugin import helpers +from mypy_django_plugin.lib import metadata, fullnames, helpers from mypy_django_plugin.config import Config from mypy_django_plugin.transformers import fields, init_create from mypy_django_plugin.transformers.forms import ( @@ -33,27 +33,27 @@ from mypy_django_plugin.transformers.settings import ( def transform_model_class(ctx: ClassDefContext, ignore_missing_model_attributes: bool) -> None: try: - sym = ctx.api.lookup_fully_qualified(helpers.MODEL_CLASS_FULLNAME) + sym = ctx.api.lookup_fully_qualified(fullnames.MODEL_CLASS_FULLNAME) except KeyError: # models.Model is not loaded, skip metadata model write pass else: if sym is not None and isinstance(sym.node, TypeInfo): - helpers.get_django_metadata(sym.node)['model_bases'][ctx.cls.fullname] = 1 + metadata.get_django_metadata(sym.node)['model_bases'][ctx.cls.fullname] = 1 process_model_class(ctx, ignore_missing_model_attributes) def transform_manager_class(ctx: ClassDefContext) -> None: - sym = ctx.api.lookup_fully_qualified_or_none(helpers.MANAGER_CLASS_FULLNAME) + sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME) if sym is not None and isinstance(sym.node, TypeInfo): - helpers.get_django_metadata(sym.node)['manager_bases'][ctx.cls.fullname] = 1 + metadata.get_django_metadata(sym.node)['manager_bases'][ctx.cls.fullname] = 1 def transform_form_class(ctx: ClassDefContext) -> None: - sym = ctx.api.lookup_fully_qualified_or_none(helpers.BASEFORM_CLASS_FULLNAME) + sym = ctx.api.lookup_fully_qualified_or_none(fullnames.BASEFORM_CLASS_FULLNAME) if sym is not None and isinstance(sym.node, TypeInfo): - helpers.get_django_metadata(sym.node)['baseform_bases'][ctx.cls.fullname] = 1 + metadata.get_django_metadata(sym.node)['baseform_bases'][ctx.cls.fullname] = 1 make_meta_nested_class_inherit_from_any(ctx) @@ -67,16 +67,16 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type: # not in class return ret outer_model_info = api.tscope.classes[0] - if not outer_model_info.has_base(helpers.MODEL_CLASS_FULLNAME): + if not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME): return ret if not isinstance(ret, Instance): return ret has_manager_base = False for i, base in enumerate(ret.type.bases): - if base.type.fullname() in {helpers.MANAGER_CLASS_FULLNAME, - helpers.RELATED_MANAGER_CLASS_FULLNAME, - helpers.BASE_MANAGER_CLASS_FULLNAME}: + if base.type.fullname() in {fullnames.MANAGER_CLASS_FULLNAME, + fullnames.RELATED_MANAGER_CLASS_FULLNAME, + fullnames.BASE_MANAGER_CLASS_FULLNAME}: has_manager_base = True break @@ -118,7 +118,7 @@ def return_type_for_id_field(ctx: AttributeContext) -> Type: def transform_form_view(ctx: ClassDefContext) -> None: form_class_value = helpers.get_assigned_value_for_class(ctx.cls.info, 'form_class') if isinstance(form_class_value, NameExpr): - helpers.get_django_metadata(ctx.cls.info)['form_class'] = form_class_value.fullname + metadata.get_django_metadata(ctx.cls.info)['form_class'] = form_class_value.fullname class DjangoPlugin(Plugin): @@ -137,36 +137,36 @@ class DjangoPlugin(Plugin): self.django_settings_module = os.environ['DJANGO_SETTINGS_MODULE'] def _get_current_model_bases(self) -> Dict[str, int]: - model_sym = self.lookup_fully_qualified(helpers.MODEL_CLASS_FULLNAME) + model_sym = self.lookup_fully_qualified(fullnames.MODEL_CLASS_FULLNAME) if model_sym is not None and isinstance(model_sym.node, TypeInfo): - return (helpers.get_django_metadata(model_sym.node) - .setdefault('model_bases', {helpers.MODEL_CLASS_FULLNAME: 1})) + return (metadata.get_django_metadata(model_sym.node) + .setdefault('model_bases', {fullnames.MODEL_CLASS_FULLNAME: 1})) else: return {} def _get_current_manager_bases(self) -> Dict[str, int]: - model_sym = self.lookup_fully_qualified(helpers.MANAGER_CLASS_FULLNAME) + model_sym = self.lookup_fully_qualified(fullnames.MANAGER_CLASS_FULLNAME) if model_sym is not None and isinstance(model_sym.node, TypeInfo): - return (helpers.get_django_metadata(model_sym.node) - .setdefault('manager_bases', {helpers.MANAGER_CLASS_FULLNAME: 1})) + return (metadata.get_django_metadata(model_sym.node) + .setdefault('manager_bases', {fullnames.MANAGER_CLASS_FULLNAME: 1})) else: return {} def _get_current_form_bases(self) -> Dict[str, int]: - model_sym = self.lookup_fully_qualified(helpers.BASEFORM_CLASS_FULLNAME) + model_sym = self.lookup_fully_qualified(fullnames.BASEFORM_CLASS_FULLNAME) if model_sym is not None and isinstance(model_sym.node, TypeInfo): - return (helpers.get_django_metadata(model_sym.node) - .setdefault('baseform_bases', {helpers.BASEFORM_CLASS_FULLNAME: 1, - helpers.FORM_CLASS_FULLNAME: 1, - helpers.MODELFORM_CLASS_FULLNAME: 1})) + return (metadata.get_django_metadata(model_sym.node) + .setdefault('baseform_bases', {fullnames.BASEFORM_CLASS_FULLNAME: 1, + fullnames.FORM_CLASS_FULLNAME: 1, + fullnames.MODELFORM_CLASS_FULLNAME: 1})) else: return {} def _get_current_queryset_bases(self) -> Dict[str, int]: - model_sym = self.lookup_fully_qualified(helpers.QUERYSET_CLASS_FULLNAME) + model_sym = self.lookup_fully_qualified(fullnames.QUERYSET_CLASS_FULLNAME) if model_sym is not None and isinstance(model_sym.node, TypeInfo): - return (helpers.get_django_metadata(model_sym.node) - .setdefault('queryset_bases', {helpers.QUERYSET_CLASS_FULLNAME: 1})) + return (metadata.get_django_metadata(model_sym.node) + .setdefault('queryset_bases', {fullnames.QUERYSET_CLASS_FULLNAME: 1})) else: return {} @@ -205,10 +205,10 @@ class DjangoPlugin(Plugin): info = self._get_typeinfo_or_none(fullname) if info: - if info.has_base(helpers.FIELD_FULLNAME): - return fields.adjust_return_type_of_field_instantiation + if info.has_base(fullnames.FIELD_FULLNAME): + return fields.process_field_instantiation - if helpers.get_django_metadata(info).get('generated_init'): + if metadata.get_django_metadata(info).get('generated_init'): return init_create.redefine_and_typecheck_model_init def get_method_hook(self, fullname: str @@ -217,22 +217,22 @@ class DjangoPlugin(Plugin): if method_name == 'get_form_class': info = self._get_typeinfo_or_none(class_name) - if info and info.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME): + if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME): return extract_proper_type_for_get_form_class if method_name == 'get_form': info = self._get_typeinfo_or_none(class_name) - if info and info.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME): + if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME): return extract_proper_type_for_get_form if method_name == 'values': model_info = self._get_typeinfo_or_none(class_name) - if model_info and model_info.has_base(helpers.QUERYSET_CLASS_FULLNAME): + if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME): return extract_proper_type_for_queryset_values if method_name == 'values_list': model_info = self._get_typeinfo_or_none(class_name) - if model_info and model_info.has_base(helpers.QUERYSET_CLASS_FULLNAME): + if model_info and model_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME): return extract_proper_type_queryset_values_list if fullname in {'django.apps.registry.Apps.get_model', @@ -254,11 +254,11 @@ class DjangoPlugin(Plugin): if fullname in self._get_current_manager_bases(): return transform_manager_class - if fullname in self._get_current_form_bases(): - return transform_form_class + # if fullname in self._get_current_form_bases(): + # return transform_form_class info = self._get_typeinfo_or_none(fullname) - if info and info.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME): + if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME): return transform_form_view return None @@ -266,7 +266,7 @@ class DjangoPlugin(Plugin): def get_attribute_hook(self, fullname: str ) -> Optional[Callable[[AttributeContext], Type]]: class_name, _, attr_name = fullname.rpartition('.') - if class_name == helpers.DUMMY_SETTINGS_BASE_CLASS: + if class_name == fullnames.DUMMY_SETTINGS_BASE_CLASS: return partial(get_type_of_setting, setting_name=attr_name, settings_modules=self._get_settings_modules_in_order_of_priority(), @@ -278,7 +278,7 @@ class DjangoPlugin(Plugin): model_info = self._get_typeinfo_or_none(class_name) if model_info: - related_managers = helpers.get_related_managers_metadata(model_info) + related_managers = metadata.get_related_managers_metadata(model_info) if attr_name in related_managers: return partial(determine_type_of_related_manager, related_manager_name=attr_name) diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index 85d2df7..9ba2960 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -7,7 +7,7 @@ from mypy.types import ( AnyType, CallableType, Instance, TupleType, Type, UnionType, ) -from mypy_django_plugin import helpers +from mypy_django_plugin.lib import metadata, fullnames, helpers def extract_referred_to_type(ctx: FunctionContext) -> Optional[Instance]: @@ -43,9 +43,9 @@ def extract_referred_to_type(ctx: FunctionContext) -> Optional[Instance]: referred_to_type = arg_type.ret_type if not isinstance(referred_to_type, Instance): return None - if not referred_to_type.type.has_base(helpers.MODEL_CLASS_FULLNAME): + if not referred_to_type.type.has_base(fullnames.MODEL_CLASS_FULLNAME): ctx.api.msg.fail(f'to= parameter value must be ' - f'a subclass of {helpers.MODEL_CLASS_FULLNAME}', + f'a subclass of {fullnames.MODEL_CLASS_FULLNAME!r}', context=ctx.context) return None @@ -118,26 +118,27 @@ def transform_into_proper_return_type(ctx: FunctionContext) -> Type: if not isinstance(default_return_type, Instance): return default_return_type - if helpers.has_any_of_bases(default_return_type.type, (helpers.FOREIGN_KEY_FULLNAME, - helpers.ONETOONE_FIELD_FULLNAME, - helpers.MANYTOMANY_FIELD_FULLNAME)): + if helpers.has_any_of_bases(default_return_type.type, (fullnames.FOREIGN_KEY_FULLNAME, + fullnames.ONETOONE_FIELD_FULLNAME, + fullnames.MANYTOMANY_FIELD_FULLNAME)): return fill_descriptor_types_for_related_field(ctx) - if default_return_type.type.has_base(helpers.ARRAY_FIELD_FULLNAME): + if default_return_type.type.has_base(fullnames.ARRAY_FIELD_FULLNAME): return determine_type_of_array_field(ctx) return set_descriptor_types_for_field(ctx) -def adjust_return_type_of_field_instantiation(ctx: FunctionContext) -> Type: - record_field_properties_into_outer_model_class(ctx) +def process_field_instantiation(ctx: FunctionContext) -> Type: + # Parse __init__ parameters of field into corresponding Model's metadata + parse_field_init_arguments_into_model_metadata(ctx) return transform_into_proper_return_type(ctx) -def record_field_properties_into_outer_model_class(ctx: FunctionContext) -> None: +def parse_field_init_arguments_into_model_metadata(ctx: FunctionContext) -> None: api = cast(TypeChecker, ctx.api) outer_model = api.scope.active_class() - if outer_model is None or not outer_model.has_base(helpers.MODEL_CLASS_FULLNAME): + if outer_model is None or not outer_model.has_base(fullnames.MODEL_CLASS_FULLNAME): # outside models.Model class, undetermined return @@ -149,7 +150,7 @@ def record_field_properties_into_outer_model_class(ctx: FunctionContext) -> None if field_name is None: return - fields_metadata = helpers.get_fields_metadata(outer_model) + fields_metadata = metadata.get_fields_metadata(outer_model) # primary key is_primary_key = False diff --git a/mypy_django_plugin/transformers/forms.py b/mypy_django_plugin/transformers/forms.py index 2afc519..891c8a1 100644 --- a/mypy_django_plugin/transformers/forms.py +++ b/mypy_django_plugin/transformers/forms.py @@ -1,7 +1,7 @@ from mypy.plugin import ClassDefContext, MethodContext from mypy.types import CallableType, Instance, NoneTyp, Type, TypeType -from mypy_django_plugin import helpers +from mypy_django_plugin.lib import metadata, helpers def make_meta_nested_class_inherit_from_any(ctx: ClassDefContext) -> None: @@ -19,7 +19,7 @@ def extract_proper_type_for_get_form(ctx: MethodContext) -> Type: form_class_type = helpers.get_argument_type_by_name(ctx, 'form_class') if form_class_type is None or isinstance(form_class_type, NoneTyp): # extract from specified form_class in metadata - form_class_fullname = helpers.get_django_metadata(object_type.type).get('form_class', None) + form_class_fullname = metadata.get_django_metadata(object_type.type).get('form_class', None) if not form_class_fullname: return ctx.default_return_type @@ -39,7 +39,7 @@ def extract_proper_type_for_get_form_class(ctx: MethodContext) -> Type: if not isinstance(object_type, Instance): return ctx.default_return_type - form_class_fullname = helpers.get_django_metadata(object_type.type).get('form_class', None) + form_class_fullname = metadata.get_django_metadata(object_type.type).get('form_class', None) if not form_class_fullname: return ctx.default_return_type diff --git a/mypy_django_plugin/transformers/init_create.py b/mypy_django_plugin/transformers/init_create.py index 4e4f97d..54f4458 100644 --- a/mypy_django_plugin/transformers/init_create.py +++ b/mypy_django_plugin/transformers/init_create.py @@ -5,13 +5,13 @@ from mypy.nodes import TypeInfo, Var from mypy.plugin import FunctionContext, MethodContext from mypy.types import AnyType, Instance, Type, TypeOfAny -from mypy_django_plugin import helpers +from mypy_django_plugin.lib import metadata, fullnames, helpers def extract_base_pointer_args(model: TypeInfo) -> Set[str]: pointer_args: Set[str] = set() for base in model.bases: - if base.type.has_base(helpers.MODEL_CLASS_FULLNAME): + if base.type.has_base(fullnames.MODEL_CLASS_FULLNAME): parent_name = base.type.name().lower() pointer_args.add(f'{parent_name}_ptr') pointer_args.add(f'{parent_name}_ptr_id') @@ -105,7 +105,7 @@ def redefine_and_typecheck_model_create(ctx: MethodContext) -> Type: def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]: - field_metadata = helpers.get_fields_metadata(model).get(field_name, {}) + field_metadata = metadata.get_fields_metadata(model).get(field_name, {}) if 'choices' in field_metadata: return field_metadata['choices'] return None @@ -146,8 +146,8 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo, if field_type is None: continue - if helpers.has_any_of_bases(typ.type, (helpers.FOREIGN_KEY_FULLNAME, - helpers.ONETOONE_FIELD_FULLNAME)): + if helpers.has_any_of_bases(typ.type, (fullnames.FOREIGN_KEY_FULLNAME, + fullnames.ONETOONE_FIELD_FULLNAME)): related_primary_key_type = AnyType(TypeOfAny.implementation_artifact) # in case it's optional, we need Instance type referred_to_model = typ.args[1] @@ -156,7 +156,7 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo, referred_to_model = helpers.make_required(typ.args[1]) if isinstance(referred_to_model, Instance) and referred_to_model.type.has_base( - helpers.MODEL_CLASS_FULLNAME): + fullnames.MODEL_CLASS_FULLNAME): pk_type = helpers.extract_explicit_set_type_of_model_primary_key(referred_to_model.type) if not pk_type: # extract set type of AutoField @@ -170,11 +170,11 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo, expected_types[name + '_id'] = related_primary_key_type - field_metadata = helpers.get_fields_metadata(model).get(name, {}) + field_metadata = metadata.get_fields_metadata(model).get(name, {}) if field_type: # related fields could be None in __init__ (but should be specified before save()) - if helpers.has_any_of_bases(typ.type, (helpers.FOREIGN_KEY_FULLNAME, - helpers.ONETOONE_FIELD_FULLNAME)) and is_init: + if helpers.has_any_of_bases(typ.type, (fullnames.FOREIGN_KEY_FULLNAME, + fullnames.ONETOONE_FIELD_FULLNAME)) and is_init: field_type = helpers.make_optional(field_type) # if primary_key=True and default specified @@ -184,7 +184,7 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo, # if CharField(blank=True,...) and not nullable, then field can be None in __init__ elif ( - helpers.has_any_of_bases(typ.type, (helpers.CHAR_FIELD_FULLNAME,)) and is_init and + helpers.has_any_of_bases(typ.type, (fullnames.CHAR_FIELD_FULLNAME,)) and is_init and field_metadata.get('blank', False) and not field_metadata.get('null', False) ): field_type = helpers.make_optional(field_type) diff --git a/mypy_django_plugin/transformers/migrations.py b/mypy_django_plugin/transformers/migrations.py index b6baad8..462178a 100644 --- a/mypy_django_plugin/transformers/migrations.py +++ b/mypy_django_plugin/transformers/migrations.py @@ -5,7 +5,7 @@ from mypy.nodes import Expression, StrExpr, TypeInfo from mypy.plugin import MethodContext from mypy.types import Instance, Type, TypeType -from mypy_django_plugin import helpers +from mypy_django_plugin.lib import helpers def get_string_value_from_expr(expr: Expression) -> Optional[str]: diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index 099855b..45cf868 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -11,7 +11,7 @@ from mypy.plugins.common import add_method from mypy.semanal import SemanticAnalyzerPass2 from mypy.types import AnyType, Instance, NoneTyp, TypeOfAny -from mypy_django_plugin import helpers +from mypy_django_plugin.lib import metadata, fullnames, helpers @dataclasses.dataclass @@ -55,8 +55,8 @@ def iter_over_one_to_n_related_fields(klass: ClassDef) -> Iterator[Tuple[NameExp for lvalue, rvalue in helpers.iter_call_assignments(klass): if (isinstance(lvalue, NameExpr) and isinstance(rvalue.callee, MemberExpr)): - if rvalue.callee.fullname in {helpers.FOREIGN_KEY_FULLNAME, - helpers.ONETOONE_FIELD_FULLNAME}: + if rvalue.callee.fullname in {fullnames.FOREIGN_KEY_FULLNAME, + fullnames.ONETOONE_FIELD_FULLNAME}: yield lvalue, rvalue @@ -97,7 +97,7 @@ class AddDefaultObjectsManager(ModelClassInitializer): callee_expr = callee_expr.analyzed.expr if isinstance(callee_expr, (MemberExpr, NameExpr)) \ and isinstance(callee_expr.node, TypeInfo) \ - and callee_expr.node.has_base(helpers.BASE_MANAGER_CLASS_FULLNAME): + and callee_expr.node.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME): managers.append((manager_name, callee_expr.node)) return managers @@ -115,7 +115,7 @@ class AddDefaultObjectsManager(ModelClassInitializer): # abstract models do not need 'objects' queryset return None - first_manager_type = self.api.named_type_or_none(helpers.MANAGER_CLASS_FULLNAME, + first_manager_type = self.api.named_type_or_none(fullnames.MANAGER_CLASS_FULLNAME, args=[Instance(self.model_classdef.info, [])]) self.add_new_manager('objects', manager_type=first_manager_type) @@ -148,18 +148,21 @@ class AddRelatedManagers(ModelClassInitializer): self.add_new_node_to_model_class(manager_name, self.api.builtin_type('builtins.object')) # save name in metadata for use in get_attribute_hook later - related_managers_metadata = helpers.get_related_managers_metadata(self.model_classdef.info) + related_managers_metadata = metadata.get_related_managers_metadata(self.model_classdef.info) related_managers_metadata[manager_name] = related_field_type_data def run(self) -> None: for module_name, module_file in self.api.modules.items(): for model_defn in helpers.iter_over_classdefs(module_file): - for lvalue, rvalue in helpers.iter_call_assignments(model_defn): - if is_related_field(rvalue, module_file): + if not model_defn.info: + self.api.defer() + + for lvalue, field_init in helpers.iter_call_assignments(model_defn): + if is_related_field(field_init, module_file): try: - referenced_model_fullname = extract_ref_to_fullname(rvalue, - module_file=module_file, - all_modules=self.api.modules) + referenced_model_fullname = extract_referenced_model_fullname(field_init, + module_file=module_file, + all_modules=self.api.modules) except helpers.SelfReference: referenced_model_fullname = model_defn.fullname @@ -168,39 +171,33 @@ class AddRelatedManagers(ModelClassInitializer): if self.model_classdef.fullname == referenced_model_fullname: related_name = model_defn.name.lower() + '_set' - if 'related_name' in rvalue.arg_names: - related_name_expr = rvalue.args[rvalue.arg_names.index('related_name')] + if 'related_name' in field_init.arg_names: + related_name_expr = field_init.args[field_init.arg_names.index('related_name')] if not isinstance(related_name_expr, StrExpr): + # not string 'related_name=' not yet supported continue related_name = related_name_expr.value if related_name == '+': # No backwards relation is desired continue - if 'related_query_name' in rvalue.arg_names: - related_query_name_expr = rvalue.args[rvalue.arg_names.index('related_query_name')] - if not isinstance(related_query_name_expr, StrExpr): - related_query_name = None - else: + # Default related_query_name to related_name + related_query_name = related_name + if 'related_query_name' in field_init.arg_names: + related_query_name_expr = field_init.args[field_init.arg_names.index('related_query_name')] + if isinstance(related_query_name_expr, StrExpr): related_query_name = related_query_name_expr.value + else: + # not string 'related_query_name=' is not yet supported + related_query_name = None # TODO: Handle defaulting to model name if related_name is not set - else: - # No related_query_name specified, default to related_name - related_query_name = related_name - # field_type_data = get_related_field_type(rvalue, self.api, defn.info) - # if typ is None: - # continue - - # TODO: recursively serialize types, or just https://github.com/python/mypy/issues/6506 # as long as Model is not a Generic, one level depth is fine - if rvalue.callee.name in {'ForeignKey', 'ManyToManyField'}: + if field_init.callee.name in {'ForeignKey', 'ManyToManyField'}: field_type_data = { - 'manager': helpers.RELATED_MANAGER_CLASS_FULLNAME, + 'manager': fullnames.RELATED_MANAGER_CLASS_FULLNAME, 'of': [model_defn.info.fullname()] } - # return api.named_type_or_none(helpers.RELATED_MANAGER_CLASS_FULLNAME, - # args=[Instance(related_model_typ, [])]) else: field_type_data = { 'manager': model_defn.info.fullname(), @@ -211,7 +208,7 @@ class AddRelatedManagers(ModelClassInitializer): if related_query_name is not None: # Only create related_query_name if it is a string literal - helpers.get_lookups_metadata(self.model_classdef.info)[related_query_name] = { + metadata.get_lookups_metadata(self.model_classdef.info)[related_query_name] = { 'related_query_name_target': related_name } @@ -219,20 +216,18 @@ class AddRelatedManagers(ModelClassInitializer): def get_related_field_type(rvalue: CallExpr, related_model_typ: TypeInfo) -> Dict[str, Any]: if rvalue.callee.name in {'ForeignKey', 'ManyToManyField'}: return { - 'manager': helpers.RELATED_MANAGER_CLASS_FULLNAME, + 'manager': fullnames.RELATED_MANAGER_CLASS_FULLNAME, 'of': [related_model_typ.fullname()] } - # return api.named_type_or_none(helpers.RELATED_MANAGER_CLASS_FULLNAME, - # args=[Instance(related_model_typ, [])]) else: return { 'manager': related_model_typ.fullname(), 'of': [] } - # return Instance(related_model_typ, []) def is_related_field(expr: CallExpr, module_file: MypyFile) -> bool: + """ Checks whether current CallExpr represents any supported RelatedField subclass""" if isinstance(expr.callee, MemberExpr) and isinstance(expr.callee.expr, NameExpr): module = module_file.names.get(expr.callee.expr.name) if module \ @@ -244,12 +239,15 @@ def is_related_field(expr: CallExpr, module_file: MypyFile) -> bool: return False -def extract_ref_to_fullname(rvalue_expr: CallExpr, - module_file: MypyFile, all_modules: Dict[str, MypyFile]) -> Optional[str]: +def extract_referenced_model_fullname(rvalue_expr: CallExpr, + module_file: MypyFile, + all_modules: Dict[str, MypyFile]) -> Optional[str]: + """ Returns fullname of a Model referenced in "to=" argument of the CallExpr""" if 'to' in rvalue_expr.arg_names: to_expr = rvalue_expr.args[rvalue_expr.arg_names.index('to')] else: to_expr = rvalue_expr.args[0] + if isinstance(to_expr, NameExpr): return module_file.names[to_expr.name].fullname elif isinstance(to_expr, StrExpr): diff --git a/mypy_django_plugin/transformers/queryset.py b/mypy_django_plugin/transformers/queryset.py index d3b0e33..993d61c 100644 --- a/mypy_django_plugin/transformers/queryset.py +++ b/mypy_django_plugin/transformers/queryset.py @@ -8,8 +8,8 @@ from mypy.plugin import ( ) from mypy.types import AnyType, Instance, Type, TypeOfAny -from mypy_django_plugin import helpers -from mypy_django_plugin.lookups import ( +from mypy_django_plugin.lib import helpers +from mypy_django_plugin.lib.lookups import ( LookupException, RelatedModelNode, resolve_lookup, ) diff --git a/mypy_django_plugin/transformers/related.py b/mypy_django_plugin/transformers/related.py index 46d2e02..cb0c962 100644 --- a/mypy_django_plugin/transformers/related.py +++ b/mypy_django_plugin/transformers/related.py @@ -4,7 +4,7 @@ from mypy.checkmember import AttributeContext from mypy.nodes import TypeInfo from mypy.types import AnyType, Instance, Type, TypeOfAny, UnionType -from mypy_django_plugin import helpers +from mypy_django_plugin.lib import fullnames, helpers def _extract_referred_to_type_info(typ: Union[UnionType, Instance]) -> Optional[TypeInfo]: @@ -22,7 +22,7 @@ def extract_and_return_primary_key_of_bound_related_field_parameter(ctx: Attribu if not isinstance(ctx.default_attr_type, Instance) or not (ctx.default_attr_type.type.fullname() == 'builtins.int'): return ctx.default_attr_type - if not isinstance(ctx.type, Instance) or not ctx.type.type.has_base(helpers.MODEL_CLASS_FULLNAME): + if not isinstance(ctx.type, Instance) or not ctx.type.type.has_base(fullnames.MODEL_CLASS_FULLNAME): return ctx.default_attr_type field_name = ctx.context.name.split('_')[0] diff --git a/mypy_django_plugin/transformers/settings.py b/mypy_django_plugin/transformers/settings.py index 1e72485..afd89b4 100644 --- a/mypy_django_plugin/transformers/settings.py +++ b/mypy_django_plugin/transformers/settings.py @@ -5,7 +5,7 @@ from mypy.checkmember import AttributeContext from mypy.nodes import NameExpr, StrExpr, SymbolTableNode, TypeInfo from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeType -from mypy_django_plugin import helpers +from mypy_django_plugin.lib import helpers if TYPE_CHECKING: from mypy.checker import TypeChecker diff --git a/test-data/typecheck/fields.test b/test-data/typecheck/fields.test index 9af5571..741e40d 100644 --- a/test-data/typecheck/fields.test +++ b/test-data/typecheck/fields.test @@ -1,27 +1,3 @@ -[CASE array_field_descriptor_access] -from django.db import models -from django.contrib.postgres.fields import ArrayField - -class User(models.Model): - array = ArrayField(base_field=models.Field()) - -user = User() -reveal_type(user.array) # N: Revealed type is 'builtins.list*[Any]' -[/CASE] - -[CASE array_field_base_field_parsed_into_generic_typevar] -from django.db import models -from django.contrib.postgres.fields import ArrayField - -class User(models.Model): - members = ArrayField(base_field=models.IntegerField()) - members_as_text = ArrayField(base_field=models.CharField(max_length=255)) - -user = User() -reveal_type(user.members) # N: Revealed type is 'builtins.list*[builtins.int]' -reveal_type(user.members_as_text) # N: Revealed type is 'builtins.list*[builtins.str]' -[/CASE] - [CASE test_model_fields_classes_present_as_primitives] from django.db import models @@ -142,3 +118,27 @@ MyModel(notnulltext="") MyModel().notnulltext = None # E: Incompatible types in assignment (expression has type "None", variable has type "Union[str, int, Combinable]") reveal_type(MyModel().notnulltext) # N: Revealed type is 'builtins.str*' [/CASE] + +[CASE array_field_descriptor_access] +from django.db import models +from django.contrib.postgres.fields import ArrayField + +class User(models.Model): + array = ArrayField(base_field=models.Field()) + +user = User() +reveal_type(user.array) # N: Revealed type is 'builtins.list*[Any]' +[/CASE] + +[CASE array_field_base_field_parsed_into_generic_typevar] +from django.db import models +from django.contrib.postgres.fields import ArrayField + +class User(models.Model): + members = ArrayField(base_field=models.IntegerField()) + members_as_text = ArrayField(base_field=models.CharField(max_length=255)) + +user = User() +reveal_type(user.members) # N: Revealed type is 'builtins.list*[builtins.int]' +reveal_type(user.members_as_text) # N: Revealed type is 'builtins.list*[builtins.str]' +[/CASE] diff --git a/test-data/typecheck/related_fields.test b/test-data/typecheck/related_fields.test index 57e2299..4145305 100644 --- a/test-data/typecheck/related_fields.test +++ b/test-data/typecheck/related_fields.test @@ -173,6 +173,7 @@ class User(models.Model): [CASE models_triple_circular_reference] from myapp.models import App +reveal_type(App().owner) # N: Revealed type is 'myapp.models.user.User' reveal_type(App().owner.profile) # N: Revealed type is 'myapp.models.profile.Profile' [file myapp/__init__.py]