add support for OneToOneField, more stubs

This commit is contained in:
Maxim Kurnikov
2018-11-13 17:43:21 +03:00
parent 155dc4e049
commit dc4f606f63
9 changed files with 97 additions and 238 deletions

View File

@@ -7,6 +7,10 @@ from .fields import (AutoField as AutoField,
Field as Field,
SlugField as SlugField,
TextField as TextField)
from .fields.related import (ForeignKey as ForeignKey)
from .deletion import CASCADE as CASCADE
from .fields.related import (ForeignKey as ForeignKey,
OneToOneField as OneToOneField)
from .deletion import (CASCADE as CASCADE,
SET_DEFAULT as SET_DEFAULT,
SET_NULL as SET_NULL,
DO_NOTHING as DO_NOTHING)
from .query import QuerySet as QuerySet

View File

@@ -1,6 +1,16 @@
from typing import Any
class ModelBase(type):
pass
class Model(metaclass=ModelBase):
pass
class DoesNotExist(Exception):
pass
def __init__(self, **kwargs) -> None: ...
def delete(self,
using: Any = ...,
keep_parents: bool = ...) -> None: ...

View File

@@ -1,2 +1,7 @@
def CASCADE(collector, field, sub_objs, using):
...
def CASCADE(collector, field, sub_objs, using): ...
def SET_NULL(collector, field, sub_objs, using): ...
def SET_DEFAULT(collector, field, sub_objs, using): ...
def DO_NOTHING(collector, field, sub_objs, using): ...

View File

@@ -10,5 +10,15 @@ class ForeignKey(Field, Generic[_T]):
def __init__(self,
to: Union[Type[_T], str],
on_delete: Any,
related_name: str = ...,
**kwargs): ...
def __get__(self, instance, owner) -> _T: ...
class OneToOneField(Field, Generic[_T]):
def __init__(self,
to: Union[Type[_T], str],
on_delete: Any,
related_name: str = ...,
**kwargs): ...
def __get__(self, instance, owner) -> _T: ...

View File

@@ -1,18 +0,0 @@
from typing import Set
import dataclasses
def get_default_base_models():
return {'django.db.models.base.Model'}
@dataclasses.dataclass
class DjangoModelsRegistry(object):
base_models: Set[str] = dataclasses.field(default_factory=get_default_base_models)
def __contains__(self, item: str) -> bool:
return item in self.base_models
def __iter__(self):
return iter(self.base_models)

View File

