support for models.Model.objects, abstract mixins

This commit is contained in:
Maxim Kurnikov
2018-11-14 02:33:50 +03:00
parent 9a68263257
commit 41cc79b957
13 changed files with 200 additions and 103 deletions

View File

@@ -6,7 +6,8 @@ from .fields import (AutoField as AutoField,
CharField as CharField, CharField as CharField,
Field as Field, Field as Field,
SlugField as SlugField, SlugField as SlugField,
TextField as TextField) TextField as TextField,
BooleanField as BooleanField)
from .fields.related import (ForeignKey as ForeignKey, from .fields.related import (ForeignKey as ForeignKey,
OneToOneField as OneToOneField) OneToOneField as OneToOneField)

View File

@@ -1,38 +1,55 @@
from mypy.checker import TypeChecker from typing import Dict, Optional, Type, Tuple, NamedTuple
from mypy.nodes import SymbolTableNode, Var
from mypy.nodes import SymbolTableNode, Var, Expression, MemberExpr
from mypy.plugin import FunctionContext
from mypy.types import Instance
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
def is_class_variable(symbol_table_node: SymbolTableNode) -> bool: def create_new_symtable_node(name: str, kind: int, instance: Instance) -> SymbolTableNode:
# MDEF: class member definition new_var = Var(name, instance)
is_class_variable = symbol_table_node.kind == 2 and type(symbol_table_node.node) == Var new_var.info = instance.type
if not is_class_variable:
return False
return True return SymbolTableNode(kind, new_var,
plugin_generated=True)
def lookup_django_model(mypy_api: TypeChecker, fullname: str) -> SymbolTableNode: Argument = NamedTuple('Argument', fields=[
module, _, model_name = fullname.rpartition('.') ('arg', Expression),
try: ('arg_type', Type)
return mypy_api.modules[module].names[model_name] ])
except KeyError:
return mypy_api.lookup_qualified('typing.Any')
# return mypy_api.modules['typing'].names['Any']
def get_app_model(model_name: str) -> str: def get_call_signature_or_none(ctx: FunctionContext) -> Optional[Dict[str, Argument]]:
import os arg_names = ctx.context.arg_names
os.environ.setdefault('SITE_URL', 'https://localhost')
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'server._config.settings.local')
import django result: Dict[str, Argument] = {}
django.setup() positional_args_only = []
positional_arg_types_only = []
for arg, arg_name, arg_type in zip(ctx.args, arg_names, ctx.arg_types):
if arg_name is None:
positional_args_only.append(arg)
positional_arg_types_only.append(arg_type)
continue
from django.apps import apps if len(arg) == 0 or len(arg_type) == 0:
continue
try: result[arg_name] = (arg[0], arg_type[0])
app_name, model_name = model_name.rsplit('.', maxsplit=1)
model = apps.get_model(app_name, model_name) callee = ctx.context.callee
return model.__module__ + '.' + model_name if '__init__' not in callee.node.names:
except ValueError: return None
return model_name
init_type = callee.node.names['__init__'].type
arg_names = init_type.arg_names[1:]
for arg, arg_name, arg_type in zip(positional_args_only,
arg_names[:len(positional_args_only)],
positional_arg_types_only):
result[arg_name] = (arg[0], arg_type[0])
return result

View File

@@ -0,0 +1,44 @@
from typing import Callable, Optional
from mypy.plugin import Plugin, FunctionContext, ClassDefContext
from mypy.types import Type
from mypy_django_plugin import helpers
from mypy_django_plugin.plugins.objects_queryset import set_objects_queryset_to_model_class
from mypy_django_plugin.plugins.postgres_fields import determine_type_of_array_field
from mypy_django_plugin.plugins.related_fields import set_related_name_instance_for_onetoonefield, \
set_related_name_manager_for_foreign_key, set_fieldname_attrs_for_related_fields
base_model_classes = {helpers.MODEL_CLASS_FULLNAME}
def transform_model_class(ctx: ClassDefContext) -> None:
base_model_classes.add(ctx.cls.fullname)
set_fieldname_attrs_for_related_fields(ctx)
set_objects_queryset_to_model_class(ctx)
class DjangoPlugin(Plugin):
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
if fullname == helpers.FOREIGN_KEY_FULLNAME:
return set_related_name_manager_for_foreign_key
if fullname == helpers.ONETOONE_FIELD_FULLNAME:
return set_related_name_instance_for_onetoonefield
if fullname == 'django.contrib.postgres.fields.array.ArrayField':
return determine_type_of_array_field
return None
def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if fullname in base_model_classes:
return transform_model_class
return None
def plugin(version):
return DjangoPlugin

View File

