From 3d7b517e59026b95d60bd6e2809f0463be8073ec Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Thu, 4 Oct 2018 03:32:58 +0300 Subject: [PATCH] support for some Field -> builtin type mapping, OneToOneField, ForeignKey --- mypy_django_plugin/constants.py | 13 -- mypy_django_plugin/helpers.py | 20 ++- mypy_django_plugin/model_classes.py | 16 +- mypy_django_plugin/plugin.py | 141 ------------------ .../{tests => plugins}/__init__.py | 0 mypy_django_plugin/plugins/base.py | 77 ++++++++++ .../plugins/field_to_python_type.py | 60 ++++++++ setup.py | 4 +- 8 files changed, 164 insertions(+), 167 deletions(-) delete mode 100644 mypy_django_plugin/constants.py delete mode 100644 mypy_django_plugin/plugin.py rename mypy_django_plugin/{tests => plugins}/__init__.py (100%) create mode 100644 mypy_django_plugin/plugins/base.py create mode 100644 mypy_django_plugin/plugins/field_to_python_type.py diff --git a/mypy_django_plugin/constants.py b/mypy_django_plugin/constants.py deleted file mode 100644 index 4bb828f..0000000 --- a/mypy_django_plugin/constants.py +++ /dev/null @@ -1,13 +0,0 @@ -# Django fields that has to= attribute used to retrieve original model -REFERENCING_DB_FIELDS = { - 'django.db.models.fields.related.ForeignKey', - 'django.db.models.fields.related.OneToOneField' -} - -# mapping between field types and plain python types -DB_FIELDS_TO_TYPES = { - 'django.db.models.fields.CharField': 'builtins.str', - 'django.db.models.fields.TextField': 'builtins.str', - 'django.db.models.fields.IntegerField': 'builtins.int', - 'django.db.models.fields.FloatField': 'builtins.float' -} diff --git a/mypy_django_plugin/helpers.py b/mypy_django_plugin/helpers.py index 5ec5d15..234bbdd 100644 --- a/mypy_django_plugin/helpers.py +++ b/mypy_django_plugin/helpers.py @@ -16,4 +16,22 @@ def lookup_django_model(mypy_api: TypeChecker, fullname: str) -> SymbolTableNode try: return mypy_api.modules[module].names[model_name] except KeyError: - return mypy_api.modules['django.db.models'].names['Model'] \ No newline at end of file + return mypy_api.modules['django.db.models'].names['Model'] + + +def get_app_model(model_name: str) -> str: + import os + os.environ.setdefault('SITE_URL', 'https://localhost') + os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'server._config.settings.local') + + import django + django.setup() + + from django.apps import apps + + try: + app_name, model_name = model_name.rsplit('.', maxsplit=1) + model = apps.get_model(app_name, model_name) + return model.__module__ + '.' + model_name + except ValueError: + return model_name diff --git a/mypy_django_plugin/model_classes.py b/mypy_django_plugin/model_classes.py index d2f6335..5f0f0d5 100644 --- a/mypy_django_plugin/model_classes.py +++ b/mypy_django_plugin/model_classes.py @@ -1,22 +1,18 @@ -import json -from typing import Dict +from typing import Set import dataclasses -@dataclasses.dataclass -class ModelInfo(object): - # class_name: str - related_managers: Dict[str, 'ModelInfo'] = dataclasses.field(default_factory=dict) - - def get_default_base_models(): - return {'django.db.models.base.Model': ModelInfo()} + return {'django.db.models.base.Model'} @dataclasses.dataclass class DjangoModelsRegistry(object): - base_models: Dict[str, ModelInfo] = dataclasses.field(default_factory=get_default_base_models) + base_models: Set[str] = dataclasses.field(default_factory=get_default_base_models) def __contains__(self, item: str) -> bool: return item in self.base_models + + def __iter__(self): + return iter(self.base_models) diff --git a/mypy_django_plugin/plugin.py b/mypy_django_plugin/plugin.py deleted file mode 100644 index a7aa2fd..0000000 --- a/mypy_django_plugin/plugin.py +++ /dev/null @@ -1,141 +0,0 @@ -import importlib -import inspect -from typing import Optional, Callable - -from mypy.mypyc_hacks import TypeOfAny -from mypy.nodes import PassStmt, CallExpr, SymbolTableNode, MDEF, Var, ClassDef, Decorator, AssignmentStmt, StrExpr -from mypy.plugin import Plugin, ClassDefContext, AttributeContext -from mypy.types import AnyType, Type, Instance, CallableType - -from mypy_django_plugin.constants import REFERENCING_DB_FIELDS, DB_FIELDS_TO_TYPES -from mypy_django_plugin.helpers import lookup_django_model -from mypy_django_plugin.model_classes import DjangoModelsRegistry, ModelInfo - - -model_registry = DjangoModelsRegistry() - - -def get_db_field_arguments(rvalue: CallExpr) -> inspect.BoundArguments: - modulename, _, classname = rvalue.callee.fullname.rpartition('.') - klass = getattr(importlib.import_module(modulename), classname) - bound_signature = inspect.signature(klass).bind(*rvalue.args) - - return bound_signature - - -def process_meta_innerclass(class_def: ClassDef): - pass - - -def get_app_model(model_name: str) -> str: - import os - os.environ.setdefault('SITE_URL', 'https://localhost') - os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'server._config.settings.local') - - import django - django.setup() - - from django.apps import apps - - try: - app_name, model_name = model_name.rsplit('.', maxsplit=1) - model = apps.get_model(app_name, model_name) - return model.__module__ + '.' + model_name - except ValueError: - return model_name - - -def base_class_callback(model_def_context: ClassDefContext) -> None: - # add new possible base models - for base_type_expr in model_def_context.cls.base_type_exprs: - if base_type_expr.fullname in model_registry: - model_registry.base_models[model_def_context.cls.fullname] = ModelInfo() - - for definition in model_def_context.cls.defs.body: - if isinstance(definition, PassStmt): - continue - - if isinstance(definition, ClassDef): - if definition.name == 'Meta': - process_meta_innerclass(definition) - continue - - if isinstance(definition, AssignmentStmt): - rvalue = definition.rvalue - if hasattr(rvalue, 'callee') and rvalue.callee.fullname in REFERENCING_DB_FIELDS: - modulename, _, classname = rvalue.callee.fullname.rpartition('.') - klass = getattr(importlib.import_module(modulename), classname) - bound_signature = inspect.signature(klass).bind(*rvalue.args) - - to_arg_value = bound_signature.arguments['to'] - if isinstance(to_arg_value, StrExpr): - model_fullname = get_app_model(to_arg_value.value) - else: - model_fullname = bound_signature.arguments['to'].fullname - - referred_model = model_fullname - rvalue.callee.node.metadata['base'] = referred_model - - # if 'related_name' in rvalue.arg_names: - # related_name = rvalue.args[rvalue.arg_names.index('related_name')].value - # - # for name, context_global in model_def_context.api.globals.items(): - # context_global: SymbolTableNode = context_global - # if context_global.fullname == referred_model: - # list_of_any = Instance(model_def_context.api.lookup_fully_qualified('typing.List').node, - # [AnyType(TypeOfAny.from_omitted_generics)]) - # related_members_node = Var(related_name, list_of_any) - # context_global.node.names[related_name] = SymbolTableNode(MDEF, related_members_node) - - # model_registry.base_models[referred_model].related_managers[related_name] = model_def_context.cls.fullname - - - -def related_manager_inference_callback(attr_context: AttributeContext) -> Type: - mypy_api = attr_context.api - - default_attr_type = attr_context.default_attr_type - if not isinstance(default_attr_type, Instance): - return default_attr_type - - attr_type_fullname = default_attr_type.type.fullname() - - if attr_type_fullname in DB_FIELDS_TO_TYPES: - return mypy_api.named_type(DB_FIELDS_TO_TYPES[attr_type_fullname]) - - if 'base' in default_attr_type.type.metadata: - base_class = default_attr_type.type.metadata['base'] - - node = lookup_django_model(mypy_api, base_class).node - return Instance(node, []) - # mypy_api.lookup_qualified(base_class) - # app = base_class.split('.')[0] - # name = base_class.split('.')[-1] - # app_model = get_app_model(app + '.' + name) - - # return mypy_api.named_type(name) - - return AnyType(TypeOfAny.unannotated) - - -# -# def type_analyze_callback(context: AnalyzeTypeContext) -> Type: -# return context.type - - -class RelatedManagersDjangoPlugin(Plugin): - def get_base_class_hook(self, fullname: str - ) -> Optional[Callable[[ClassDefContext], None]]: - # every time new class is created, this method is called with first base class in MRO - if fullname in model_registry: - return base_class_callback - - return None - - def get_attribute_hook(self, fullname: str - ) -> Optional[Callable[[AttributeContext], Type]]: - return related_manager_inference_callback - - -def plugin(version): - return RelatedManagersDjangoPlugin diff --git a/mypy_django_plugin/tests/__init__.py b/mypy_django_plugin/plugins/__init__.py similarity index 100% rename from mypy_django_plugin/tests/__init__.py rename to mypy_django_plugin/plugins/__init__.py diff --git a/mypy_django_plugin/plugins/base.py b/mypy_django_plugin/plugins/base.py new file mode 100644 index 0000000..550c78f --- /dev/null +++ b/mypy_django_plugin/plugins/base.py @@ -0,0 +1,77 @@ +from typing import Callable, Optional + +from mypy.nodes import AssignmentStmt, CallExpr, RefExpr, StrExpr +from mypy.plugin import Plugin, ClassDefContext + +from mypy_django_plugin.helpers import get_app_model +from mypy_django_plugin.model_classes import DjangoModelsRegistry + + +# fields which real type is inside to= expression +REFERENCING_DB_FIELDS = { + 'django.db.models.fields.related.ForeignKey', + 'django.db.models.fields.related.OneToOneField' +} + + +def save_referred_to_model_in_metadata(rvalue: CallExpr) -> None: + to_arg_value = rvalue.args[rvalue.arg_names.index('to')] + if isinstance(to_arg_value, StrExpr): + referred_model_fullname = get_app_model(to_arg_value.value) + else: + referred_model_fullname = to_arg_value.fullname + + rvalue.callee.node.metadata['base'] = referred_model_fullname + + +class CollectModelsInformation(object): + def __init__(self, model_registry: DjangoModelsRegistry): + self.model_registry = model_registry + + def __call__(self, model_definition: ClassDefContext) -> None: + self.model_registry.base_models.add(model_definition.cls.fullname) + + for member in model_definition.cls.defs.body: + if isinstance(member, AssignmentStmt): + if len(member.lvalues) > 1: + return None + + arg_name = member.lvalues[0].name + arg_name_as_id = arg_name + '_id' + + rvalue = member.rvalue + if isinstance(rvalue, CallExpr): + if not isinstance(rvalue.callee, RefExpr): + return None + + if rvalue.callee.fullname in REFERENCING_DB_FIELDS: + if rvalue.callee.fullname == 'django.db.models.fields.related.ForeignKey': + model_definition.cls.info.names[arg_name_as_id] = \ + model_definition.api.lookup_fully_qualified('builtins.int') + + if rvalue.callee.fullname == 'django.db.models.fields.related.OneToOneField': + if 'related_name' in rvalue.arg_names: + referred_to_model = rvalue.args[rvalue.arg_names.index('to')] + related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value + + if isinstance(referred_to_model, StrExpr): + referred_model_fullname = get_app_model(referred_to_model.value) + else: + referred_model_fullname = referred_to_model.fullname + + referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname) + referred_model.node.names[related_arg_value] = \ + model_definition.api.lookup_fully_qualified_or_none(model_definition.cls.fullname) + + return save_referred_to_model_in_metadata(rvalue) + + +class BaseDjangoModelsPlugin(Plugin): + model_registry = DjangoModelsRegistry() + + def get_base_class_hook(self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + if fullname in self.model_registry: + return CollectModelsInformation(self.model_registry) + + return None diff --git a/mypy_django_plugin/plugins/field_to_python_type.py b/mypy_django_plugin/plugins/field_to_python_type.py new file mode 100644 index 0000000..8f71bab --- /dev/null +++ b/mypy_django_plugin/plugins/field_to_python_type.py @@ -0,0 +1,60 @@ +from typing import Optional, Callable + +from mypy.plugin import AttributeContext +from mypy.types import Type, Instance + +from mypy_django_plugin.helpers import lookup_django_model +from mypy_django_plugin.model_classes import DjangoModelsRegistry +from mypy_django_plugin.plugins.base import BaseDjangoModelsPlugin + +# mapping between field types and plain python types +DB_FIELDS_TO_TYPES = { + 'django.db.models.fields.CharField': 'builtins.str', + 'django.db.models.fields.TextField': 'builtins.str', + 'django.db.models.fields.BooleanField': 'builtins.bool', + # 'django.db.models.fields.NullBooleanField': 'typing.Optional[builtins.bool]', + 'django.db.models.fields.IntegerField': 'builtins.int', + 'django.db.models.fields.AutoField': 'builtins.int', + 'django.db.models.fields.FloatField': 'builtins.float', + 'django.contrib.postgres.fields.jsonb.JSONField': 'builtins.dict', + 'django.contrib.postgres.fields.array.ArrayField': 'typing.Iterable' +} + + +class DetermineFieldPythonTypeCallback(object): + def __init__(self, models_registry: DjangoModelsRegistry): + self.models_registry = models_registry + + def __call__(self, attr_context: AttributeContext) -> Type: + default_attr_type = attr_context.default_attr_type + + if isinstance(default_attr_type, Instance): + attr_type_fullname = default_attr_type.type.fullname() + if attr_type_fullname in DB_FIELDS_TO_TYPES: + return attr_context.api.named_type(DB_FIELDS_TO_TYPES[attr_type_fullname]) + + if 'base' in default_attr_type.type.metadata: + referred_base_model = default_attr_type.type.metadata['base'] + try: + node = lookup_django_model(attr_context.api, referred_base_model).node + return Instance(node, []) + except AssertionError as e: + print(e) + print('name to lookup:', referred_base_model) + pass + + return default_attr_type + + +class FieldToPythonTypePlugin(BaseDjangoModelsPlugin): + def get_attribute_hook(self, fullname: str + ) -> Optional[Callable[[AttributeContext], Type]]: + classname, _, attrname = fullname.rpartition('.') + if classname and classname in self.model_registry: + return DetermineFieldPythonTypeCallback(self.model_registry) + + return None + + +def plugin(version): + return FieldToPythonTypePlugin diff --git a/setup.py b/setup.py index 1c92b23..05c939e 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,6 @@ setup( version="0.1.0", license='BSD', install_requires='Django>=2.1.1', - packages=['django-stubs'], - package_data=find_stubs('django-stubs') + packages=['mypy_django_plugin'] + # package_data=find_stubs('django-stubs') )