@@ -1,164 +0,0 @@
import dataclasses
from mypy import types, nodes
from mypy.nodes import CallExpr, StrExpr, AssignmentStmt, RefExpr
from mypy.plugin import AttributeContext, ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.types import Type, Instance, AnyType, TypeOfAny
from mypy_django_plugin.helpers import lookup_django_model, get_app_model
from mypy_django_plugin.model_classes import DjangoModelsRegistry
# mapping between field types and plain python types
DB_FIELDS_TO_TYPES = {
'django.db.models.fields.CharField': 'builtins.str',
'django.db.models.fields.TextField': 'builtins.str',
'django.db.models.fields.BooleanField': 'builtins.bool',
# 'django.db.models.fields.NullBooleanField': 'typing.Optional[builtins.bool]',
'django.db.models.fields.IntegerField': 'builtins.int',
'django.db.models.fields.AutoField': 'builtins.int',
'django.db.models.fields.FloatField': 'builtins.float',
'django.contrib.postgres.fields.jsonb.JSONField': 'builtins.dict',
'django.contrib.postgres.fields.array.ArrayField': 'typing.Iterable'
}
# def get_queryset_of(type_fullname: str):
# return model_definition.api.lookup_fully_qualified_or_none('django.db.models.QuerySet')
@dataclasses.dataclass
class DjangoPluginApi(object):
mypy_api: SemanticAnalyzerPluginInterface
def get_queryset_of(self, type_fullname: str = 'django.db.models.base.Model') -> types.Type:
queryset_sym = self.mypy_api.lookup_fully_qualified_or_none('django.db.models.QuerySet')
if not queryset_sym:
return AnyType(TypeOfAny.from_error)
generic_arg = self.mypy_api.lookup_fully_qualified_or_none(type_fullname)
if not generic_arg:
return Instance(queryset_sym.node, [AnyType(TypeOfAny.from_error)])
return Instance(queryset_sym.node, [Instance(generic_arg.node, [])])
def generate_related_manager_assignment_stmt(self,
related_mngr_name: str,
queryset_argument_type_fullname: str) -> nodes.AssignmentStmt:
rvalue = nodes.TempNode(AnyType(TypeOfAny.special_form))
assignment = nodes.AssignmentStmt(lvalues=[nodes.NameExpr(related_mngr_name)],
rvalue=rvalue,
new_syntax=True,
type=self.get_queryset_of(queryset_argument_type_fullname))
return assignment
class DetermineFieldPythonTypeCallback(object):
def __init__(self, models_registry: DjangoModelsRegistry):
self.models_registry = models_registry
def __call__(self, attr_context: AttributeContext) -> Type:
default_attr_type = attr_context.default_attr_type
if isinstance(default_attr_type, Instance):
attr_type_fullname = default_attr_type.type.fullname()
if attr_type_fullname in DB_FIELDS_TO_TYPES:
return attr_context.api.named_type(DB_FIELDS_TO_TYPES[attr_type_fullname])
# if 'base' in default_attr_type.type.metadata:
# referred_base_model = default_attr_type.type.metadata['base']
if 'members' in attr_context.type.type.metadata:
arg_name = attr_context.context.name
if arg_name in attr_context.type.type.metadata['members']:
referred_base_model = attr_context.type.type.metadata['members'][arg_name]
typ = lookup_django_model(attr_context.api, referred_base_model)
try:
return Instance(typ.node, [])
except AssertionError as e:
return typ.type
return default_attr_type
# fields which real type is inside to= expression
REFERENCING_DB_FIELDS = {
'django.db.models.fields.related.ForeignKey',
'django.db.models.fields.related.OneToOneField'
}
def save_referred_to_model_in_metadata(rvalue: CallExpr) -> None:
to_arg_value = rvalue.args[rvalue.arg_names.index('to')]
if isinstance(to_arg_value, StrExpr):
referred_model_fullname = get_app_model(to_arg_value.value)
else:
referred_model_fullname = to_arg_value.fullname
rvalue.callee.node.metadata['base'] = referred_model_fullname
class CollectModelsInformationCallback(object):
def __init__(self, model_registry: DjangoModelsRegistry):
self.model_registry = model_registry
def __call__(self, model_definition: ClassDefContext) -> None:
self.model_registry.base_models.add(model_definition.cls.fullname)
plugin_api = DjangoPluginApi(mypy_api=model_definition.api)
for member in model_definition.cls.defs.body:
if isinstance(member, AssignmentStmt):
if len(member.lvalues) > 1:
return None
arg_name = member.lvalues[0].name
arg_name_as_id = arg_name + '_id'
rvalue = member.rvalue
if isinstance(rvalue, CallExpr):
if not isinstance(rvalue.callee, RefExpr):
return None
if rvalue.callee.fullname in REFERENCING_DB_FIELDS:
if rvalue.callee.fullname == 'django.db.models.fields.related.ForeignKey':
model_definition.cls.info.names[arg_name_as_id] = \
model_definition.api.lookup_fully_qualified('builtins.int')
referred_to_model = rvalue.args[rvalue.arg_names.index('to')]
if isinstance(referred_to_model, StrExpr):
referred_model_fullname = get_app_model(referred_to_model.value)
else:
referred_model_fullname = referred_to_model.fullname
rvalue.callee.node.metadata['base'] = referred_model_fullname
referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname)
if 'related_name' in rvalue.arg_names:
related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value
referred_model_class_def = referred_model.node.defn # type: nodes.ClassDef
referred_model_class_def.defs.body.append(
plugin_api.generate_related_manager_assignment_stmt(related_arg_value,
model_definition.cls.fullname))
if rvalue.callee.fullname == 'django.db.models.fields.related.OneToOneField':
referred_to_model = rvalue.args[rvalue.arg_names.index('to')]
if isinstance(referred_to_model, StrExpr):
referred_model_fullname = get_app_model(referred_to_model.value)
else:
referred_model_fullname = referred_to_model.fullname
referred_model = model_definition.api.lookup_fully_qualified_or_none(referred_model_fullname)
if 'related_name' in rvalue.arg_names:
related_arg_value = rvalue.args[rvalue.arg_names.index('related_name')].value
referred_model.node.names[related_arg_value] = \
model_definition.api.lookup_fully_qualified_or_none(model_definition.cls.fullname)
rvalue.callee.node.metadata['base'] = referred_model_fullname
rvalue.callee.node.metadata['name'] = arg_name
if 'members' not in model_definition.cls.info.metadata:
model_definition.cls.info.metadata['members'] = {}
model_definition.cls.info.metadata['members'][arg_name] = rvalue.callee.node.metadata.get('base', None)

