mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-08 04:54:48 +08:00
add support for OneToOneField, more stubs
This commit is contained in:
@@ -7,6 +7,10 @@ from .fields import (AutoField as AutoField,
|
|||||||
Field as Field,
|
Field as Field,
|
||||||
SlugField as SlugField,
|
SlugField as SlugField,
|
||||||
TextField as TextField)
|
TextField as TextField)
|
||||||
from .fields.related import (ForeignKey as ForeignKey)
|
from .fields.related import (ForeignKey as ForeignKey,
|
||||||
from .deletion import CASCADE as CASCADE
|
OneToOneField as OneToOneField)
|
||||||
|
from .deletion import (CASCADE as CASCADE,
|
||||||
|
SET_DEFAULT as SET_DEFAULT,
|
||||||
|
SET_NULL as SET_NULL,
|
||||||
|
DO_NOTHING as DO_NOTHING)
|
||||||
from .query import QuerySet as QuerySet
|
from .query import QuerySet as QuerySet
|
||||||
@@ -1,6 +1,16 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class ModelBase(type):
|
class ModelBase(type):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Model(metaclass=ModelBase):
|
class Model(metaclass=ModelBase):
|
||||||
|
class DoesNotExist(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def __init__(self, **kwargs) -> None: ...
|
||||||
|
|
||||||
|
def delete(self,
|
||||||
|
using: Any = ...,
|
||||||
|
keep_parents: bool = ...) -> None: ...
|
||||||
@@ -1,2 +1,7 @@
|
|||||||
def CASCADE(collector, field, sub_objs, using):
|
def CASCADE(collector, field, sub_objs, using): ...
|
||||||
...
|
|
||||||
|
def SET_NULL(collector, field, sub_objs, using): ...
|
||||||
|
|
||||||
|
def SET_DEFAULT(collector, field, sub_objs, using): ...
|
||||||
|
|
||||||
|
def DO_NOTHING(collector, field, sub_objs, using): ...
|
||||||
|
|||||||
@@ -10,5 +10,15 @@ class ForeignKey(Field, Generic[_T]):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
to: Union[Type[_T], str],
|
to: Union[Type[_T], str],
|
||||||
on_delete: Any,
|
on_delete: Any,
|
||||||
|
related_name: str = ...,
|
||||||
|
**kwargs): ...
|
||||||
|
def __get__(self, instance, owner) -> _T: ...
|
||||||
|
|
||||||
|
|
||||||
|
class OneToOneField(Field, Generic[_T]):
|
||||||
|
def __init__(self,
|
||||||
|
to: Union[Type[_T], str],
|
||||||
|
on_delete: Any,
|
||||||
|
related_name: str = ...,
|
||||||
**kwargs): ...
|
**kwargs): ...
|
||||||
def __get__(self, instance, owner) -> _T: ...
|
def __get__(self, instance, owner) -> _T: ...
|
||||||
|
|||||||
@@ -1,18 +0,0 @@
|
|||||||
from typing import Set
|
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_base_models():
|
|
||||||
return {'django.db.models.base.Model'}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class DjangoModelsRegistry(object):
|
|
||||||
base_models: Set[str] = dataclasses.field(default_factory=get_default_base_models)
|
|
||||||
|
|
||||||
def __contains__(self, item: str) -> bool:
|
|
||||||
return item in self.base_models
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return iter(self.base_models)
|
|
||||||
@@ -1,164 +0,0 @@
|
|||||||
import dataclasses
|
|
||||||
from mypy import types, nodes
|
|
||||||
from mypy.nodes import CallExpr, StrExpr, AssignmentStmt, RefExpr
|
|
||||||
from mypy.plugin import AttributeContext, ClassDefContext, SemanticAnalyzerPluginInterface
|
|
||||||
from mypy.types import Type, Instance, AnyType, TypeOfAny
|
|
||||||
|
|
||||||
from mypy_django_plugin.helpers import lookup_django_model, get_app_model
|
|
||||||
from mypy_django_plugin.model_classes import DjangoModelsRegistry
|
|
||||||
|
|
||||||
# mapping between field types and plain python types
|
|
||||||
DB_FIELDS_TO_TYPES = {
|
|
||||||
'django.db.models.fields.CharField': 'builtins.str',
|
|
||||||
'django.db.models.fields.TextField': 'builtins.str',
|
|
||||||
'django.db.models.fields.BooleanField': 'builtins.bool',
|
|
||||||
# 'django.db.models.fields.NullBooleanField': 'typing.Optional[builtins.bool]',
|
|
||||||
'django.db.models.fields.IntegerField': 'builtins.int',
|
|
||||||
'django.db.models.fields.AutoField': 'builtins.int',
|
|
||||||
'django.db.models.fields.FloatField': 'builtins.float',
|
|
||||||
'django.contrib.postgres.fields.jsonb.JSONField': 'builtins.dict',
|
|
||||||
'django.contrib.postgres.fields.array.ArrayField': 'typing.Iterable'
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# def get_queryset_of(type_fullname: str):
|
|
||||||
# return model_definition.api.lookup_fully_qualified_or_none('django.db.models.QuerySet')
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class DjangoPluginApi(object):
|
|
||||||
mypy_api: SemanticAnalyzerPluginInterface
|
|
||||||
|
|
||||||
def get_queryset_of(self, type_fullname: str = 'django.db.models.base.Model') -> types.Type:
|
|
||||||
queryset_sym = self.mypy_api.lookup_fully_qualified_or_none('django.db.models.QuerySet')
|
|
||||||
if not queryset_sym:
|
|
||||||
return AnyType(TypeOfAny.from_error)
|
|
||||||
|
|
||||||
generic_arg = self.mypy_api.lookup_fully_qualified_or_none(type_fullname)
|
|
||||||
if not generic_arg:
|
|
||||||
return Instance(queryset_sym.node, [AnyType(TypeOfAny.from_error)])
|
|
||||||
|
|
||||||
return Instance(queryset_sym.node, [Instance(generic_arg.node, [])])
|
|
||||||
|
|
||||||
def generate_related_manager_assignment_stmt(self,
|
|
||||||
related_mngr_name: str,
|
|
||||||
queryset_argument_type_fullname: str) -> nodes.AssignmentStmt:
|
|
||||||
rvalue = nodes.TempNode(AnyType(TypeOfAny.special_form))
|
|
||||||
assignment = nodes.AssignmentStmt(lvalues=[nodes.NameExpr(related_mngr_name)],
|
|
||||||
rvalue=rvalue,
|
|
||||||
new_syntax=True,
|
|
||||||
type=self.get_queryset_of(queryset_argument_type_fullname))
|
|
||||||
return assignment
|
|
||||||
|
|
||||||
|
|
||||||
class DetermineFieldPythonTypeCallback(object):
|
|
||||||
def __init__(self, models_registry: DjangoModelsRegistry):
|
|
||||||
self.models_registry = models_registry
|
|
||||||
|
|
||||||
def __call__(self, attr_context: AttributeContext) -> Type:
|
|
||||||
default_attr_type = attr_context.default_attr_type
|
|
||||||
|
|
||||||
if isinstance(default_attr_type, Instance):
|
|
||||||
attr_type_fullname = default_attr_type.type.fullname()
|
|
||||||
if attr_type_fullname in DB_FIELDS_TO_TYPES:
|
|
||||||
return attr_context.api.named_type(DB_FIELDS_TO_TYPES[attr_type_fullname])
|
|
||||||
|
|
||||||
# if 'base' in default_attr_type.type.metadata:
|
|
||||||
# referred_base_model = default_attr_type.type.metadata['base']
|
|
||||||
|
|
||||||
if 'members' in attr_context.type.type.metadata:
|
|
||||||
arg_name = attr_context.context.name
|
|
||||||
if arg_name in attr_context.type.type.metadata['members']:
|
|
||||||
referred_base_model = attr_context.type.type.metadata['members'][arg_name]
|
|
||||||
|
|
||||||
typ = lookup_django_model(attr_context.api, referred_base_model)
|
|
||||||
try:
|
|
||||||
return Instance(typ.node, [])
|
|
||||||
except AssertionError as e:
|
|
||||||
return typ.type
|
|
||||||
|
|
||||||
return default_attr_type
|
|
||||||
|
|
||||||
|
|
||||||
# fields which real type is inside to= expression
|
|
||||||
REFERENCING_DB_FIELDS = {
|
|
||||||
'django.db.models.fields.related.ForeignKey',
|
|
||||||
'django.db.models.fields.related.OneToOneField'
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def save_referred_to_model_in_metadata(rvalue: CallExpr) -> None:
|
|
||||||
to_arg_value = rvalue.args[rvalue.arg_names.index('to')]
|
|
||||||
if isinstance(to_arg_value, StrExpr):
|
|
||||||
referred_model_fullname = get_app_model(to_arg_value.value)
|
|
||||||
else:
|
|
||||||
referred_model_fullname = to_arg_value.fullname
|
|
||||||
|
|
||||||
rvalue.callee.node.metadata['base'] = referred_model_fullname
|
|
||||||
|
|
||||||
|
|
||||||
class CollectModelsInformationCallback(object):
|
|
||||||
def __init__(self, model_registry: DjangoModelsRegistry):
|
|
||||||
self.model_registry = model_registry
|
|
||||||
|
|
||||||
def __call__(self, model_definition: ClassDefContext) -> None:
|
|
||||||
self.model_registry.base_models.add(model_definition.cls.fullname)
|
|
||||||
plugin_api = DjangoPluginApi(mypy_api=model_definition.api)
|
|
||||||
|
|
||||||
for member in model_definition.cls.defs.body:
|
|
||||||
if isinstance(member, AssignmentStmt):
|
|
||||||
if len(member.lvalues) > 1:
|
|
||||||
return None
|
|
||||||
|
|
||||||
arg_name = member.lvalues[0].name
|
|
||||||
arg_name_as_id = arg_name + '_id'
|
|
||||||
|
|
||||||
rvalue = member.rvalue
|
|
||||||
if isinstance(rvalue, CallExpr):
|
|
||||||
if not isinstance(rvalue.callee, RefExpr):
|
|
||||||
return None
|
|
||||||
|
|
||||||
if rvalue.callee.fullname in REFERENCING_DB_FIELDS:
|
|
||||||
if rvalue.callee.fullname == 'django.db.models.fields.related.ForeignKey':
|
|
||||||
model_definition.cls.info.names[arg_name_as_id] = \
|
|
||||||
model_definition.api.lookup_fully_qualified('builtins.int')
|
|
||||||
|
|
||||||
referred_to_model = rvalue.args[rvalue.arg_names.index('to')]
|
|
||||||
if isinstance(referred_to_model, StrExpr):
|
|
||||||
referred_model_fullname = get_app_model(referred_to_model.value)
|
|
||||||
else:
|
|
||||||
referred_model_fullname = referred_to_model.fullname
|
|
||||||
|
|
||||||
rvalue.callee.node.metadata['base'] = referred_model_fullname
|
|
||||||
|
|
||||||
referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname)
|
|
||||||
|
|
||||||
if 'related_name' in rvalue.arg_names:
|
|
||||||
related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value
|
|
||||||
referred_model_class_def = referred_model.node.defn # type: nodes.ClassDef
|
|
||||||
|
|
||||||
referred_model_class_def.defs.body.append(
|
|
||||||
plugin_api.generate_related_manager_assignment_stmt(related_arg_value,
|
|
||||||
model_definition.cls.fullname))
|
|
||||||
|
|
||||||
if rvalue.callee.fullname == 'django.db.models.fields.related.OneToOneField':
|
|
||||||
referred_to_model = rvalue.args[rvalue.arg_names.index('to')]
|
|
||||||
if isinstance(referred_to_model, StrExpr):
|
|
||||||
referred_model_fullname = get_app_model(referred_to_model.value)
|
|
||||||
else:
|
|
||||||
referred_model_fullname = referred_to_model.fullname
|
|
||||||
|
|
||||||
referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname)
|
|
||||||
|
|
||||||
if 'related_name' in rvalue.arg_names:
|
|
||||||
related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value
|
|
||||||
referred_model.node.names[related_arg_value] = \
|
|
||||||
model_definition.api.lookup_fully_qualified_or_none(model_definition.cls.fullname)
|
|
||||||
|
|
||||||
rvalue.callee.node.metadata['base'] = referred_model_fullname
|
|
||||||
|
|
||||||
rvalue.callee.node.metadata['name'] = arg_name
|
|
||||||
if 'members' not in model_definition.cls.info.metadata:
|
|
||||||
model_definition.cls.info.metadata['members'] = {}
|
|
||||||
|
|
||||||
model_definition.cls.info.metadata['members'][arg_name] = rvalue.callee.node.metadata.get('base', None)
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
from typing import Optional, Callable, Type
|
|
||||||
|
|
||||||
from mypy.plugin import Plugin, ClassDefContext, AttributeContext
|
|
||||||
|
|
||||||
from mypy_django_plugin.model_classes import DjangoModelsRegistry
|
|
||||||
from mypy_django_plugin.plugins.callbacks import CollectModelsInformationCallback, DetermineFieldPythonTypeCallback
|
|
||||||
|
|
||||||
|
|
||||||
class FieldToPythonTypePlugin(Plugin):
|
|
||||||
model_registry = DjangoModelsRegistry()
|
|
||||||
|
|
||||||
def get_base_class_hook(self, fullname: str
|
|
||||||
) -> Optional[Callable[[ClassDefContext], None]]:
|
|
||||||
if fullname in self.model_registry:
|
|
||||||
return CollectModelsInformationCallback(self.model_registry)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_attribute_hook(self, fullname: str
|
|
||||||
) -> Optional[Callable[[AttributeContext], Type]]:
|
|
||||||
# print(fullname)
|
|
||||||
classname, _, attrname = fullname.rpartition('.')
|
|
||||||
if classname and classname in self.model_registry:
|
|
||||||
return DetermineFieldPythonTypeCallback(self.model_registry)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def plugin(version):
|
|
||||||
return FieldToPythonTypePlugin
|
|
||||||
@@ -1,21 +1,33 @@
|
|||||||
from typing import Optional, Callable
|
from typing import Optional, Callable, cast
|
||||||
|
|
||||||
|
from mypy.checker import TypeChecker
|
||||||
from mypy.nodes import Var, MDEF, SymbolTableNode
|
from mypy.nodes import Var, MDEF, SymbolTableNode
|
||||||
from mypy.plugin import Plugin, FunctionContext
|
from mypy.plugin import Plugin, FunctionContext
|
||||||
from mypy.types import Type, CallableType, Instance
|
from mypy.types import Type, CallableType, Instance
|
||||||
|
|
||||||
|
|
||||||
|
def extract_to_value_type(ctx: FunctionContext) -> Optional[Type]:
|
||||||
|
assert 'to' in ctx.context.arg_names
|
||||||
|
to_arg_value = ctx.arg_types[ctx.context.arg_names.index('to')][0]
|
||||||
|
if not isinstance(to_arg_value, CallableType):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return to_arg_value.ret_type
|
||||||
|
|
||||||
|
|
||||||
|
def extract_related_name_value(ctx: FunctionContext) -> str:
|
||||||
|
return ctx.context.args[ctx.context.arg_names.index('related_name')].value
|
||||||
|
|
||||||
|
|
||||||
def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type:
|
def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type:
|
||||||
if 'related_name' not in ctx.context.arg_names:
|
if 'related_name' not in ctx.context.arg_names:
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
assert 'to' in ctx.context.arg_names
|
referred_to = extract_to_value_type(ctx)
|
||||||
to_arg_value = ctx.arg_types[ctx.context.arg_names.index('to')][0]
|
if not referred_to:
|
||||||
if not isinstance(to_arg_value, CallableType):
|
|
||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
referred_to = to_arg_value.ret_type
|
related_name = extract_related_name_value(ctx)
|
||||||
related_name = ctx.context.args[ctx.context.arg_names.index('related_name')].value
|
|
||||||
outer_class_info = ctx.api.tscope.classes[-1]
|
outer_class_info = ctx.api.tscope.classes[-1]
|
||||||
|
|
||||||
queryset_type = ctx.api.named_generic_type('django.db.models.QuerySet',
|
queryset_type = ctx.api.named_generic_type('django.db.models.QuerySet',
|
||||||
@@ -28,11 +40,36 @@ def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type:
|
|||||||
return ctx.default_return_type
|
return ctx.default_return_type
|
||||||
|
|
||||||
|
|
||||||
|
def set_related_name_instance_for_onetoonefield(ctx: FunctionContext) -> Type:
|
||||||
|
if 'related_name' not in ctx.context.arg_names:
|
||||||
|
return ctx.default_return_type
|
||||||
|
|
||||||
|
referred_to = extract_to_value_type(ctx)
|
||||||
|
if referred_to is None:
|
||||||
|
return ctx.default_return_type
|
||||||
|
|
||||||
|
related_name = extract_related_name_value(ctx)
|
||||||
|
outer_class_info = ctx.api.tscope.classes[-1]
|
||||||
|
|
||||||
|
api = cast(TypeChecker, ctx.api)
|
||||||
|
related_instance_type = api.named_type(outer_class_info.fullname())
|
||||||
|
related_var = Var(related_name, related_instance_type)
|
||||||
|
related_var.info = related_instance_type.type
|
||||||
|
|
||||||
|
referred_to.type.names[related_name] = SymbolTableNode(MDEF, related_var,
|
||||||
|
plugin_generated=True)
|
||||||
|
return ctx.default_return_type
|
||||||
|
|
||||||
|
|
||||||
class RelatedFieldsPlugin(Plugin):
|
class RelatedFieldsPlugin(Plugin):
|
||||||
def get_function_hook(self, fullname: str
|
def get_function_hook(self, fullname: str
|
||||||
) -> Optional[Callable[[FunctionContext], Type]]:
|
) -> Optional[Callable[[FunctionContext], Type]]:
|
||||||
if fullname == 'django.db.models.fields.related.ForeignKey':
|
if fullname == 'django.db.models.fields.related.ForeignKey':
|
||||||
return set_related_name_manager_for_foreign_key
|
return set_related_name_manager_for_foreign_key
|
||||||
|
|
||||||
|
if fullname == 'django.db.models.fields.related.OneToOneField':
|
||||||
|
return set_related_name_instance_for_onetoonefield
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,18 +1,4 @@
|
|||||||
[case testForeignKeyWithClass]
|
[case testForeignKeyField]
|
||||||
from django.db import models
|
|
||||||
|
|
||||||
class Publisher(models.Model):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class Book(models.Model):
|
|
||||||
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
|
|
||||||
|
|
||||||
book = Book()
|
|
||||||
reveal_type(book.publisher) # E: Revealed type is 'main.Publisher*'
|
|
||||||
[out]
|
|
||||||
|
|
||||||
|
|
||||||
[case testForeignKeyRelatedName]
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
|
||||||
class Publisher(models.Model):
|
class Publisher(models.Model):
|
||||||
@@ -22,6 +8,25 @@ class Book(models.Model):
|
|||||||
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE,
|
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE,
|
||||||
related_name='books')
|
related_name='books')
|
||||||
|
|
||||||
|
book = Book()
|
||||||
|
reveal_type(book.publisher) # E: Revealed type is 'main.Publisher*'
|
||||||
|
|
||||||
publisher = Publisher()
|
publisher = Publisher()
|
||||||
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]'
|
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]'
|
||||||
[out]
|
[out]
|
||||||
|
|
||||||
|
[case testOneToOneField]
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
class User(models.Model):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Profile(models.Model):
|
||||||
|
user = models.OneToOneField(to=User, on_delete=models.CASCADE, related_name='profile')
|
||||||
|
|
||||||
|
profile = Profile()
|
||||||
|
reveal_type(profile.user) # E: Revealed type is 'main.User*'
|
||||||
|
|
||||||
|
user = User()
|
||||||
|
reveal_type(user.profile) # E: Revealed type is 'main.Profile'
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user