mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-06 20:24:31 +08:00
add properly typed FOREIGN_KEY_FIELD_NAME_id fields to models
This commit is contained in:
@@ -2,10 +2,10 @@ import typing
|
||||
from typing import Dict, Optional
|
||||
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import AssignmentStmt, Expression, ImportedName, Lvalue, MypyFile, NameExpr, Statement, SymbolNode, TypeInfo, \
|
||||
ClassDef
|
||||
from mypy.nodes import AssignmentStmt, ClassDef, Expression, FuncDef, ImportedName, Lvalue, MypyFile, NameExpr, SymbolNode, \
|
||||
TypeInfo
|
||||
from mypy.plugin import FunctionContext
|
||||
from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeVarType
|
||||
from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType, UnionType
|
||||
|
||||
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
|
||||
FIELD_FULLNAME = 'django.db.models.fields.Field'
|
||||
@@ -172,7 +172,8 @@ def get_setting_expr(api: TypeChecker, setting_name: str) -> Optional[Expression
|
||||
return None
|
||||
|
||||
|
||||
def iter_over_assignments(class_or_module: typing.Union[ClassDef, MypyFile]) -> typing.Iterator[typing.Tuple[Lvalue, Expression]]:
|
||||
def iter_over_assignments(
|
||||
class_or_module: typing.Union[ClassDef, MypyFile]) -> typing.Iterator[typing.Tuple[Lvalue, Expression]]:
|
||||
if isinstance(class_or_module, ClassDef):
|
||||
statements = class_or_module.defs.body
|
||||
else:
|
||||
@@ -185,3 +186,71 @@ def iter_over_assignments(class_or_module: typing.Union[ClassDef, MypyFile]) ->
|
||||
# not supported yet
|
||||
continue
|
||||
yield stmt.lvalues[0], stmt.rvalue
|
||||
|
||||
|
||||
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
|
||||
if not isinstance(tp, Instance):
|
||||
return None
|
||||
if tp.type.has_base(FIELD_FULLNAME):
|
||||
set_method = tp.type.get_method('__set__')
|
||||
if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType):
|
||||
if 'value' in set_method.type.arg_names:
|
||||
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
|
||||
if isinstance(set_value_type, Instance):
|
||||
set_value_type = fill_typevars(tp, set_value_type)
|
||||
return set_value_type
|
||||
elif isinstance(set_value_type, UnionType):
|
||||
items_no_typevars = []
|
||||
for item in set_value_type.items:
|
||||
if isinstance(item, Instance):
|
||||
item = fill_typevars(tp, item)
|
||||
items_no_typevars.append(item)
|
||||
return UnionType(items_no_typevars)
|
||||
|
||||
field_getter_type = extract_field_getter_type(tp)
|
||||
if field_getter_type:
|
||||
return field_getter_type
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_field_getter_type(tp: Instance) -> Optional[Type]:
|
||||
if not isinstance(tp, Instance):
|
||||
return None
|
||||
if tp.type.has_base(FIELD_FULLNAME):
|
||||
get_method = tp.type.get_method('__get__')
|
||||
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
|
||||
return get_method.type.ret_type
|
||||
# GenericForeignKey
|
||||
if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME):
|
||||
return AnyType(TypeOfAny.special_form)
|
||||
return None
|
||||
|
||||
|
||||
def get_django_metadata(model: TypeInfo) -> Dict[str, typing.Any]:
|
||||
return model.metadata.setdefault('django', {})
|
||||
|
||||
|
||||
def get_related_field_primary_key_names(base_model: TypeInfo) -> typing.List[str]:
|
||||
django_metadata = get_django_metadata(base_model)
|
||||
return django_metadata.setdefault('related_field_primary_keys', [])
|
||||
|
||||
|
||||
def get_fields_metadata(model: TypeInfo) -> Dict[str, typing.Any]:
|
||||
return get_django_metadata(model).setdefault('fields', {})
|
||||
|
||||
|
||||
def extract_primary_key_type_for_set(model: TypeInfo) -> Optional[Type]:
|
||||
for field_name, props in get_fields_metadata(model).items():
|
||||
is_primary_key = props.get('primary_key', False)
|
||||
if is_primary_key:
|
||||
return extract_field_setter_type(model.names[field_name].type)
|
||||
return None
|
||||
|
||||
|
||||
def extract_primary_key_type_for_get(model: TypeInfo) -> Optional[Type]:
|
||||
for field_name, props in get_fields_metadata(model).items():
|
||||
is_primary_key = props.get('primary_key', False)
|
||||
if is_primary_key:
|
||||
return extract_field_getter_type(model.names[field_name].type)
|
||||
return None
|
||||
|
||||
@@ -4,8 +4,8 @@ from typing import Callable, Dict, Optional, cast
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import TypeInfo
|
||||
from mypy.options import Options
|
||||
from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin
|
||||
from mypy.types import Instance, Type, TypeType
|
||||
from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin, AttributeContext
|
||||
from mypy.types import Instance, Type, TypeType, AnyType, TypeOfAny
|
||||
|
||||
from mypy_django_plugin import helpers, monkeypatch
|
||||
from mypy_django_plugin.config import Config
|
||||
@@ -85,6 +85,27 @@ def return_user_model_hook(ctx: FunctionContext) -> Type:
|
||||
return TypeType(Instance(model_info, []))
|
||||
|
||||
|
||||
def extract_and_return_primary_key_of_bound_related_field_parameter(ctx: AttributeContext) -> Type:
|
||||
if not isinstance(ctx.default_attr_type, Instance) or not (ctx.default_attr_type.type.fullname() == 'builtins.int'):
|
||||
return ctx.default_attr_type
|
||||
|
||||
if not isinstance(ctx.type, Instance) or not ctx.type.type.has_base(helpers.MODEL_CLASS_FULLNAME):
|
||||
return ctx.default_attr_type
|
||||
|
||||
field_name = ctx.context.name.split('_')[0]
|
||||
sym = ctx.type.type.get(field_name)
|
||||
if sym and isinstance(sym.type, Instance) and len(sym.type.args) > 0:
|
||||
to_arg = sym.type.args[0]
|
||||
if isinstance(to_arg, AnyType):
|
||||
return AnyType(TypeOfAny.special_form)
|
||||
|
||||
model_type: TypeInfo = to_arg.type
|
||||
primary_key_type = helpers.extract_primary_key_type_for_get(model_type)
|
||||
if primary_key_type:
|
||||
return primary_key_type
|
||||
return ctx.default_attr_type
|
||||
|
||||
|
||||
class DjangoPlugin(Plugin):
|
||||
def __init__(self, options: Options) -> None:
|
||||
super().__init__(options)
|
||||
@@ -186,6 +207,14 @@ class DjangoPlugin(Plugin):
|
||||
|
||||
return None
|
||||
|
||||
def get_attribute_hook(self, fullname: str
|
||||
) -> Optional[Callable[[AttributeContext], Type]]:
|
||||
# sym = self.lookup_fully_qualified(helpers.MODEL_CLASS_FULLNAME)
|
||||
# if sym and isinstance(sym.node, TypeInfo):
|
||||
# if fullname.rpartition('.')[-1] in helpers.get_related_field_primary_key_names(sym.node):
|
||||
return extract_and_return_primary_key_of_bound_related_field_parameter
|
||||
|
||||
|
||||
|
||||
def plugin(version):
|
||||
return DjangoPlugin
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from typing import Any, Dict, Optional, Set, cast
|
||||
from typing import Dict, Optional, Set, cast
|
||||
|
||||
from mypy.checker import TypeChecker
|
||||
from mypy.nodes import FuncDef, TypeInfo, Var
|
||||
from mypy.nodes import TypeInfo, Var
|
||||
from mypy.plugin import FunctionContext, MethodContext
|
||||
from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, UnionType
|
||||
from mypy.types import AnyType, Instance, Type, TypeOfAny, UnionType
|
||||
|
||||
from mypy_django_plugin import helpers
|
||||
from mypy_django_plugin.helpers import extract_field_setter_type, extract_primary_key_type_for_set, get_fields_metadata
|
||||
|
||||
|
||||
def extract_base_pointer_args(model: TypeInfo) -> Set[str]:
|
||||
@@ -103,46 +104,6 @@ def redefine_and_typecheck_model_create(ctx: MethodContext) -> Type:
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
|
||||
if not isinstance(tp, Instance):
|
||||
return None
|
||||
if tp.type.has_base(helpers.FIELD_FULLNAME):
|
||||
set_method = tp.type.get_method('__set__')
|
||||
if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType):
|
||||
if 'value' in set_method.type.arg_names:
|
||||
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
|
||||
if isinstance(set_value_type, Instance):
|
||||
set_value_type = helpers.fill_typevars(tp, set_value_type)
|
||||
return set_value_type
|
||||
elif isinstance(set_value_type, UnionType):
|
||||
items_no_typevars = []
|
||||
for item in set_value_type.items:
|
||||
if isinstance(item, Instance):
|
||||
item = helpers.fill_typevars(tp, item)
|
||||
items_no_typevars.append(item)
|
||||
return UnionType(items_no_typevars)
|
||||
|
||||
get_method = tp.type.get_method('__get__')
|
||||
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
|
||||
return get_method.type.ret_type
|
||||
# GenericForeignKey
|
||||
if tp.type.has_base(helpers.GENERIC_FOREIGN_KEY_FULLNAME):
|
||||
return AnyType(TypeOfAny.special_form)
|
||||
return None
|
||||
|
||||
|
||||
def get_fields_metadata(model: TypeInfo) -> Dict[str, Any]:
|
||||
return model.metadata.setdefault('django', {}).setdefault('fields', {})
|
||||
|
||||
|
||||
def extract_primary_key_type(model: TypeInfo) -> Optional[Type]:
|
||||
for field_name, props in get_fields_metadata(model).items():
|
||||
is_primary_key = props.get('primary_key', False)
|
||||
if is_primary_key:
|
||||
return extract_field_setter_type(model.names[field_name].type)
|
||||
return None
|
||||
|
||||
|
||||
def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]:
|
||||
field_metadata = get_fields_metadata(model).get(field_name, {})
|
||||
if 'choices' in field_metadata:
|
||||
@@ -153,7 +114,7 @@ def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]:
|
||||
def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]:
|
||||
expected_types: Dict[str, Type] = {}
|
||||
|
||||
primary_key_type = extract_primary_key_type(model)
|
||||
primary_key_type = extract_primary_key_type_for_set(model)
|
||||
if not primary_key_type:
|
||||
# no explicit primary key, set pk to Any and add id
|
||||
primary_key_type = AnyType(TypeOfAny.special_form)
|
||||
@@ -178,11 +139,13 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, T
|
||||
|
||||
if tp.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}:
|
||||
ref_to_model = tp.args[0]
|
||||
primary_key_type = AnyType(TypeOfAny.special_form)
|
||||
if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME):
|
||||
primary_key_type = extract_primary_key_type(ref_to_model.type)
|
||||
if not primary_key_type:
|
||||
primary_key_type = AnyType(TypeOfAny.special_form)
|
||||
expected_types[name + '_id'] = primary_key_type
|
||||
typ = extract_primary_key_type_for_set(ref_to_model.type)
|
||||
if typ:
|
||||
primary_key_type = typ
|
||||
expected_types[name + '_id'] = primary_key_type
|
||||
|
||||
if field_type:
|
||||
expected_types[name] = field_type
|
||||
elif isinstance(sym.node.type, AnyType):
|
||||
|
||||
@@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod
|
||||
from typing import Dict, Iterator, List, Optional, Tuple, cast
|
||||
|
||||
import dataclasses
|
||||
from mypy.nodes import ARG_STAR, ARG_STAR2, Argument, AssignmentStmt, CallExpr, ClassDef, Context, Expression, IndexExpr, \
|
||||
from mypy.nodes import ARG_STAR, ARG_STAR2, Argument, CallExpr, ClassDef, Context, Expression, IndexExpr, \
|
||||
Lvalue, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var
|
||||
from mypy.plugin import ClassDefContext
|
||||
from mypy.plugins.common import add_method
|
||||
@@ -74,8 +74,11 @@ def iter_over_one_to_n_related_fields(klass: ClassDef) -> Iterator[Tuple[NameExp
|
||||
class SetIdAttrsForRelatedFields(ModelClassInitializer):
|
||||
def run(self) -> None:
|
||||
for lvalue, rvalue in iter_over_one_to_n_related_fields(self.model_classdef):
|
||||
self.add_new_node_to_model_class(lvalue.name + '_id',
|
||||
typ=self.api.named_type('__builtins__.int'))
|
||||
# base_model_info = self.api.named_type('builtins.object').type
|
||||
# helpers.get_related_field_primary_key_names(base_model_info).append(node_name)
|
||||
node_name = lvalue.name + '_id'
|
||||
self.add_new_node_to_model_class(name=node_name,
|
||||
typ=self.api.builtin_type('builtins.int'))
|
||||
|
||||
|
||||
class InjectAnyAsBaseForNestedMeta(ModelClassInitializer):
|
||||
|
||||
@@ -22,13 +22,11 @@ class Publisher(models.Model):
|
||||
|
||||
class Book(models.Model):
|
||||
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
|
||||
class StylesheetError(Exception):
|
||||
pass
|
||||
owner = models.ForeignKey(db_column='model_id', to='db.Unknown', on_delete=models.CASCADE)
|
||||
|
||||
book = Book()
|
||||
reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int'
|
||||
reveal_type(book.owner_id) # E: Revealed type is 'builtins.int'
|
||||
reveal_type(book.owner_id) # E: Revealed type is 'Any'
|
||||
|
||||
[CASE test_foreign_key_field_different_order_of_params]
|
||||
from django.db import models
|
||||
@@ -68,7 +66,7 @@ from django.db import models
|
||||
class Publisher(models.Model):
|
||||
pass
|
||||
|
||||
[CASE test_to_parameter_as_string_with_application_name__fallbacks_to_any_if_model_not_present_in_dependency_graph]
|
||||
[CASE test_to_parameter_as_string_with_application_name_fallbacks_to_any_if_model_not_present_in_dependency_graph]
|
||||
from django.db import models
|
||||
|
||||
class Book(models.Model):
|
||||
@@ -76,6 +74,9 @@ class Book(models.Model):
|
||||
|
||||
book = Book()
|
||||
reveal_type(book.publisher) # E: Revealed type is 'Any'
|
||||
reveal_type(book.publisher_id) # E: Revealed type is 'Any'
|
||||
Book(publisher_id=1)
|
||||
Book.objects.create(publisher_id=1)
|
||||
|
||||
[file myapp/__init__.py]
|
||||
[file myapp/models.py]
|
||||
@@ -247,3 +248,30 @@ class Publisher(models.Model):
|
||||
class Book(models.Model):
|
||||
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
|
||||
reveal_type(Publisher().book_set) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.Book]'
|
||||
|
||||
[CASE underscore_id_attribute_has_set_type_of_primary_key_if_explicit]
|
||||
from django.db import models
|
||||
import datetime
|
||||
class Publisher(models.Model):
|
||||
mypk = models.CharField(max_length=100, primary_key=True)
|
||||
class Book(models.Model):
|
||||
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
|
||||
|
||||
reveal_type(Book().publisher_id) # E: Revealed type is 'builtins.str'
|
||||
Book(publisher_id=1)
|
||||
Book(publisher_id='hello')
|
||||
Book(publisher_id=datetime.datetime.now()) # E: Incompatible type for "publisher_id" of "Book" (got "datetime", expected "Union[str, int, Combinable]")
|
||||
Book.objects.create(publisher_id=1)
|
||||
Book.objects.create(publisher_id='hello')
|
||||
|
||||
class Publisher2(models.Model):
|
||||
mypk = models.IntegerField(primary_key=True)
|
||||
class Book2(models.Model):
|
||||
publisher = models.ForeignKey(to=Publisher2, on_delete=models.CASCADE)
|
||||
|
||||
reveal_type(Book2().publisher_id) # E: Revealed type is 'builtins.int'
|
||||
Book2(publisher_id=1)
|
||||
Book2(publisher_id='hello') # E: Incompatible type for "publisher_id" of "Book2" (got "str", expected "Union[int, Combinable, Literal['']]")
|
||||
Book2.objects.create(publisher_id=1)
|
||||
Book2.objects.create(publisher_id='hello') # E: Incompatible type for "publisher_id" of "Book2" (got "str", expected "Union[int, Combinable, Literal['']]")
|
||||
[out]
|
||||
|
||||
Reference in New Issue
Block a user