@@ -0,0 +1,27 @@
from typing import cast
from mypy.nodes import MDEF, AssignmentStmt
from mypy.plugin import ClassDefContext
from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import Instance
from mypy_django_plugin import helpers
def set_objects_queryset_to_model_class(ctx: ClassDefContext) -> None:
if 'objects' in ctx.cls.info.names:
return
api = cast(SemanticAnalyzerPass2, ctx.api)
metaclass_node = ctx.cls.info.names.get('Meta')
if metaclass_node is not None:
for stmt in metaclass_node.node.defn.defs.body:
if (isinstance(stmt, AssignmentStmt) and len(stmt.lvalues) == 1
and stmt.lvalues[0].name == 'abstract'):
is_abstract = api.parse_bool(stmt.rvalue)
if is_abstract:
return
typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME, args=[Instance(ctx.cls.info, [])])
new_objects_node = helpers.create_new_symtable_node('objects', MDEF, instance=typ)
ctx.cls.info.names['objects'] = new_objects_node

View File

@@ -1,25 +1,14 @@
from typing import Optional, Callable from mypy.plugin import FunctionContext
from mypy.plugin import Plugin, FunctionContext
from mypy.types import Type from mypy.types import Type
from mypy_django_plugin import helpers
def determine_type_of_array_field(ctx: FunctionContext) -> Type: def determine_type_of_array_field(ctx: FunctionContext) -> Type:
assert 'base_field' in ctx.context.arg_names signature = helpers.get_call_signature_or_none(ctx)
base_field_arg_index = ctx.context.arg_names.index('base_field') if signature is None:
base_field_arg_type = ctx.arg_types[base_field_arg_index][0] return ctx.default_return_type
_, base_field_arg_type = signature['base_field']
return ctx.api.named_generic_type(ctx.context.callee.fullname, return ctx.api.named_generic_type(ctx.context.callee.fullname,
args=[base_field_arg_type.type.names['__get__'].type.ret_type]) args=[base_field_arg_type.type.names['__get__'].type.ret_type])
class PostgresFieldsPlugin(Plugin):
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
if fullname == 'django.contrib.postgres.fields.array.ArrayField':
return determine_type_of_array_field
return None
def plugin(version):
return PostgresFieldsPlugin

View File

@@ -1,37 +1,32 @@
from typing import Optional, Callable, cast from typing import Optional, cast
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.nodes import Var, MDEF, SymbolTableNode, TypeInfo, SymbolTable from mypy.nodes import TypeInfo, SymbolTable, MDEF, AssignmentStmt, MemberExpr
from mypy.plugin import Plugin, FunctionContext, ClassDefContext from mypy.plugin import FunctionContext, ClassDefContext
from mypy.semanal import SemanticAnalyzerPass2 from mypy.types import Type, CallableType, Instance, AnyType
from mypy.types import Type, CallableType, Instance
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey' from mypy_django_plugin import helpers
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
def extract_to_value_type(ctx: FunctionContext) -> Optional[Instance]: def extract_to_value_type(ctx: FunctionContext) -> Optional[Instance]:
assert 'to' in ctx.context.arg_names signature = helpers.get_call_signature_or_none(ctx)
to_arg_value = ctx.arg_types[ctx.context.arg_names.index('to')][0] if signature is None or 'to' not in signature:
if not isinstance(to_arg_value, CallableType):
return None return None
return to_arg_value.ret_type arg, arg_type = signature['to']
if not isinstance(arg_type, CallableType):
return None
return arg_type.ret_type
def extract_related_name_value(ctx: FunctionContext) -> str: def extract_related_name_value(ctx: FunctionContext) -> str:
return ctx.context.args[ctx.context.arg_names.index('related_name')].value return ctx.context.args[ctx.context.arg_names.index('related_name')].value
def create_new_symtable_node_for_class_member(name: str, instance: Instance) -> SymbolTableNode:
new_var = Var(name, instance)
new_var.info = instance.type
return SymbolTableNode(MDEF, new_var, plugin_generated=True)
def add_new_class_member(klass_typeinfo: TypeInfo, name: str, new_member_instance: Instance) -> None: def add_new_class_member(klass_typeinfo: TypeInfo, name: str, new_member_instance: Instance) -> None:
klass_typeinfo.names[name] = create_new_symtable_node_for_class_member(name, klass_typeinfo.names[name] = helpers.create_new_symtable_node(name,
kind=MDEF,
instance=new_member_instance) instance=new_member_instance)
@@ -47,8 +42,12 @@ def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type:
return ctx.default_return_type return ctx.default_return_type
related_name = extract_related_name_value(ctx) related_name = extract_related_name_value(ctx)
queryset_type = api.named_generic_type('django.db.models.QuerySet', queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME,
args=[Instance(outer_class_info, [])]) args=[Instance(outer_class_info, [])])
if isinstance(referred_to, AnyType):
# referred_to defined as string, which is unsupported for now
return ctx.default_return_type
add_new_class_member(referred_to.type, add_new_class_member(referred_to.type,
related_name, queryset_type) related_name, queryset_type)
return ctx.default_return_type return ctx.default_return_type
@@ -75,35 +74,20 @@ def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None:
api = ctx.api api = ctx.api
new_symtable_nodes = SymbolTable() new_symtable_nodes = SymbolTable()
for (name, symtable_node), assignment_stmt in zip(ctx.cls.info.names.items(), ctx.cls.defs.body): for (name, symtable_node), stmt in zip(ctx.cls.info.names.items(), ctx.cls.defs.body):
rvalue_callee = assignment_stmt.rvalue.callee if not isinstance(stmt, AssignmentStmt):
if rvalue_callee.fullname in {FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME}: continue
if not hasattr(stmt.rvalue, 'callee'):
continue
rvalue_callee = stmt.rvalue.callee
if rvalue_callee.fullname in {helpers.FOREIGN_KEY_FULLNAME,
helpers.ONETOONE_FIELD_FULLNAME}:
name += '_id' name += '_id'
new_node = create_new_symtable_node_for_class_member(name, new_node = helpers.create_new_symtable_node(name,
kind=MDEF,
instance=api.named_type('__builtins__.int')) instance=api.named_type('__builtins__.int'))
new_symtable_nodes[name] = new_node new_symtable_nodes[name] = new_node
for name, node in new_symtable_nodes.items(): for name, node in new_symtable_nodes.items():
ctx.cls.info.names[name] = node ctx.cls.info.names[name] = node
class RelatedFieldsPlugin(Plugin):
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
if fullname == 'django.db.models.fields.related.ForeignKey':
return set_related_name_manager_for_foreign_key
if fullname == 'django.db.models.fields.related.OneToOneField':
return set_related_name_instance_for_onetoonefield
return None
def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if fullname == 'django.db.models.base.Model':
return set_fieldname_attrs_for_related_fields
return None
def plugin(version):
return RelatedFieldsPlugin

