mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-06 20:24:31 +08:00
support for models.Model.objects, abstract mixins
This commit is contained in:
@@ -6,7 +6,8 @@ from .fields import (AutoField as AutoField,
|
||||
CharField as CharField,
|
||||
Field as Field,
|
||||
SlugField as SlugField,
|
||||
TextField as TextField)
|
||||
TextField as TextField,
|
||||
BooleanField as BooleanField)
|
||||
|
||||
from .fields.related import (ForeignKey as ForeignKey,
|
||||
OneToOneField as OneToOneField)
|
||||
|
||||
@@ -1,38 +1,55 @@
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import SymbolTableNode, Var
|
||||
from typing import Dict, Optional, Type, Tuple, NamedTuple
|
||||
|
||||
from mypy.nodes import SymbolTableNode, Var, Expression, MemberExpr
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import Instance
|
||||
|
||||
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
|
||||
QUERYSET_CLASS_FULLNAME = 'django.db.models.query.QuerySet'
|
||||
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
|
||||
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
|
||||
|
||||
|
||||
def is_class_variable(symbol_table_node: SymbolTableNode) -> bool:
|
||||
# MDEF: class member definition
|
||||
is_class_variable = symbol_table_node.kind == 2 and type(symbol_table_node.node) == Var
|
||||
if not is_class_variable:
|
||||
return False
|
||||
def create_new_symtable_node(name: str, kind: int, instance: Instance) -> SymbolTableNode:
|
||||
new_var = Var(name, instance)
|
||||
new_var.info = instance.type
|
||||
|
||||
return True
|
||||
return SymbolTableNode(kind, new_var,
|
||||
plugin_generated=True)
|
||||
|
||||
|
||||
def lookup_django_model(mypy_api: TypeChecker, fullname: str) -> SymbolTableNode:
|
||||
module, _, model_name = fullname.rpartition('.')
|
||||
try:
|
||||
return mypy_api.modules[module].names[model_name]
|
||||
except KeyError:
|
||||
return mypy_api.lookup_qualified('typing.Any')
|
||||
# return mypy_api.modules['typing'].names['Any']
|
||||
Argument = NamedTuple('Argument', fields=[
|
||||
('arg', Expression),
|
||||
('arg_type', Type)
|
||||
])
|
||||
|
||||
|
||||
def get_app_model(model_name: str) -> str:
|
||||
import os
|
||||
os.environ.setdefault('SITE_URL', 'https://localhost')
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'server._config.settings.local')
|
||||
def get_call_signature_or_none(ctx: FunctionContext) -> Optional[Dict[str, Argument]]:
|
||||
arg_names = ctx.context.arg_names
|
||||
|
||||
import django
|
||||
django.setup()
|
||||
result: Dict[str, Argument] = {}
|
||||
positional_args_only = []
|
||||
positional_arg_types_only = []
|
||||
for arg, arg_name, arg_type in zip(ctx.args, arg_names, ctx.arg_types):
|
||||
if arg_name is None:
|
||||
positional_args_only.append(arg)
|
||||
positional_arg_types_only.append(arg_type)
|
||||
continue
|
||||
|
||||
from django.apps import apps
|
||||
if len(arg) == 0 or len(arg_type) == 0:
|
||||
continue
|
||||
|
||||
try:
|
||||
app_name, model_name = model_name.rsplit('.', maxsplit=1)
|
||||
model = apps.get_model(app_name, model_name)
|
||||
return model.__module__ + '.' + model_name
|
||||
except ValueError:
|
||||
return model_name
|
||||
result[arg_name] = (arg[0], arg_type[0])
|
||||
|
||||
callee = ctx.context.callee
|
||||
if '__init__' not in callee.node.names:
|
||||
return None
|
||||
|
||||
init_type = callee.node.names['__init__'].type
|
||||
arg_names = init_type.arg_names[1:]
|
||||
for arg, arg_name, arg_type in zip(positional_args_only,
|
||||
arg_names[:len(positional_args_only)],
|
||||
positional_arg_types_only):
|
||||
result[arg_name] = (arg[0], arg_type[0])
|
||||
|
||||
return result
|
||||
|
||||
44
mypy_django_plugin/plugin.py
Normal file
44
mypy_django_plugin/plugin.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Callable, Optional
|
||||
|
||||
from mypy.plugin import Plugin, FunctionContext, ClassDefContext
|
||||
from mypy.types import Type
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
from mypy_django_plugin.plugins.objects_queryset import set_objects_queryset_to_model_class
|
||||
from mypy_django_plugin.plugins.postgres_fields import determine_type_of_array_field
|
||||
from mypy_django_plugin.plugins.related_fields import set_related_name_instance_for_onetoonefield, \
|
||||
set_related_name_manager_for_foreign_key, set_fieldname_attrs_for_related_fields
|
||||
|
||||
|
||||
base_model_classes = {helpers.MODEL_CLASS_FULLNAME}
|
||||
|
||||
|
||||
def transform_model_class(ctx: ClassDefContext) -> None:
|
||||
base_model_classes.add(ctx.cls.fullname)
|
||||
|
||||
set_fieldname_attrs_for_related_fields(ctx)
|
||||
set_objects_queryset_to_model_class(ctx)
|
||||
|
||||
|
||||
class DjangoPlugin(Plugin):
|
||||
def get_function_hook(self, fullname: str
|
||||
) -> Optional[Callable[[FunctionContext], Type]]:
|
||||
if fullname == helpers.FOREIGN_KEY_FULLNAME:
|
||||
return set_related_name_manager_for_foreign_key
|
||||
|
||||
if fullname == helpers.ONETOONE_FIELD_FULLNAME:
|
||||
return set_related_name_instance_for_onetoonefield
|
||||
|
||||
if fullname == 'django.contrib.postgres.fields.array.ArrayField':
|
||||
return determine_type_of_array_field
|
||||
return None
|
||||
|
||||
def get_base_class_hook(self, fullname: str
|
||||
) -> Optional[Callable[[ClassDefContext], None]]:
|
||||
if fullname in base_model_classes:
|
||||
return transform_model_class
|
||||
return None
|
||||
|
||||
|
||||
def plugin(version):
|
||||
return DjangoPlugin
|
||||
27
mypy_django_plugin/plugins/objects_queryset.py
Normal file
27
mypy_django_plugin/plugins/objects_queryset.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from typing import cast
|
||||
|
||||
from mypy.nodes import MDEF, AssignmentStmt
|
||||
from mypy.plugin import ClassDefContext
|
||||
from mypy.semanal import SemanticAnalyzerPass2
|
||||
from mypy.types import Instance
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
|
||||
|
||||
def set_objects_queryset_to_model_class(ctx: ClassDefContext) -> None:
|
||||
if 'objects' in ctx.cls.info.names:
|
||||
return
|
||||
api = cast(SemanticAnalyzerPass2, ctx.api)
|
||||
|
||||
metaclass_node = ctx.cls.info.names.get('Meta')
|
||||
if metaclass_node is not None:
|
||||
for stmt in metaclass_node.node.defn.defs.body:
|
||||
if (isinstance(stmt, AssignmentStmt) and len(stmt.lvalues) == 1
|
||||
and stmt.lvalues[0].name == 'abstract'):
|
||||
is_abstract = api.parse_bool(stmt.rvalue)
|
||||
if is_abstract:
|
||||
return
|
||||
|
||||
typ = api.named_type_or_none(helpers.QUERYSET_CLASS_FULLNAME, args=[Instance(ctx.cls.info, [])])
|
||||
new_objects_node = helpers.create_new_symtable_node('objects', MDEF, instance=typ)
|
||||
ctx.cls.info.names['objects'] = new_objects_node
|
||||
@@ -1,25 +1,14 @@
|
||||
from typing import Optional, Callable
|
||||
|
||||
from mypy.plugin import Plugin, FunctionContext
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import Type
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
|
||||
|
||||
def determine_type_of_array_field(ctx: FunctionContext) -> Type:
|
||||
assert 'base_field' in ctx.context.arg_names
|
||||
base_field_arg_index = ctx.context.arg_names.index('base_field')
|
||||
base_field_arg_type = ctx.arg_types[base_field_arg_index][0]
|
||||
signature = helpers.get_call_signature_or_none(ctx)
|
||||
if signature is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
_, base_field_arg_type = signature['base_field']
|
||||
return ctx.api.named_generic_type(ctx.context.callee.fullname,
|
||||
args=[base_field_arg_type.type.names['__get__'].type.ret_type])
|
||||
|
||||
|
||||
class PostgresFieldsPlugin(Plugin):
|
||||
def get_function_hook(self, fullname: str
|
||||
) -> Optional[Callable[[FunctionContext], Type]]:
|
||||
if fullname == 'django.contrib.postgres.fields.array.ArrayField':
|
||||
return determine_type_of_array_field
|
||||
return None
|
||||
|
||||
|
||||
def plugin(version):
|
||||
return PostgresFieldsPlugin
|
||||
|
||||
@@ -1,38 +1,33 @@
|
||||
from typing import Optional, Callable, cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import Var, MDEF, SymbolTableNode, TypeInfo, SymbolTable
|
||||
from mypy.plugin import Plugin, FunctionContext, ClassDefContext
|
||||
from mypy.semanal import SemanticAnalyzerPass2
|
||||
from mypy.types import Type, CallableType, Instance
|
||||
from mypy.nodes import TypeInfo, SymbolTable, MDEF, AssignmentStmt, MemberExpr
|
||||
from mypy.plugin import FunctionContext, ClassDefContext
|
||||
from mypy.types import Type, CallableType, Instance, AnyType
|
||||
|
||||
FOREIGN_KEY_FULLNAME = 'django.db.models.fields.related.ForeignKey'
|
||||
ONETOONE_FIELD_FULLNAME = 'django.db.models.fields.related.OneToOneField'
|
||||
from mypy_django_plugin import helpers
|
||||
|
||||
|
||||
def extract_to_value_type(ctx: FunctionContext) -> Optional[Instance]:
|
||||
assert 'to' in ctx.context.arg_names
|
||||
to_arg_value = ctx.arg_types[ctx.context.arg_names.index('to')][0]
|
||||
if not isinstance(to_arg_value, CallableType):
|
||||
signature = helpers.get_call_signature_or_none(ctx)
|
||||
if signature is None or 'to' not in signature:
|
||||
return None
|
||||
|
||||
return to_arg_value.ret_type
|
||||
arg, arg_type = signature['to']
|
||||
if not isinstance(arg_type, CallableType):
|
||||
return None
|
||||
|
||||
return arg_type.ret_type
|
||||
|
||||
|
||||
def extract_related_name_value(ctx: FunctionContext) -> str:
|
||||
return ctx.context.args[ctx.context.arg_names.index('related_name')].value
|
||||
|
||||
|
||||
def create_new_symtable_node_for_class_member(name: str, instance: Instance) -> SymbolTableNode:
|
||||
new_var = Var(name, instance)
|
||||
new_var.info = instance.type
|
||||
|
||||
return SymbolTableNode(MDEF, new_var, plugin_generated=True)
|
||||
|
||||
|
||||
def add_new_class_member(klass_typeinfo: TypeInfo, name: str, new_member_instance: Instance) -> None:
|
||||
klass_typeinfo.names[name] = create_new_symtable_node_for_class_member(name,
|
||||
instance=new_member_instance)
|
||||
klass_typeinfo.names[name] = helpers.create_new_symtable_node(name,
|
||||
kind=MDEF,
|
||||
instance=new_member_instance)
|
||||
|
||||
|
||||
def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type:
|
||||
@@ -47,8 +42,12 @@ def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type:
|
||||
return ctx.default_return_type
|
||||
|
||||
related_name = extract_related_name_value(ctx)
|
||||
queryset_type = api.named_generic_type('django.db.models.QuerySet',
|
||||
queryset_type = api.named_generic_type(helpers.QUERYSET_CLASS_FULLNAME,
|
||||
args=[Instance(outer_class_info, [])])
|
||||
if isinstance(referred_to, AnyType):
|
||||
# referred_to defined as string, which is unsupported for now
|
||||
return ctx.default_return_type
|
||||
|
||||
add_new_class_member(referred_to.type,
|
||||
related_name, queryset_type)
|
||||
return ctx.default_return_type
|
||||
@@ -75,35 +74,20 @@ def set_fieldname_attrs_for_related_fields(ctx: ClassDefContext) -> None:
|
||||
api = ctx.api
|
||||
|
||||
new_symtable_nodes = SymbolTable()
|
||||
for (name, symtable_node), assignment_stmt in zip(ctx.cls.info.names.items(), ctx.cls.defs.body):
|
||||
rvalue_callee = assignment_stmt.rvalue.callee
|
||||
if rvalue_callee.fullname in {FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME}:
|
||||
for (name, symtable_node), stmt in zip(ctx.cls.info.names.items(), ctx.cls.defs.body):
|
||||
if not isinstance(stmt, AssignmentStmt):
|
||||
continue
|
||||
if not hasattr(stmt.rvalue, 'callee'):
|
||||
continue
|
||||
|
||||
rvalue_callee = stmt.rvalue.callee
|
||||
if rvalue_callee.fullname in {helpers.FOREIGN_KEY_FULLNAME,
|
||||
helpers.ONETOONE_FIELD_FULLNAME}:
|
||||
name += '_id'
|
||||
new_node = create_new_symtable_node_for_class_member(name,
|
||||
instance=api.named_type('__builtins__.int'))
|
||||
new_node = helpers.create_new_symtable_node(name,
|
||||
kind=MDEF,
|
||||
instance=api.named_type('__builtins__.int'))
|
||||
new_symtable_nodes[name] = new_node
|
||||
|
||||
for name, node in new_symtable_nodes.items():
|
||||
ctx.cls.info.names[name] = node
|
||||
|
||||
|
||||
class RelatedFieldsPlugin(Plugin):
|
||||
def get_function_hook(self, fullname: str
|
||||
) -> Optional[Callable[[FunctionContext], Type]]:
|
||||
if fullname == 'django.db.models.fields.related.ForeignKey':
|
||||
return set_related_name_manager_for_foreign_key
|
||||
|
||||
if fullname == 'django.db.models.fields.related.OneToOneField':
|
||||
return set_related_name_instance_for_onetoonefield
|
||||
|
||||
return None
|
||||
|
||||
def get_base_class_hook(self, fullname: str
|
||||
) -> Optional[Callable[[ClassDefContext], None]]:
|
||||
if fullname == 'django.db.models.base.Model':
|
||||
return set_fieldname_attrs_for_related_fields
|
||||
return None
|
||||
|
||||
|
||||
def plugin(version):
|
||||
return RelatedFieldsPlugin
|
||||
|
||||
@@ -3,6 +3,5 @@
|
||||
testpaths = test
|
||||
python_files = test*.py
|
||||
addopts =
|
||||
; -nauto
|
||||
--tb=native
|
||||
-v
|
||||
2
setup.py
2
setup.py
@@ -18,7 +18,7 @@ setup(
|
||||
author_email="maxim.kurnikov@gmail.com",
|
||||
version="0.1.0",
|
||||
license='BSD',
|
||||
install_requires='Django>=2.1.1',
|
||||
install_requires=['Django>=2.1.1'],
|
||||
packages=['mypy_django_plugin']
|
||||
# package_data=find_stubs('django-stubs')
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
[mypy]
|
||||
plugins =
|
||||
mypy_django_plugin.plugins.postgres_fields,
|
||||
mypy_django_plugin.plugins.related_fields
|
||||
mypy_django_plugin.plugin
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
[case testBasicModelFields]
|
||||
from django.db import models
|
||||
|
||||
|
||||
class User(models.Model):
|
||||
id = models.AutoField(primary_key=True)
|
||||
small_int = models.SmallIntegerField()
|
||||
|
||||
@@ -56,3 +56,16 @@ class Profile(models.Model):
|
||||
profile = Profile()
|
||||
reveal_type(profile.user_id) # E: Revealed type is 'builtins.int'
|
||||
[out]
|
||||
|
||||
[case testToParameterKeywordMaybeAbsent]
|
||||
from django.db import models
|
||||
|
||||
class User(models.Model):
|
||||
pass
|
||||
|
||||
class Profile(models.Model):
|
||||
user = models.OneToOneField(User, on_delete=models.CASCADE, related_name='profile')
|
||||
|
||||
reveal_type(User().profile) # E: Revealed type is 'main.Profile'
|
||||
[out]
|
||||
|
||||
|
||||
24
test/test-data/check-objects-queryset.test
Normal file
24
test/test-data/check-objects-queryset.test
Normal file
@@ -0,0 +1,24 @@
|
||||
[case testEveryModelClassHasDefaultObjectsQuerySetAvailableAsAttribute]
|
||||
from django.db import models
|
||||
|
||||
class User(models.Model):
|
||||
pass
|
||||
|
||||
reveal_type(User.objects) # E: Revealed type is 'django.db.models.query.QuerySet[main.User]'
|
||||
[out]
|
||||
|
||||
[case testGetReturnsModelInstanceIfInheritedFromAbstractMixin]
|
||||
from django.db import models
|
||||
|
||||
class ModelMixin(models.Model):
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
class User(ModelMixin):
|
||||
pass
|
||||
|
||||
reveal_type(ModelMixin.objects)
|
||||
reveal_type(User.objects.get()) # E: Revealed type is 'main.User*'
|
||||
[out]
|
||||
main:10: error: Revealed type is 'Any'
|
||||
main:10: error: "Type[ModelMixin]" has no attribute "objects"
|
||||
@@ -14,9 +14,10 @@ MYPY_INI_PATH = ROOT_DIR / 'test' / 'plugins.ini'
|
||||
|
||||
class DjangoTestSuite(DataSuite):
|
||||
files = [
|
||||
'check-objects-queryset.test',
|
||||
'check-model-fields.test',
|
||||
'check-postgres-fields.test',
|
||||
'check-model-relations.test',
|
||||
'check-model-relations.test'
|
||||
]
|
||||
data_prefix = str(TEST_DATA_DIR)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user