mirror of
https://github.com/davidhalter/django-stubs.git
synced 2026-05-25 17:58:41 +08:00
support for some Field -> builtin type mapping, OneToOneField, ForeignKey
This commit is contained in:
@@ -1,13 +0,0 @@
|
|||||||
# Django fields that has to= attribute used to retrieve original model
|
|
||||||
REFERENCING_DB_FIELDS = {
|
|
||||||
'django.db.models.fields.related.ForeignKey',
|
|
||||||
'django.db.models.fields.related.OneToOneField'
|
|
||||||
}
|
|
||||||
|
|
||||||
# mapping between field types and plain python types
|
|
||||||
DB_FIELDS_TO_TYPES = {
|
|
||||||
'django.db.models.fields.CharField': 'builtins.str',
|
|
||||||
'django.db.models.fields.TextField': 'builtins.str',
|
|
||||||
'django.db.models.fields.IntegerField': 'builtins.int',
|
|
||||||
'django.db.models.fields.FloatField': 'builtins.float'
|
|
||||||
}
|
|
||||||
@@ -16,4 +16,22 @@ def lookup_django_model(mypy_api: TypeChecker, fullname: str) -> SymbolTableNode
|
|||||||
try:
|
try:
|
||||||
return mypy_api.modules[module].names[model_name]
|
return mypy_api.modules[module].names[model_name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return mypy_api.modules['django.db.models'].names['Model']
|
return mypy_api.modules['django.db.models'].names['Model']
|
||||||
|
|
||||||
|
|
||||||
|
def get_app_model(model_name: str) -> str:
|
||||||
|
import os
|
||||||
|
os.environ.setdefault('SITE_URL', 'https://localhost')
|
||||||
|
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'server._config.settings.local')
|
||||||
|
|
||||||
|
import django
|
||||||
|
django.setup()
|
||||||
|
|
||||||
|
from django.apps import apps
|
||||||
|
|
||||||
|
try:
|
||||||
|
app_name, model_name = model_name.rsplit('.', maxsplit=1)
|
||||||
|
model = apps.get_model(app_name, model_name)
|
||||||
|
return model.__module__ + '.' + model_name
|
||||||
|
except ValueError:
|
||||||
|
return model_name
|
||||||
|
|||||||
@@ -1,22 +1,18 @@
|
|||||||
import json
|
from typing import Set
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class ModelInfo(object):
|
|
||||||
# class_name: str
|
|
||||||
related_managers: Dict[str, 'ModelInfo'] = dataclasses.field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_base_models():
|
def get_default_base_models():
|
||||||
return {'django.db.models.base.Model': ModelInfo()}
|
return {'django.db.models.base.Model'}
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class DjangoModelsRegistry(object):
|
class DjangoModelsRegistry(object):
|
||||||
base_models: Dict[str, ModelInfo] = dataclasses.field(default_factory=get_default_base_models)
|
base_models: Set[str] = dataclasses.field(default_factory=get_default_base_models)
|
||||||
|
|
||||||
def __contains__(self, item: str) -> bool:
|
def __contains__(self, item: str) -> bool:
|
||||||
return item in self.base_models
|
return item in self.base_models
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.base_models)
|
||||||
|
|||||||
@@ -1,141 +0,0 @@
|
|||||||
import importlib
|
|
||||||
import inspect
|
|
||||||
from typing import Optional, Callable
|
|
||||||
|
|
||||||
from mypy.mypyc_hacks import TypeOfAny
|
|
||||||
from mypy.nodes import PassStmt, CallExpr, SymbolTableNode, MDEF, Var, ClassDef, Decorator, AssignmentStmt, StrExpr
|
|
||||||
from mypy.plugin import Plugin, ClassDefContext, AttributeContext
|
|
||||||
from mypy.types import AnyType, Type, Instance, CallableType
|
|
||||||
|
|
||||||
from mypy_django_plugin.constants import REFERENCING_DB_FIELDS, DB_FIELDS_TO_TYPES
|
|
||||||
from mypy_django_plugin.helpers import lookup_django_model
|
|
||||||
from mypy_django_plugin.model_classes import DjangoModelsRegistry, ModelInfo
|
|
||||||
|
|
||||||
|
|
||||||
model_registry = DjangoModelsRegistry()
|
|
||||||
|
|
||||||
|
|
||||||
def get_db_field_arguments(rvalue: CallExpr) -> inspect.BoundArguments:
|
|
||||||
modulename, _, classname = rvalue.callee.fullname.rpartition('.')
|
|
||||||
klass = getattr(importlib.import_module(modulename), classname)
|
|
||||||
bound_signature = inspect.signature(klass).bind(*rvalue.args)
|
|
||||||
|
|
||||||
return bound_signature
|
|
||||||
|
|
||||||
|
|
||||||
def process_meta_innerclass(class_def: ClassDef):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def get_app_model(model_name: str) -> str:
|
|
||||||
import os
|
|
||||||
os.environ.setdefault('SITE_URL', 'https://localhost')
|
|
||||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'server._config.settings.local')
|
|
||||||
|
|
||||||
import django
|
|
||||||
django.setup()
|
|
||||||
|
|
||||||
from django.apps import apps
|
|
||||||
|
|
||||||
try:
|
|
||||||
app_name, model_name = model_name.rsplit('.', maxsplit=1)
|
|
||||||
model = apps.get_model(app_name, model_name)
|
|
||||||
return model.__module__ + '.' + model_name
|
|
||||||
except ValueError:
|
|
||||||
return model_name
|
|
||||||
|
|
||||||
|
|
||||||
def base_class_callback(model_def_context: ClassDefContext) -> None:
|
|
||||||
# add new possible base models
|
|
||||||
for base_type_expr in model_def_context.cls.base_type_exprs:
|
|
||||||
if base_type_expr.fullname in model_registry:
|
|
||||||
model_registry.base_models[model_def_context.cls.fullname] = ModelInfo()
|
|
||||||
|
|
||||||
for definition in model_def_context.cls.defs.body:
|
|
||||||
if isinstance(definition, PassStmt):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(definition, ClassDef):
|
|
||||||
if definition.name == 'Meta':
|
|
||||||
process_meta_innerclass(definition)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(definition, AssignmentStmt):
|
|
||||||
rvalue = definition.rvalue
|
|
||||||
if hasattr(rvalue, 'callee') and rvalue.callee.fullname in REFERENCING_DB_FIELDS:
|
|
||||||
modulename, _, classname = rvalue.callee.fullname.rpartition('.')
|
|
||||||
klass = getattr(importlib.import_module(modulename), classname)
|
|
||||||
bound_signature = inspect.signature(klass).bind(*rvalue.args)
|
|
||||||
|
|
||||||
to_arg_value = bound_signature.arguments['to']
|
|
||||||
if isinstance(to_arg_value, StrExpr):
|
|
||||||
model_fullname = get_app_model(to_arg_value.value)
|
|
||||||
else:
|
|
||||||
model_fullname = bound_signature.arguments['to'].fullname
|
|
||||||
|
|
||||||
referred_model = model_fullname
|
|
||||||
rvalue.callee.node.metadata['base'] = referred_model
|
|
||||||
|
|
||||||
# if 'related_name' in rvalue.arg_names:
|
|
||||||
# related_name = rvalue.args[rvalue.arg_names.index('related_name')].value
|
|
||||||
#
|
|
||||||
# for name, context_global in model_def_context.api.globals.items():
|
|
||||||
# context_global: SymbolTableNode = context_global
|
|
||||||
# if context_global.fullname == referred_model:
|
|
||||||
# list_of_any = Instance(model_def_context.api.lookup_fully_qualified('typing.List').node,
|
|
||||||
# [AnyType(TypeOfAny.from_omitted_generics)])
|
|
||||||
# related_members_node = Var(related_name, list_of_any)
|
|
||||||
# context_global.node.names[related_name] = SymbolTableNode(MDEF, related_members_node)
|
|
||||||
|
|
||||||
# model_registry.base_models[referred_model].related_managers[related_name] = model_def_context.cls.fullname
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def related_manager_inference_callback(attr_context: AttributeContext) -> Type:
|
|
||||||
mypy_api = attr_context.api
|
|
||||||
|
|
||||||
default_attr_type = attr_context.default_attr_type
|
|
||||||
if not isinstance(default_attr_type, Instance):
|
|
||||||
return default_attr_type
|
|
||||||
|
|
||||||
attr_type_fullname = default_attr_type.type.fullname()
|
|
||||||
|
|
||||||
if attr_type_fullname in DB_FIELDS_TO_TYPES:
|
|
||||||
return mypy_api.named_type(DB_FIELDS_TO_TYPES[attr_type_fullname])
|
|
||||||
|
|
||||||
if 'base' in default_attr_type.type.metadata:
|
|
||||||
base_class = default_attr_type.type.metadata['base']
|
|
||||||
|
|
||||||
node = lookup_django_model(mypy_api, base_class).node
|
|
||||||
return Instance(node, [])
|
|
||||||
# mypy_api.lookup_qualified(base_class)
|
|
||||||
# app = base_class.split('.')[0]
|
|
||||||
# name = base_class.split('.')[-1]
|
|
||||||
# app_model = get_app_model(app + '.' + name)
|
|
||||||
|
|
||||||
# return mypy_api.named_type(name)
|
|
||||||
|
|
||||||
return AnyType(TypeOfAny.unannotated)
|
|
||||||
|
|
||||||
|
|
||||||
#
|
|
||||||
# def type_analyze_callback(context: AnalyzeTypeContext) -> Type:
|
|
||||||
# return context.type
|
|
||||||
|
|
||||||
|
|
||||||
class RelatedManagersDjangoPlugin(Plugin):
|
|
||||||
def get_base_class_hook(self, fullname: str
|
|
||||||
) -> Optional[Callable[[ClassDefContext], None]]:
|
|
||||||
# every time new class is created, this method is called with first base class in MRO
|
|
||||||
if fullname in model_registry:
|
|
||||||
return base_class_callback
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_attribute_hook(self, fullname: str
|
|
||||||
) -> Optional[Callable[[AttributeContext], Type]]:
|
|
||||||
return related_manager_inference_callback
|
|
||||||
|
|
||||||
|
|
||||||
def plugin(version):
|
|
||||||
return RelatedManagersDjangoPlugin
|
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from mypy.nodes import AssignmentStmt, CallExpr, RefExpr, StrExpr
|
||||||
|
from mypy.plugin import Plugin, ClassDefContext
|
||||||
|
|
||||||
|
from mypy_django_plugin.helpers import get_app_model
|
||||||
|
from mypy_django_plugin.model_classes import DjangoModelsRegistry
|
||||||
|
|
||||||
|
|
||||||
|
# fields which real type is inside to= expression
|
||||||
|
REFERENCING_DB_FIELDS = {
|
||||||
|
'django.db.models.fields.related.ForeignKey',
|
||||||
|
'django.db.models.fields.related.OneToOneField'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def save_referred_to_model_in_metadata(rvalue: CallExpr) -> None:
|
||||||
|
to_arg_value = rvalue.args[rvalue.arg_names.index('to')]
|
||||||
|
if isinstance(to_arg_value, StrExpr):
|
||||||
|
referred_model_fullname = get_app_model(to_arg_value.value)
|
||||||
|
else:
|
||||||
|
referred_model_fullname = to_arg_value.fullname
|
||||||
|
|
||||||
|
rvalue.callee.node.metadata['base'] = referred_model_fullname
|
||||||
|
|
||||||
|
|
||||||
|
class CollectModelsInformation(object):
|
||||||
|
def __init__(self, model_registry: DjangoModelsRegistry):
|
||||||
|
self.model_registry = model_registry
|
||||||
|
|
||||||
|
def __call__(self, model_definition: ClassDefContext) -> None:
|
||||||
|
self.model_registry.base_models.add(model_definition.cls.fullname)
|
||||||
|
|
||||||
|
for member in model_definition.cls.defs.body:
|
||||||
|
if isinstance(member, AssignmentStmt):
|
||||||
|
if len(member.lvalues) > 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
arg_name = member.lvalues[0].name
|
||||||
|
arg_name_as_id = arg_name + '_id'
|
||||||
|
|
||||||
|
rvalue = member.rvalue
|
||||||
|
if isinstance(rvalue, CallExpr):
|
||||||
|
if not isinstance(rvalue.callee, RefExpr):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if rvalue.callee.fullname in REFERENCING_DB_FIELDS:
|
||||||
|
if rvalue.callee.fullname == 'django.db.models.fields.related.ForeignKey':
|
||||||
|
model_definition.cls.info.names[arg_name_as_id] = \
|
||||||
|
model_definition.api.lookup_fully_qualified('builtins.int')
|
||||||
|
|
||||||
|
if rvalue.callee.fullname == 'django.db.models.fields.related.OneToOneField':
|
||||||
|
if 'related_name' in rvalue.arg_names:
|
||||||
|
referred_to_model = rvalue.args[rvalue.arg_names.index('to')]
|
||||||
|
related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value
|
||||||
|
|
||||||
|
if isinstance(referred_to_model, StrExpr):
|
||||||
|
referred_model_fullname = get_app_model(referred_to_model.value)
|
||||||
|
else:
|
||||||
|
referred_model_fullname = referred_to_model.fullname
|
||||||
|
|
||||||
|
referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname)
|
||||||
|
referred_model.node.names[related_arg_value] = \
|
||||||
|
model_definition.api.lookup_fully_qualified_or_none(model_definition.cls.fullname)
|
||||||
|
|
||||||
|
return save_referred_to_model_in_metadata(rvalue)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDjangoModelsPlugin(Plugin):
|
||||||
|
model_registry = DjangoModelsRegistry()
|
||||||
|
|
||||||
|
def get_base_class_hook(self, fullname: str
|
||||||
|
) -> Optional[Callable[[ClassDefContext], None]]:
|
||||||
|
if fullname in self.model_registry:
|
||||||
|
return CollectModelsInformation(self.model_registry)
|
||||||
|
|
||||||
|
return None
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
from typing import Optional, Callable
|
||||||
|
|
||||||
|
from mypy.plugin import AttributeContext
|
||||||
|
from mypy.types import Type, Instance
|
||||||
|
|
||||||
|
from mypy_django_plugin.helpers import lookup_django_model
|
||||||
|
from mypy_django_plugin.model_classes import DjangoModelsRegistry
|
||||||
|
from mypy_django_plugin.plugins.base import BaseDjangoModelsPlugin
|
||||||
|
|
||||||
|
# mapping between field types and plain python types
|
||||||
|
DB_FIELDS_TO_TYPES = {
|
||||||
|
'django.db.models.fields.CharField': 'builtins.str',
|
||||||
|
'django.db.models.fields.TextField': 'builtins.str',
|
||||||
|
'django.db.models.fields.BooleanField': 'builtins.bool',
|
||||||
|
# 'django.db.models.fields.NullBooleanField': 'typing.Optional[builtins.bool]',
|
||||||
|
'django.db.models.fields.IntegerField': 'builtins.int',
|
||||||
|
'django.db.models.fields.AutoField': 'builtins.int',
|
||||||
|
'django.db.models.fields.FloatField': 'builtins.float',
|
||||||
|
'django.contrib.postgres.fields.jsonb.JSONField': 'builtins.dict',
|
||||||
|
'django.contrib.postgres.fields.array.ArrayField': 'typing.Iterable'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DetermineFieldPythonTypeCallback(object):
|
||||||
|
def __init__(self, models_registry: DjangoModelsRegistry):
|
||||||
|
self.models_registry = models_registry
|
||||||
|
|
||||||
|
def __call__(self, attr_context: AttributeContext) -> Type:
|
||||||
|
default_attr_type = attr_context.default_attr_type
|
||||||
|
|
||||||
|
if isinstance(default_attr_type, Instance):
|
||||||
|
attr_type_fullname = default_attr_type.type.fullname()
|
||||||
|
if attr_type_fullname in DB_FIELDS_TO_TYPES:
|
||||||
|
return attr_context.api.named_type(DB_FIELDS_TO_TYPES[attr_type_fullname])
|
||||||
|
|
||||||
|
if 'base' in default_attr_type.type.metadata:
|
||||||
|
referred_base_model = default_attr_type.type.metadata['base']
|
||||||
|
try:
|
||||||
|
node = lookup_django_model(attr_context.api, referred_base_model).node
|
||||||
|
return Instance(node, [])
|
||||||
|
except AssertionError as e:
|
||||||
|
print(e)
|
||||||
|
print('name to lookup:', referred_base_model)
|
||||||
|
pass
|
||||||
|
|
||||||
|
return default_attr_type
|
||||||
|
|
||||||
|
|
||||||
|
class FieldToPythonTypePlugin(BaseDjangoModelsPlugin):
|
||||||
|
def get_attribute_hook(self, fullname: str
|
||||||
|
) -> Optional[Callable[[AttributeContext], Type]]:
|
||||||
|
classname, _, attrname = fullname.rpartition('.')
|
||||||
|
if classname and classname in self.model_registry:
|
||||||
|
return DetermineFieldPythonTypeCallback(self.model_registry)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def plugin(version):
|
||||||
|
return FieldToPythonTypePlugin
|
||||||
@@ -19,6 +19,6 @@ setup(
|
|||||||
version="0.1.0",
|
version="0.1.0",
|
||||||
license='BSD',
|
license='BSD',
|
||||||
install_requires='Django>=2.1.1',
|
install_requires='Django>=2.1.1',
|
||||||
packages=['django-stubs'],
|
packages=['mypy_django_plugin']
|
||||||
package_data=find_stubs('django-stubs')
|
# package_data=find_stubs('django-stubs')
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user