diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..0900f01 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,4 @@ +[mypy] + +[mypy-mypy_django_plugin.monkeypatch.*] +ignore_errors = True \ No newline at end of file diff --git a/mypy_django_plugin/helpers.py b/mypy_django_plugin/helpers.py index 1ef5de5..c4ba9b0 100644 --- a/mypy_django_plugin/helpers.py +++ b/mypy_django_plugin/helpers.py @@ -1,9 +1,7 @@ import typing -from typing import Dict, Optional, NamedTuple +from typing import Dict, Optional -from mypy.nodes import Expression, StrExpr, MypyFile, TypeInfo -from mypy.plugin import FunctionContext -from mypy.types import Type, UnionType, NoneTyp +from mypy.nodes import StrExpr, MypyFile, TypeInfo, ImportedName, SymbolNode MODEL_CLASS_FULLNAME = 'django.db.models.base.Model' QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet' @@ -11,63 +9,14 @@ FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey' ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField' DUMMY_SETTINGS_BASE_CLASS = 'django.conf._DjangoConfLazyObject' -Argument = NamedTuple('Argument', fields=[ - ('arg', Expression), - ('arg_type', Type) -]) - - -def get_call_signature_or_none(ctx: FunctionContext) -> Optional[Dict[str, Argument]]: - result: Dict[str, Argument] = {} - positional_args_only = [] - positional_arg_types_only = [] - for arg, arg_name, arg_type in zip(ctx.args, ctx.arg_names, ctx.arg_types): - if arg_name is None: - positional_args_only.append(arg) - positional_arg_types_only.append(arg_type) - continue - - if len(arg) == 0 or len(arg_type) == 0: - continue - - result[arg_name] = (arg[0], arg_type[0]) - - callee = ctx.context.callee - if '__init__' not in callee.node.names: - return None - - init_type = callee.node.names['__init__'].type - arg_names = init_type.arg_names[1:] - for arg, arg_name, arg_type in zip(positional_args_only, - arg_names[:len(positional_args_only)], - positional_arg_types_only): - result[arg_name] = (arg[0], arg_type[0]) - - return result - - -def make_optional(typ: Type) -> Type: - return UnionType.make_simplified_union([typ, NoneTyp()]) - - -def make_required(typ: Type) -> Type: - if not isinstance(typ, UnionType): - return typ - items = [item for item in typ.items if not isinstance(item, NoneTyp)] - return UnionType.make_union(items) - - -def get_obj_type_name(typ: typing.Type) -> str: - return typ.__module__ + '.' + typ.__qualname__ - def get_models_file(app_name: str, all_modules: typing.Dict[str, MypyFile]) -> Optional[MypyFile]: models_module = '.'.join([app_name, 'models']) return all_modules.get(models_module) -def get_model_type_from_string(expr: StrExpr, - all_modules: Dict[str, MypyFile]) -> Optional[TypeInfo]: +def get_model_fullname_from_string(expr: StrExpr, + all_modules: Dict[str, MypyFile]) -> Optional[str]: app_name, model_name = expr.value.split('.') models_file = get_models_file(app_name, all_modules) @@ -75,7 +24,26 @@ def get_model_type_from_string(expr: StrExpr, # not imported so far, not supported return None sym = models_file.names.get(model_name) - if not sym or not isinstance(sym.node, TypeInfo): - # no such model found in the app / node is not a class definition + if not sym: + return None + + if isinstance(sym.node, TypeInfo): + return sym.node.fullname() + elif isinstance(sym.node, ImportedName): + return sym.node.target_fullname + else: + return None + + +def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolNode]: + if '.' not in name: + return None + module, cls_name = name.rsplit('.', 1) + + module_file = all_modules.get(module) + if module_file is None: + return None + sym = module_file.names.get(cls_name) + if sym is None: return None return sym.node diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 10b7d24..e838160 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -42,9 +42,6 @@ class DjangoPlugin(Plugin): helpers.ONETOONE_FIELD_FULLNAME}: return extract_to_parameter_as_get_ret_type_for_related_field - # if fullname == helpers.ONETOONE_FIELD_FULLNAME: - # return OneToOneFieldHook(settings=self.django_settings) - if fullname == 'django.contrib.postgres.fields.array.ArrayField': return determine_type_of_array_field return None diff --git a/mypy_django_plugin/plugins/models.py b/mypy_django_plugin/plugins/models.py index 3859a1b..7ecd1c8 100644 --- a/mypy_django_plugin/plugins/models.py +++ b/mypy_django_plugin/plugins/models.py @@ -1,7 +1,7 @@ -from typing import cast, Iterator, Tuple, Optional +from typing import cast, Iterator, Tuple, Optional, Dict from mypy.nodes import ClassDef, AssignmentStmt, CallExpr, MemberExpr, StrExpr, NameExpr, MDEF, TypeInfo, Var, SymbolTableNode, \ - Lvalue, Expression, Statement + Lvalue, Expression, MypyFile from mypy.plugin import ClassDefContext from mypy.semanal import SemanticAnalyzerPass2 from mypy.types import Instance @@ -30,12 +30,11 @@ def iter_over_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, Expression] def iter_call_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, CallExpr]]: for lvalue, rvalue in iter_over_assignments(klass): - if not isinstance(rvalue, CallExpr): - continue - yield lvalue, rvalue + if isinstance(rvalue, CallExpr): + yield lvalue, rvalue -def iter_over_one_to_n_related_fields(klass: ClassDef, api: SemanticAnalyzerPass2) -> Iterator[Tuple[NameExpr, CallExpr]]: +def iter_over_one_to_n_related_fields(klass: ClassDef) -> Iterator[Tuple[NameExpr, CallExpr]]: for lvalue, rvalue in iter_call_assignments(klass): if (isinstance(lvalue, NameExpr) and isinstance(rvalue.callee, MemberExpr)): @@ -67,7 +66,7 @@ def is_abstract_model(ctx: ClassDefContext) -> bool: def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None: api = ctx.api - for lvalue, rvalue in iter_over_one_to_n_related_fields(ctx.cls, api): + for lvalue, rvalue in iter_over_one_to_n_related_fields(ctx.cls): property_name = lvalue.name + '_id' add_new_var_node_to_class(ctx.cls.info, property_name, typ=api.named_type('__builtins__.int')) @@ -112,68 +111,67 @@ def inject_any_as_base_for_nested_class_meta(ctx: ClassDefContext) -> None: meta_node.fallback_to_any = True -def is_model_defn(defn: Statement, api: SemanticAnalyzerPass2) -> bool: - if not isinstance(defn, ClassDef): - return False - - for base_type_expr in defn.base_type_exprs: - # api.accept(base_type_expr) - fullname = getattr(base_type_expr, 'fullname', None) - if fullname == helpers.MODEL_CLASS_FULLNAME: - return True - return False - - -def iter_over_models(ctx: ClassDefContext) -> Iterator[ClassDef]: - for module_name, module_file in ctx.api.modules.items(): - for defn in module_file.defs: - if is_model_defn(defn, api=cast(SemanticAnalyzerPass2, ctx.api)): - yield defn - - -def extract_to_value_or_none(field_expr: CallExpr, ctx: ClassDefContext) -> Optional[TypeInfo]: - if 'to' in field_expr.arg_names: - ref_expr = field_expr.args[field_expr.arg_names.index('to')] - else: - # first positional argument - ref_expr = field_expr.args[0] - - if isinstance(ref_expr, StrExpr): - model_typeinfo = helpers.get_model_type_from_string(ref_expr, - all_modules=ctx.api.modules) - return model_typeinfo - elif isinstance(ref_expr, NameExpr): - return ref_expr.node +def iter_over_classdefs(module_file: MypyFile) -> Iterator[ClassDef]: + for defn in module_file.defs: + if isinstance(defn, ClassDef): + yield defn def get_related_field_type(rvalue: CallExpr, api: SemanticAnalyzerPass2, related_model_typ: TypeInfo) -> Optional[Instance]: - if rvalue.callee.fullname == helpers.FOREIGN_KEY_FULLNAME: + if rvalue.callee.name == 'ForeignKey': return api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME, args=[Instance(related_model_typ, [])]) else: return Instance(related_model_typ, []) -def add_related_managers(ctx: ClassDefContext) -> None: - for model_defn in iter_over_models(ctx): - for _, rvalue in iter_over_one_to_n_related_fields(model_defn, ctx.api): - if 'related_name' not in rvalue.arg_names: - # positional related_name is not supported yet - return - related_name = rvalue.args[rvalue.arg_names.index('related_name')].value - ref_to_typ = extract_to_value_or_none(rvalue, ctx) - if ref_to_typ is not None: - if ref_to_typ.fullname() == ctx.cls.info.fullname(): - typ = get_related_field_type(rvalue, ctx.api, - related_model_typ=model_defn.info) - if typ is None: - return - add_new_var_node_to_class(ctx.cls.info, related_name, typ) +def is_related_field(expr: CallExpr, module_file: MypyFile) -> bool: + if isinstance(expr.callee, MemberExpr) and isinstance(expr.callee.expr, NameExpr): + module = module_file.names[expr.callee.expr.name] + if module.fullname == 'django.db.models' and expr.callee.name in {'ForeignKey', 'OneToOneField'}: + return True + return False + + +def extract_ref_to_fullname(rvalue_expr: CallExpr, + module_file: MypyFile, all_modules: Dict[str, MypyFile]) -> Optional[str]: + if 'to' in rvalue_expr.arg_names: + to_expr = rvalue_expr.args[rvalue_expr.arg_names.index('to')] + else: + to_expr = rvalue_expr.args[0] + if isinstance(to_expr, NameExpr): + return module_file.names[to_expr.name].fullname + elif isinstance(to_expr, StrExpr): + typ_fullname = helpers.get_model_fullname_from_string(to_expr, all_modules) + if typ_fullname is None: + return None + return typ_fullname + return None + + +def add_related_managers(ctx: ClassDefContext): + api = cast(SemanticAnalyzerPass2, ctx.api) + for module_name, module_file in ctx.api.modules.items(): + for defn in iter_over_classdefs(module_file): + for lvalue, rvalue in iter_call_assignments(defn): + if is_related_field(rvalue, module_file): + ref_to_fullname = extract_ref_to_fullname(rvalue, module_file=module_file, + all_modules=api.modules) + if ctx.cls.fullname == ref_to_fullname: + if 'related_name' in rvalue.arg_names: + related_name_expr = rvalue.args[rvalue.arg_names.index('related_name')] + if not isinstance(related_name_expr, StrExpr): + return None + related_name = related_name_expr.value + typ = get_related_field_type(rvalue, api, defn.info) + if typ is None: + return None + add_new_var_node_to_class(ctx.cls.info, related_name, typ) def process_model_class(ctx: ClassDefContext) -> None: - # add_related_managers(ctx) + add_related_managers(ctx) inject_any_as_base_for_nested_class_meta(ctx) set_fieldname_attrs_for_related_fields(ctx) add_int_id_attribute_if_primary_key_true_is_not_present(ctx) diff --git a/mypy_django_plugin/plugins/related_fields.py b/mypy_django_plugin/plugins/related_fields.py index 2a3c72c..fa89b75 100644 --- a/mypy_django_plugin/plugins/related_fields.py +++ b/mypy_django_plugin/plugins/related_fields.py @@ -2,7 +2,7 @@ import typing from typing import Optional, cast from mypy.checker import TypeChecker -from mypy.nodes import StrExpr +from mypy.nodes import StrExpr, TypeInfo from mypy.plugin import FunctionContext from mypy.types import Type, CallableType, Instance, AnyType, TypeOfAny @@ -18,10 +18,11 @@ def fill_typevars_with_any(instance: Instance) -> Type: def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]: + api = cast(TypeChecker, ctx.api) if 'to' not in ctx.arg_names: # shouldn't happen, invalid code - ctx.api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}', - context=ctx.context) + api.msg.fail(f'to= parameter must be set for {ctx.context.callee.fullname}', + context=ctx.context) return None arg_type = ctx.arg_types[ctx.arg_names.index('to')][0] @@ -30,13 +31,19 @@ def get_valid_to_value_or_none(ctx: FunctionContext) -> Optional[Instance]: if not isinstance(to_arg_expr, StrExpr): # not string, not supported return None - model_info = helpers.get_model_type_from_string(to_arg_expr, - all_modules=cast(TypeChecker, ctx.api).modules) - if model_info is None: + model_fullname = helpers.get_model_fullname_from_string(to_arg_expr, + all_modules=api.modules) + if model_fullname is None: + return None + model_info = helpers.lookup_fully_qualified_generic(model_fullname, + all_modules=api.modules) + if model_info is None or not isinstance(model_info, TypeInfo): return None return Instance(model_info, []) referred_to_type = arg_type.ret_type + if not isinstance(referred_to_type, Instance): + return None for base in referred_to_type.type.bases: if base.type.fullname() == helpers.MODEL_CLASS_FULLNAME: break diff --git a/mypy_django_plugin/plugins/settings.py b/mypy_django_plugin/plugins/settings.py index 08bffa3..ef3f24c 100644 --- a/mypy_django_plugin/plugins/settings.py +++ b/mypy_django_plugin/plugins/settings.py @@ -1,12 +1,10 @@ -from typing import cast, List +from typing import cast, List, Optional from mypy.nodes import Var, Context, SymbolNode, SymbolTableNode from mypy.plugin import ClassDefContext from mypy.semanal import SemanticAnalyzerPass2 from mypy.types import Instance, UnionType, NoneTyp, Type -from mypy_django_plugin import helpers - def get_error_context(node: SymbolNode) -> Context: context = Context() @@ -18,10 +16,24 @@ def filter_out_nones(typ: UnionType) -> List[Type]: return [item for item in typ.items if not isinstance(item, NoneTyp)] -def copy_sym_of_instance(sym: SymbolTableNode) -> SymbolTableNode: - copied = sym.copy() - copied.node.info = sym.type.type - return copied +def make_sym_copy_of_setting(sym: SymbolTableNode) -> Optional[SymbolTableNode]: + if isinstance(sym.type, Instance): + copied = sym.copy() + copied.node.info = sym.type.type + return copied + elif isinstance(sym.type, UnionType): + instances = filter_out_nones(sym.type) + if len(instances) > 1: + # plain unions not supported yet + return None + typ = instances[0] + if isinstance(typ, Instance): + copied = sym.copy() + copied.node.info = typ.type + return copied + return None + else: + return None def add_settings_to_django_conf_object(ctx: ClassDefContext, @@ -33,25 +45,24 @@ def add_settings_to_django_conf_object(ctx: ClassDefContext, settings_file = api.modules[settings_module] for name, sym in settings_file.names.items(): if name.isupper() and isinstance(sym.node, Var): - if isinstance(sym.type, Instance): - copied = sym.copy() - copied.node.info = sym.type.type - ctx.cls.info.names[name] = copied - - elif isinstance(sym.type, UnionType): - instances = filter_out_nones(sym.type) - if len(instances) > 1: - # plain unions not supported yet + if sym.type is not None: + copied = make_sym_copy_of_setting(sym) + if copied is None: continue - typ = instances[0] - if isinstance(typ, Instance): - copied = sym.copy() - copied.node.info = typ.type - ctx.cls.info.names[name] = copied + ctx.cls.info.names[name] = copied + else: + context = Context() + module, node_name = sym.node.fullname().rsplit('.', 1) + module_file = api.modules.get(module) + if module_file is None: + return None + context.set_line(sym.node) + api.msg.report(f"Need type annotation for '{sym.node.name()}'", context, + severity='error', file=module_file.path) class DjangoConfSettingsInitializerHook(object): - def __init__(self, settings_module: str): + def __init__(self, settings_module: Optional[str]): self.settings_module = settings_module def __call__(self, ctx: ClassDefContext) -> None: diff --git a/pytest.ini b/pytest.ini index 3a56481..9a4bb98 100644 --- a/pytest.ini +++ b/pytest.ini @@ -5,5 +5,6 @@ python_files = test*.py addopts = --tb=native --mypy-ini-file=./test-data/plugins.ini + --mypy-no-cache -s -v \ No newline at end of file diff --git a/test-data/typecheck/related_fields.test b/test-data/typecheck/related_fields.test index 97dbb85..22a4106 100644 --- a/test-data/typecheck/related_fields.test +++ b/test-data/typecheck/related_fields.test @@ -119,7 +119,7 @@ class App(models.Model): def method(self) -> None: reveal_type(self.views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]' -[case models_related_managers_work_with_direct_model_inheritance_and_with_inheritance_from_other_model] +[CASE models_related_managers_work_with_direct_model_inheritance_and_with_inheritance_from_other_model] from django.db.models import Model from django.db import models @@ -134,4 +134,62 @@ class View2(View): reveal_type(App().views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]' reveal_type(App().views2) # E: Revealed type is 'django.db.models.query.QuerySet[main.View2]' -[out] \ No newline at end of file +[out] + +[CASE models_imported_inside_init_file] +from django.db import models +from myapp.models import App +class View(models.Model): + app = models.ForeignKey(to='myapp.App', related_name='views', on_delete=models.CASCADE) +reveal_type(View().app.views) # E: Revealed type is 'django.db.models.query.QuerySet[main.View]' + +[file myapp/__init__.py] +[file myapp/models/__init__.py] +from .app import App +[file myapp/models/app.py] +from django.db import models +class App(models.Model): + pass + +[CASE models_imported_inside_init_file_one_to_one_field] +from django.db import models +from myapp.models import User +class Profile(models.Model): + user = models.OneToOneField(to='myapp.User', related_name='profile', on_delete=models.CASCADE) +reveal_type(Profile().user.profile) # E: Revealed type is 'main.Profile' + +[file myapp/__init__.py] +[file myapp/models/__init__.py] +from .user import User +[file myapp/models/user.py] +from django.db import models +class User(models.Model): + pass + +[CASE models_triple_circular_reference] +from myapp.models import App +reveal_type(App().owner.profile) # E: Revealed type is 'myapp.models.profile.Profile' + +[file myapp/__init__.py] +[file myapp/models/__init__.py] +from .user import User +from .profile import Profile +from .app import App + +[file myapp/models/user.py] +from django.db import models +class User(models.Model): + pass + +[file myapp/models/profile.py] +from django.db import models +from myapp.models import User +class Profile(models.Model): + user = models.OneToOneField(to='myapp.User', related_name='profile', on_delete=models.CASCADE) + +[file myapp/models/app.py] +from django.db import models +class App(models.Model): + owner = models.ForeignKey(to='myapp.User', on_delete=models.CASCADE, related_name='apps') + + diff --git a/test-data/typecheck/settings.test b/test-data/typecheck/settings.test index 9ba185e..9d12d18 100644 --- a/test-data/typecheck/settings.test +++ b/test-data/typecheck/settings.test @@ -36,17 +36,47 @@ ROOT_DIR = Path(__file__) from django.conf import settings class Class: pass -reveal_type(settings.MYSETTING) # E: Revealed type is 'builtins.int' reveal_type(settings.REGISTRY) # E: Revealed type is 'Union[main.Class, None]' +reveal_type(settings.LIST) # E: Revealed type is 'Any' +reveal_type(settings.BASE_LIST) # E: Revealed type is 'Any' +[out] +main:5: error: "LazySettings" has no attribute "LIST" +main:6: error: "LazySettings" has no attribute "BASE_LIST" +mysettings:4: error: Need type annotation for 'LIST' +base:6: error: Need type annotation for 'BASE_LIST' [env DJANGO_SETTINGS_MODULE=mysettings] [file mysettings.py] -from typing import TYPE_CHECKING from typing import Optional +from base import * + +LIST = ['1', '2'] + +[file base.py] +from typing import TYPE_CHECKING, Optional +if TYPE_CHECKING: + from main import Class + +REGISTRY: Optional['Class'] = None +BASE_LIST = ['3', '4'] + +[CASE test_circular_dependency_in_settings_works_if_settings_have_annotations] +from django.conf import settings +class Class: + pass +reveal_type(settings.MYSETTING) # E: Revealed type is 'builtins.int' +reveal_type(settings.REGISTRY) # E: Revealed type is 'Union[main.Class, None]' +reveal_type(settings.LIST) # E: Revealed type is 'builtins.list[builtins.str]' +[out] + +[env DJANGO_SETTINGS_MODULE=mysettings] +[file mysettings.py] +from typing import TYPE_CHECKING, Optional, List if TYPE_CHECKING: from main import Class MYSETTING = 1122 REGISTRY: Optional['Class'] = None +LIST: List[str] = ['1', '2']