View File

@@ -1,30 +0,0 @@
from typing import Optional, Callable, Type
from mypy.plugin import Plugin, ClassDefContext, AttributeContext
from mypy_django_plugin.model_classes import DjangoModelsRegistry
from mypy_django_plugin.plugins.callbacks import CollectModelsInformationCallback, DetermineFieldPythonTypeCallback
class FieldToPythonTypePlugin(Plugin):
model_registry = DjangoModelsRegistry()
def get_base_class_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
if fullname in self.model_registry:
return CollectModelsInformationCallback(self.model_registry)
return None
def get_attribute_hook(self, fullname: str
) -> Optional[Callable[[AttributeContext], Type]]:
# print(fullname)
classname, _, attrname = fullname.rpartition('.')
if classname and classname in self.model_registry:
return DetermineFieldPythonTypeCallback(self.model_registry)
return None
def plugin(version):
return FieldToPythonTypePlugin

View File

@@ -1,21 +1,33 @@
from typing import Optional, Callable
from typing import Optional, Callable, cast
from mypy.checker import TypeChecker
from mypy.nodes import Var, MDEF, SymbolTableNode
from mypy.plugin import Plugin, FunctionContext
from mypy.types import Type, CallableType, Instance
def extract_to_value_type(ctx: FunctionContext) -> Optional[Type]:
assert 'to' in ctx.context.arg_names
to_arg_value = ctx.arg_types[ctx.context.arg_names.index('to')][0]
if not isinstance(to_arg_value, CallableType):
return None
return to_arg_value.ret_type
def extract_related_name_value(ctx: FunctionContext) -> str:
return ctx.context.args[ctx.context.arg_names.index('related_name')].value
def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type:
if 'related_name' not in ctx.context.arg_names:
return ctx.default_return_type
assert 'to' in ctx.context.arg_names
to_arg_value = ctx.arg_types[ctx.context.arg_names.index('to')][0]
if not isinstance(to_arg_value, CallableType):
referred_to = extract_to_value_type(ctx)
if not referred_to:
return ctx.default_return_type
referred_to = to_arg_value.ret_type
related_name = ctx.context.args[ctx.context.arg_names.index('related_name')].value
related_name = extract_related_name_value(ctx)
outer_class_info = ctx.api.tscope.classes[-1]
queryset_type = ctx.api.named_generic_type('django.db.models.QuerySet',
@@ -28,11 +40,36 @@ def set_related_name_manager_for_foreign_key(ctx: FunctionContext) -> Type:
return ctx.default_return_type
def set_related_name_instance_for_onetoonefield(ctx: FunctionContext) -> Type:
if 'related_name' not in ctx.context.arg_names:
return ctx.default_return_type
referred_to = extract_to_value_type(ctx)
if referred_to is None:
return ctx.default_return_type
related_name = extract_related_name_value(ctx)
outer_class_info = ctx.api.tscope.classes[-1]
api = cast(TypeChecker, ctx.api)
related_instance_type = api.named_type(outer_class_info.fullname())
related_var = Var(related_name, related_instance_type)
related_var.info = related_instance_type.type
referred_to.type.names[related_name] = SymbolTableNode(MDEF, related_var,
plugin_generated=True)
return ctx.default_return_type
class RelatedFieldsPlugin(Plugin):
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
if fullname == 'django.db.models.fields.related.ForeignKey':
return set_related_name_manager_for_foreign_key
if fullname == 'django.db.models.fields.related.OneToOneField':
return set_related_name_instance_for_onetoonefield
return None

View File

@@ -1,18 +1,4 @@
[case testForeignKeyWithClass]
from django.db import models
class Publisher(models.Model):
pass
class Book(models.Model):
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
book = Book()
reveal_type(book.publisher) # E: Revealed type is 'main.Publisher*'
[out]
[case testForeignKeyRelatedName]
[case testForeignKeyField]
from django.db import models
class Publisher(models.Model):
@@ -22,6 +8,25 @@ class Book(models.Model):
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE,
related_name='books')
book = Book()
reveal_type(book.publisher) # E: Revealed type is 'main.Publisher*'
publisher = Publisher()
reveal_type(publisher.books) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book]'
[out]
[case testOneToOneField]
from django.db import models
class User(models.Model):
pass
class Profile(models.Model):
user = models.OneToOneField(to=User, on_delete=models.CASCADE, related_name='profile')
profile = Profile()
reveal_type(profile.user) # E: Revealed type is 'main.User*'
user = User()
reveal_type(user.profile) # E: Revealed type is 'main.Profile'