View File

@@ -3,6 +3,5 @@
testpaths = test testpaths = test
python_files = test*.py python_files = test*.py
addopts = addopts =
; -nauto
--tb=native --tb=native
-v -v

View File

@@ -18,7 +18,7 @@ setup(
author_email="maxim.kurnikov@gmail.com", author_email="maxim.kurnikov@gmail.com",
version="0.1.0", version="0.1.0",
license='BSD', license='BSD',
install_requires='Django>=2.1.1', install_requires=['Django>=2.1.1'],
packages=['mypy_django_plugin'] packages=['mypy_django_plugin']
# package_data=find_stubs('django-stubs') # package_data=find_stubs('django-stubs')
) )

View File

@@ -1,4 +1,3 @@
[mypy] [mypy]
plugins = plugins =
mypy_django_plugin.plugins.postgres_fields, mypy_django_plugin.plugin
mypy_django_plugin.plugins.related_fields

View File

@@ -1,7 +1,6 @@
[case testBasicModelFields] [case testBasicModelFields]
from django.db import models from django.db import models
class User(models.Model): class User(models.Model):
id = models.AutoField(primary_key=True) id = models.AutoField(primary_key=True)
small_int = models.SmallIntegerField() small_int = models.SmallIntegerField()

View File

@@ -56,3 +56,16 @@ class Profile(models.Model):
profile = Profile() profile = Profile()
reveal_type(profile.user_id) # E: Revealed type is 'builtins.int' reveal_type(profile.user_id) # E: Revealed type is 'builtins.int'
[out] [out]
[case testToParameterKeywordMaybeAbsent]
from django.db import models
class User(models.Model):
pass
class Profile(models.Model):
user = models.OneToOneField(User, on_delete=models.CASCADE, related_name='profile')
reveal_type(User().profile) # E: Revealed type is 'main.Profile'
[out]

View File

@@ -0,0 +1,24 @@
[case testEveryModelClassHasDefaultObjectsQuerySetAvailableAsAttribute]
from django.db import models
class User(models.Model):
pass
reveal_type(User.objects) # E: Revealed type is 'django.db.models.query.QuerySet[main.User]'
[out]
[case testGetReturnsModelInstanceIfInheritedFromAbstractMixin]
from django.db import models
class ModelMixin(models.Model):
class Meta:
abstract = True
class User(ModelMixin):
pass
reveal_type(ModelMixin.objects)
reveal_type(User.objects.get()) # E: Revealed type is 'main.User*'
[out]
main:10: error: Revealed type is 'Any'
main:10: error: "Type[ModelMixin]" has no attribute "objects"

View File

@@ -14,9 +14,10 @@ MYPY_INI_PATH = ROOT_DIR / 'test' / 'plugins.ini'
class DjangoTestSuite(DataSuite): class DjangoTestSuite(DataSuite):
files = [ files = [
'check-objects-queryset.test',
'check-model-fields.test', 'check-model-fields.test',
'check-postgres-fields.test', 'check-postgres-fields.test',
'check-model-relations.test', 'check-model-relations.test'
] ]
data_prefix = str(TEST_DATA_DIR) data_prefix = str(TEST_DATA_DIR)