mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-07 04:34:29 +08:00
support for models.Model.objects, abstract mixins
This commit is contained in:
@@ -6,7 +6,8 @@ from .fields import (AutoField as AutoField,
|
|||||||
CharField as CharField,
|
CharField as CharField,
|
||||||
Field as Field,
|
Field as Field,
|
||||||
SlugField as SlugField,
|
SlugField as SlugField,
|
||||||
TextField as TextField)
|
TextField as TextField,
|
||||||
|
BooleanField as BooleanField)
|
||||||
|
|
||||||
from .fields.related import (ForeignKey as ForeignKey,
|
from .fields.related import (ForeignKey as ForeignKey,
|
||||||
OneToOneField as OneToOneField)
|
OneToOneField as OneToOneField)
|
||||||
|
|||||||
@@ -1,38 +1,55 @@
|
|||||||
from mypy.checker import TypeChecker
|
from typing import Dict, Optional, Type, Tuple, NamedTuple
|
||||||
from mypy.nodes import SymbolTableNode, Var
|
|
||||||
|
from mypy.nodes import SymbolTableNode, Var, Expression, MemberExpr
|
||||||
|
from mypy.plugin import FunctionContext
|
||||||
|
from mypy.types import Instance
|
||||||
|
|
||||||
|
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
|
||||||
|
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
|
||||||
|
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
|
||||||
|
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
|
||||||
|
|
||||||
|
|
||||||
def is_class_variable(symbol_table_node: SymbolTableNode) -> bool:
|
def create_new_symtable_node(name: str, kind: int, instance: Instance) -> SymbolTableNode:
|
||||||
# MDEF: class member definition
|
new_var = Var(name, instance)
|
||||||
is_class_variable = symbol_table_node.kind == 2 and type(symbol_table_node.node) == Var
|
new_var.info = instance.type
|
||||||
if not is_class_variable:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
return SymbolTableNode(kind, new_var,
|
||||||
|
plugin_generated=True)
|
||||||
|
|
||||||
|
|
||||||
def lookup_django_model(mypy_api: TypeChecker, fullname: str) -> SymbolTableNode:
|
Argument = NamedTuple('Argument', fields=[
|
||||||
module, _, model_name = fullname.rpartition('.')
|
('arg', Expression),
|
||||||
try:
|
('arg_type', Type)
|
||||||
return mypy_api.modules[module].names[model_name]
|
])
|
||||||
except KeyError:
|
|
||||||
return mypy_api.lookup_qualified('typing.Any')
|
|
||||||
# return mypy_api.modules['typing'].names['Any']
|
|
||||||
|
|
||||||
|
|
||||||
def get_app_model(model_name: str) -> str:
|
def get_call_signature_or_none(ctx: FunctionContext) -> Optional[Dict[str, Argument]]:
|
||||||
import os
|
arg_names = ctx.context.arg_names
|
||||||
os.environ.setdefault('SITE_URL', 'https://localhost')
|
|
||||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'server._config.settings.local')
|
|
||||||
|
|
||||||
import django
|
result: Dict[str, Argument] = {}
|
||||||
django.setup()
|
positional_args_only = []
|
||||||
|
positional_arg_types_only = []
|
||||||
|
for arg, arg_name, arg_type in zip(ctx.args, arg_names, ctx.arg_types):
|
||||||
|
if arg_name is None:
|
||||||
|
positional_args_only.append(arg)
|
||||||
|
positional_arg_types_only.append(arg_type)
|
||||||
|
continue
|
||||||
|
|
||||||
from django.apps import apps
|
if len(arg) == 0 or len(arg_type) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
result[arg_name] = (arg[0], arg_type[0])
|
||||||
app_name, model_name = model_name.rsplit('.', maxsplit=1)
|
|
||||||
model = apps.get_model(app_name, model_name)
|
callee = ctx.context.callee
|
||||||
return model.__module__ + '.' + model_name
|
if '__init__' not in callee.node.names:
|
||||||
except ValueError:
|
return None
|
||||||
return model_name
|
|
||||||
|
init_type = callee.node.names['__init__'].type
|
||||||
|
arg_names = init_type.arg_names[1:]
|
||||||
|
for arg, arg_name, arg_type in zip(positional_args_only,
|
||||||
|
arg_names[:len(positional_args_only)],
|
||||||
|
positional_arg_types_only):
|
||||||
|
result[arg_name] = (arg[0], arg_type[0])
|
||||||
|
|
||||||
|
return result
|
||||||
|
|||||||
44
mypy_django_plugin/plugin.py
Normal file
44
mypy_django_plugin/plugin.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from mypy.plugin import Plugin, FunctionContext, ClassDefContext
|
||||||
|
from mypy.types import Type
|
||||||
|
|
||||||
|
from mypy_django_plugin import helpers
|
||||||
|
from mypy_django_plugin.plugins.objects_queryset import set_objects_queryset_to_model_class
|
||||||
|
from mypy_django_plugin.plugins.postgres_fields import determine_type_of_array_field
|
||||||
|
from mypy_django_plugin.plugins.related_fields import set_related_name_instance_for_onetoonefield, \
|
||||||
|
set_related_name_manager_for_foreign_key, set_fieldname_attrs_for_related_fields
|
||||||
|
|
||||||
|
|
||||||
|
base_model_classes = {helpers.MODEL_CLASS_FULLNAME}
|
||||||
|
|
||||||
|
|
||||||
|
def transform_model_class(ctx: ClassDefContext) -> None:
|
||||||
|
base_model_classes.add(ctx.cls.fullname)
|
||||||
|
|
||||||
|
set_fieldname_attrs_for_related_fields(ctx)
|
||||||
|
set_objects_queryset_to_model_class(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
class DjangoPlugin(Plugin):
|
||||||
|
def get_function_hook(self, fullname: str
|
||||||
|
) -> Optional[Callable[[FunctionContext], Type]]:
|
||||||
|
if fullname == helpers.FOREIGN_KEY_FULLNAME:
|
||||||
|
return set_related_name_manager_for_foreign_key
|
||||||
|
|
||||||
|
if fullname == helpers.ONETOONE_FIELD_FULLNAME:
|
||||||
|
return set_related_name_instance_for_onetoonefield
|
||||||
|
|
||||||
|
if fullname == 'django.contrib.postgres.fields.array.ArrayField':
|
||||||
|
return determine_type_of_array_field
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_base_class_hook(self, fullname: str
|
||||||
|
) -> Optional[Callable[[ClassDefContext], None]]:
|
||||||
|
if fullname in base_model_classes:
|
||||||
|
return transform_model_class
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def plugin(version):
|
||||||
|
return DjangoPlugin
|
||||||
27
mypy_django_plugin/plugins/objects_queryset.py
Normal file
27
mypy_django_plugin/plugins/objects_queryset.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from mypy.nodes import MDEF, AssignmentStmt
|
||||||
|
from mypy.plugin import ClassDefContext
|
||||||
|
from mypy.semanal import SemanticAnalyzerPass2
|
||||||
|
from mypy.types import Instance
|
||||||
|
|
||||||
|
from mypy_django_plugin import helpers
|
||||||
|
|
||||||
|
|
||||||
|
def set_objects_queryset_to_model_class(ctx: ClassDefContext) -> None:
|
||||||
|
if 'objects' in ctx.cls.info.names:
|
||||||
|
return
|
||||||
|
api = cast(SemanticAnalyzerPass2, ctx.api)
|
||||||
|
|
||||||
|
metaclass_node = ctx.cls.info.names.get('Meta')
|
||||||
|
if metaclass_node is not None:
|
||||||
|
for stmt in metaclass_node.node.defn.defs.body:
|
||||||
|
if (isinstance(stmt, AssignmentStmt) and len(stmt.lvalues) == 1
|
||||||
|
and stmt.lvalues[0].name == 'abstract'):
|
||||||
|
is_abstract = api.parse_bool(stmt.rvalue)
|
||||||
|
if is_abstract:
|
||||||
|
return
|
||||||
|
|
||||||
|
typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME, args=[Instance(ctx.cls.info, [])])
|
||||||
|
new_objects_node = helpers.create_new_symtable_node('objects', MDEF, instance=typ)
|
||||||
|
ctx.cls.info.names['objects'] = new_objects_node
|
||||||
@@ -1,25 +1,14 @@
|
|||||||
from typing import Optional, Callable
|
from mypy.plugin import FunctionContext
|
||||||
|
|
||||||
from mypy.plugin import Plugin, FunctionContext
|
|
||||||
from mypy.types import Type
|
from mypy.types import Type
|
||||||
|
|
||||||
|
from mypy_django_plugin import helpers
|
||||||
|
|
||||||
|
|
||||||
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
|
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
|
||||||
assert 'base_field' in ctx.context.arg_names
|
signature = helpers.get_call_signature_or_none(ctx)
|
||||||
base_field_arg_index = ctx.context.arg_names.index('base_field')
|
if signature is None:
|
||||||
base_field_arg_type = ctx.arg_types[base_field_arg_index][0]
|
return ctx.default_return_type
|
||||||
|
|
||||||
|
_, base_field_arg_type = signature['base_field']
|
||||||
return ctx.api.named_generic_type(ctx.context.callee.fullname,
|
return ctx.api.named_generic_type(ctx.context.callee.fullname,
|
||||||
args=[base_field_arg_type.type.names['__get__'].type.ret_type])
|
args=[base_field_arg_type.type.names['__get__'].type.ret_type])
|
||||||
|
|
||||||
|
|
||||||
class PostgresFieldsPlugin(Plugin):
|
|
||||||
def get_function_hook(self, fullname: str
|
|
||||||
) -> Optional[Callable[[FunctionContext], Type]]:
|
|
||||||
if fullname == 'django.contrib.postgres.fields.array.ArrayField':
|
|
||||||
return determine_type_of_array_field
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def plugin(version):
|
|
||||||
return PostgresFieldsPlugin
|
|
||||||
|
|||||||
@@ -1,37 +1,32 @@
|
|||||||
from typing import Optional, Callable, cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
from mypy.checker import TypeChecker
|
from mypy.checker import TypeChecker
|
||||||
from mypy.nodes import Var, MDEF, SymbolTableNode, TypeInfo, SymbolTable
|
from mypy.nodes import TypeInfo, SymbolTable, MDEF, AssignmentStmt, MemberExpr
|
||||||
from mypy.plugin import Plugin, FunctionContext, ClassDefContext
|
from mypy.plugin import FunctionContext, ClassDefContext
|
||||||
from mypy.semanal import SemanticAnalyzerPass2
|
from mypy.types import Type, CallableType, Instance, AnyType
|
||||||
from mypy.types import Type, CallableType, Instance
|
|
||||||
|
|
||||||
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
|
from mypy_django_plugin import helpers
|
||||||
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
|
|
||||||
|
|
||||||
|
|
||||||
def extract_to_value_type(ctx: FunctionContext) -> Optional[Instance]:
|
def extract_to_value_type(ctx: FunctionContext) -> Optional[Instance]:
|
||||||
assert 'to' in ctx.context.arg_names
|
signature = helpers.get_call_signature_or_none(ctx)
|
||||||
to_arg_value = ctx.arg_types[ctx.context.arg_names.index('to')][0]
|
if signature is None or 'to' not in signature:
|
||||||
if not isinstance(to_arg_value, CallableType):
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return to_arg_value.ret_type
|
arg, arg_type = signature['to']
|
||||||
|
if not isinstance(arg_type, CallableType):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return arg_type.ret_type
|
||||||
|
|
||||||
|
|
||||||
def extract_related_name_value(ctx: FunctionContext) -> str:
|
def extract_related_name_value(ctx: FunctionContext) -> str:
|
||||||
return ctx.context.args[ctx.context.arg_names.index('related_name')].value
|
return ctx.context.args[ctx.context.arg_names.index('related_name')].value
|
||||||
|
|
||||||
|
|
||||||
def create_new_symtable_node_for_class_member(name: str, instance: Instance) -> SymbolTableNode:
|
|
||||||
new_var = Var(name, instance)
|
|
||||||
new_var.info = instance.type
|
|
||||||
|
|
||||||
return SymbolTableNode(MDEF, new_var, plugin_generated=True)
|
|
||||||
|
|
||||||
|
|
||||||
def add_new_class_member(klass_typeinfo: TypeInfo, name: str, new_member_instance: Instance) -> None:
|
def add_new_class_member(klass_typeinfo: TypeInfo, name: str, new_member_instance: Instance) -> None:
|
||||||
klass_typeinfo.names[name] = create_new_symtable_node_for_class_member(name,
|
klass_typeinfo.names[name] = helpers.create_new_symtable_node(name,
|
||||||
|
kind=MDEF,
|
||||||
instance=new_member_instance)
|
instance=new_member_instance)
|
||||||
|
|
||||||
|
|
||||||
@@ -47,8 +42,12 @@ def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type:
|
|||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
related_name = extract_related_name_value(ctx)
|
related_name = extract_related_name_value(ctx)
|
||||||
queryset_type = api.named_generic_type('django.db.models.QuerySet',
|
queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME,
|
||||||
args=[Instance(outer_class_info, [])])
|
args=[Instance(outer_class_info, [])])
|
||||||
|
if isinstance(referred_to, AnyType):
|
||||||
|
# referred_to defined as string, which is unsupported for now
|
||||||
|
return ctx.default_return_type
|
||||||
|
|
||||||
add_new_class_member(referred_to.type,
|
add_new_class_member(referred_to.type,
|
||||||
related_name, queryset_type)
|
related_name, queryset_type)
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
@@ -75,35 +74,20 @@ def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None:
|
|||||||
api = ctx.api
|
api = ctx.api
|
||||||
|
|
||||||
new_symtable_nodes = SymbolTable()
|
new_symtable_nodes = SymbolTable()
|
||||||
for (name, symtable_node), assignment_stmt in zip(ctx.cls.info.names.items(), ctx.cls.defs.body):
|
for (name, symtable_node), stmt in zip(ctx.cls.info.names.items(), ctx.cls.defs.body):
|
||||||
rvalue_callee = assignment_stmt.rvalue.callee
|
if not isinstance(stmt, AssignmentStmt):
|
||||||
if rvalue_callee.fullname in {FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME}:
|
continue
|
||||||
|
if not hasattr(stmt.rvalue, 'callee'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
rvalue_callee = stmt.rvalue.callee
|
||||||
|
if rvalue_callee.fullname in {helpers.FOREIGN_KEY_FULLNAME,
|
||||||
|
helpers.ONETOONE_FIELD_FULLNAME}:
|
||||||
name += '_id'
|
name += '_id'
|
||||||
new_node = create_new_symtable_node_for_class_member(name,
|
new_node = helpers.create_new_symtable_node(name,
|
||||||
|
kind=MDEF,
|
||||||
instance=api.named_type('__builtins__.int'))
|
instance=api.named_type('__builtins__.int'))
|
||||||
new_symtable_nodes[name] = new_node
|
new_symtable_nodes[name] = new_node
|
||||||
|
|
||||||
for name, node in new_symtable_nodes.items():
|
for name, node in new_symtable_nodes.items():
|
||||||
ctx.cls.info.names[name] = node
|
ctx.cls.info.names[name] = node
|
||||||
|
|
||||||
|
|
||||||
class RelatedFieldsPlugin(Plugin):
|
|
||||||
def get_function_hook(self, fullname: str
|
|
||||||
) -> Optional[Callable[[FunctionContext], Type]]:
|
|
||||||
if fullname == 'django.db.models.fields.related.ForeignKey':
|
|
||||||
return set_related_name_manager_for_foreign_key
|
|
||||||
|
|
||||||
if fullname == 'django.db.models.fields.related.OneToOneField':
|
|
||||||
return set_related_name_instance_for_onetoonefield
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_base_class_hook(self, fullname: str
|
|
||||||
) -> Optional[Callable[[ClassDefContext], None]]:
|
|
||||||
if fullname == 'django.db.models.base.Model':
|
|
||||||
return set_fieldname_attrs_for_related_fields
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def plugin(version):
|
|
||||||
return RelatedFieldsPlugin
|
|
||||||
|
|||||||
@@ -3,6 +3,5 @@
|
|||||||
testpaths = test
|
testpaths = test
|
||||||
python_files = test*.py
|
python_files = test*.py
|
||||||
addopts =
|
addopts =
|
||||||
; -nauto
|
|
||||||
--tb=native
|
--tb=native
|
||||||
-v
|
-v
|
||||||
2
setup.py
2
setup.py
@@ -18,7 +18,7 @@ setup(
|
|||||||
author_email="maxim.kurnikov@gmail.com",
|
author_email="maxim.kurnikov@gmail.com",
|
||||||
version="0.1.0",
|
version="0.1.0",
|
||||||
license='BSD',
|
license='BSD',
|
||||||
install_requires='Django>=2.1.1',
|
install_requires=['Django>=2.1.1'],
|
||||||
packages=['mypy_django_plugin']
|
packages=['mypy_django_plugin']
|
||||||
# package_data=find_stubs('django-stubs')
|
# package_data=find_stubs('django-stubs')
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
[mypy]
|
[mypy]
|
||||||
plugins =
|
plugins =
|
||||||
mypy_django_plugin.plugins.postgres_fields,
|
mypy_django_plugin.plugin
|
||||||
mypy_django_plugin.plugins.related_fields
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
[case testBasicModelFields]
|
[case testBasicModelFields]
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
|
||||||
|
|
||||||
class User(models.Model):
|
class User(models.Model):
|
||||||
id = models.AutoField(primary_key=True)
|
id = models.AutoField(primary_key=True)
|
||||||
small_int = models.SmallIntegerField()
|
small_int = models.SmallIntegerField()
|
||||||
|
|||||||
@@ -56,3 +56,16 @@ class Profile(models.Model):
|
|||||||
profile = Profile()
|
profile = Profile()
|
||||||
reveal_type(profile.user_id) # E: Revealed type is 'builtins.int'
|
reveal_type(profile.user_id) # E: Revealed type is 'builtins.int'
|
||||||
[out]
|
[out]
|
||||||
|
|
||||||
|
[case testToParameterKeywordMaybeAbsent]
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
class User(models.Model):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Profile(models.Model):
|
||||||
|
user = models.OneToOneField(User, on_delete=models.CASCADE, related_name='profile')
|
||||||
|
|
||||||
|
reveal_type(User().profile) # E: Revealed type is 'main.Profile'
|
||||||
|
[out]
|
||||||
|
|
||||||
|
|||||||
24
test/test-data/check-objects-queryset.test
Normal file
24
test/test-data/check-objects-queryset.test
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
[case testEveryModelClassHasDefaultObjectsQuerySetAvailableAsAttribute]
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
class User(models.Model):
|
||||||
|
pass
|
||||||
|
|
||||||
|
reveal_type(User.objects) # E: Revealed type is 'django.db.models.query.QuerySet[main.User]'
|
||||||
|
[out]
|
||||||
|
|
||||||
|
[case testGetReturnsModelInstanceIfInheritedFromAbstractMixin]
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
class ModelMixin(models.Model):
|
||||||
|
class Meta:
|
||||||
|
abstract = True
|
||||||
|
|
||||||
|
class User(ModelMixin):
|
||||||
|
pass
|
||||||
|
|
||||||
|
reveal_type(ModelMixin.objects)
|
||||||
|
reveal_type(User.objects.get()) # E: Revealed type is 'main.User*'
|
||||||
|
[out]
|
||||||
|
main:10: error: Revealed type is 'Any'
|
||||||
|
main:10: error: "Type[ModelMixin]" has no attribute "objects"
|
||||||
@@ -14,9 +14,10 @@ MYPY_INI_PATH = ROOT_DIR / 'test' / 'plugins.ini'
|
|||||||
|
|
||||||
class DjangoTestSuite(DataSuite):
|
class DjangoTestSuite(DataSuite):
|
||||||
files = [
|
files = [
|
||||||
|
'check-objects-queryset.test',
|
||||||
'check-model-fields.test',
|
'check-model-fields.test',
|
||||||
'check-postgres-fields.test',
|
'check-postgres-fields.test',
|
||||||
'check-model-relations.test',
|
'check-model-relations.test'
|
||||||
]
|
]
|
||||||
data_prefix = str(TEST_DATA_DIR)
|
data_prefix = str(TEST_DATA_DIR)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user