diff --git a/django-stubs/__init__.pyi b/django-stubs/__init__.pyi index 8d41882..83b267d 100644 --- a/django-stubs/__init__.pyi +++ b/django-stubs/__init__.pyi @@ -1,7 +1,12 @@ -from typing import Any +from typing import Any, NamedTuple from .utils.version import get_version as get_version VERSION: Any __version__: str def setup(set_prefix: bool = ...) -> None: ... + +# Used by mypy_django_plugin when returning a QuerySet row that is a NamedTuple where the field names are unknown +class _NamedTupleAnyAttr(NamedTuple): + def __getattr__(self, item: str) -> Any: ... + def __setattr__(self, item: str, value: Any) -> None: ... diff --git a/django-stubs/db/models/query.pyi b/django-stubs/db/models/query.pyi index 3c1bddb..b578190 100644 --- a/django-stubs/db/models/query.pyi +++ b/django-stubs/db/models/query.pyi @@ -97,8 +97,9 @@ class QuerySet(Generic[_T, _Row], Collection[_Row], Sized): def raw( self, raw_query: str, params: Any = ..., translations: Optional[Dict[str, str]] = ..., using: None = ... ) -> RawQuerySet: ... + # The type of values may be overridden to be more specific in the mypy plugin, depending on the fields param def values(self, *fields: Union[str, Combinable], **expressions: Any) -> QuerySet[_T, Dict[str, Any]]: ... - # The type of values_list is overridden to be more specific in the mypy django plugin + # The type of values_list may be overridden to be more specific in the mypy plugin, depending on the fields param def values_list( self, *fields: Union[str, Combinable], flat: bool = ..., named: bool = ... ) -> QuerySet[_T, Any]: ... diff --git a/mypy_django_plugin/helpers.py b/mypy_django_plugin/helpers.py index ee3a98b..1ec3e31 100644 --- a/mypy_django_plugin/helpers.py +++ b/mypy_django_plugin/helpers.py @@ -1,14 +1,16 @@ import typing -from typing import Dict, Optional +from collections import OrderedDict +from typing import Dict, Optional, cast -from mypy.checker import TypeChecker +from mypy.checker import TypeChecker, gen_unique_name +from mypy.mro import calculate_mro from mypy.nodes import ( AssignmentStmt, ClassDef, Expression, ImportedName, Lvalue, MypyFile, NameExpr, SymbolNode, TypeInfo, -) + SymbolTable, SymbolTableNode, Block, GDEF, MDEF, Var) from mypy.plugin import FunctionContext, MethodContext from mypy.types import ( AnyType, Instance, NoneTyp, Type, TypeOfAny, TypeVarType, UnionType, -) + TupleType, TypedDictType) MODEL_CLASS_FULLNAME = 'django.db.models.base.Model' FIELD_FULLNAME = 'django.db.models.fields.Field' @@ -211,7 +213,7 @@ def extract_field_setter_type(tp: Instance) -> Optional[Type]: return None -def extract_field_getter_type(tp: Instance) -> Optional[Type]: +def extract_field_getter_type(tp: Type) -> Optional[Type]: if not isinstance(tp, Instance): return None if tp.type.has_base(FIELD_FULLNAME): @@ -235,6 +237,10 @@ def get_fields_metadata(model: TypeInfo) -> Dict[str, typing.Any]: return get_django_metadata(model).setdefault('fields', {}) +def get_lookups_metadata(model: TypeInfo) -> Dict[str, typing.Any]: + return get_django_metadata(model).setdefault('lookups', {}) + + def extract_explicit_set_type_of_model_primary_key(model: TypeInfo) -> Optional[Type]: """ If field with primary_key=True is set on the model, extract its __set__ type. @@ -296,3 +302,84 @@ def get_assigned_value_for_class(type_info: TypeInfo, name: str) -> Optional[Exp if isinstance(lvalue, NameExpr) and lvalue.name == name: return rvalue return None + + +def is_field_nullable(model: TypeInfo, field_name: str) -> bool: + return get_fields_metadata(model).get(field_name, {}).get('null', False) + + +def is_foreign_key(t: Type) -> bool: + if not isinstance(t, Instance): + return False + return has_any_of_bases(t.type, (FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME)) + + +def build_class_with_annotated_fields(api: TypeChecker, base: Type, fields: 'OrderedDict[str, Type]', + name: str) -> Instance: + """Build an Instance with `name` that contains the specified `fields` as attributes and extends `base`.""" + # Credit: This code is largely copied/modified from TypeChecker.intersect_instance_callable and + # NamedTupleAnalyzer.build_namedtuple_typeinfo + + cur_module = cast(MypyFile, api.scope.stack[0]) + gen_name = gen_unique_name(name, cur_module.names) + + cdef = ClassDef(name, Block([])) + cdef.fullname = cur_module.fullname() + '.' + gen_name + info = TypeInfo(SymbolTable(), cdef, cur_module.fullname()) + cdef.info = info + info.bases = [base] + + def add_field(var: Var, is_initialized_in_class: bool = False, + is_property: bool = False) -> None: + var.info = info + var.is_initialized_in_class = is_initialized_in_class + var.is_property = is_property + var._fullname = '%s.%s' % (info.fullname(), var.name()) + info.names[var.name()] = SymbolTableNode(MDEF, var) + + vars = [Var(item, typ) for item, typ in fields.items()] + for var in vars: + add_field(var, is_property=True) + + calculate_mro(info) + info.calculate_metaclass_type() + + cur_module.names[gen_name] = SymbolTableNode(GDEF, info, plugin_generated=True) + return Instance(info, []) + + +def make_named_tuple(api: TypeChecker, fields: 'OrderedDict[str, Type]', name: str) -> Type: + if not fields: + # No fields specified, so fallback to a subclass of NamedTuple that allows + # __getattr__ / __setattr__ for any attribute name. + fallback = api.named_generic_type('django._NamedTupleAnyAttr', []) + else: + fallback = build_class_with_annotated_fields( + api=api, + base=api.named_generic_type('typing.NamedTuple', []), + fields=fields, + name=name + ) + return TupleType(list(fields.values()), fallback=fallback) + + +def make_typeddict(api: TypeChecker, fields: 'OrderedDict[str, Type]', required_keys: typing.Set[str]) -> Type: + object_type = api.named_generic_type('mypy_extensions._TypedDict', []) + typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type) + return typed_dict_type + + +def make_tuple(api: TypeChecker, fields: typing.List[Type]) -> Type: + implicit_any = AnyType(TypeOfAny.special_form) + fallback = api.named_generic_type('builtins.tuple', [implicit_any]) + return TupleType(fields, fallback=fallback) + + +def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> Type: + node = type_info.get(private_field_name).node + if isinstance(node, Var): + descriptor_type = node.type + if is_nullable: + descriptor_type = make_optional(descriptor_type) + return descriptor_type + return AnyType(TypeOfAny.unannotated) diff --git a/mypy_django_plugin/lookups.py b/mypy_django_plugin/lookups.py new file mode 100644 index 0000000..48b48d1 --- /dev/null +++ b/mypy_django_plugin/lookups.py @@ -0,0 +1,150 @@ +import dataclasses +from typing import Union, List + +from mypy.nodes import TypeInfo +from mypy.plugin import CheckerPluginInterface +from mypy.types import Type, Instance + +from mypy_django_plugin import helpers + + +@dataclasses.dataclass +class RelatedModelNode: + typ: Instance + is_nullable: bool + + +@dataclasses.dataclass +class FieldNode: + typ: Type + + +LookupNode = Union[RelatedModelNode, FieldNode] + + +class LookupException(Exception): + pass + + +def resolve_lookup(api: CheckerPluginInterface, model_type_info: TypeInfo, lookup: str) -> List[LookupNode]: + """Resolve a lookup str to a list of LookupNodes. + + Each node represents a part of the lookup (separated by "__"), in order. + Each node is the Model or Field that was resolved. + + Raises LookupException if there were any issues resolving the lookup. + """ + lookup_parts = lookup.split("__") + + nodes = [] + while lookup_parts: + lookup_part = lookup_parts.pop(0) + + if not nodes: + current_node = None + else: + current_node = nodes[-1] + + if current_node is None: + new_node = resolve_model_lookup(api, model_type_info, lookup_part) + elif isinstance(current_node, RelatedModelNode): + new_node = resolve_model_lookup(api, current_node.typ.type, lookup_part) + elif isinstance(current_node, FieldNode): + raise LookupException(f"Field lookups not yet supported for lookup {lookup}") + else: + raise LookupException(f"Unsupported node type: {type(current_node)}") + nodes.append(new_node) + return nodes + + +def resolve_model_lookup(api: CheckerPluginInterface, model_type_info: TypeInfo, + lookup: str) -> LookupNode: + """Resolve a lookup on the given model.""" + if lookup == 'pk': + # Primary keys are special-cased + primary_key_type = helpers.extract_primary_key_type_for_get(model_type_info) + if primary_key_type: + return FieldNode(primary_key_type) + else: + # No PK, use the get type for AutoField as PK type. + autofield_info = api.lookup_typeinfo('django.db.models.fields.AutoField') + pk_type = helpers.get_private_descriptor_type(autofield_info, '_pyi_private_get_type', + is_nullable=False) + return FieldNode(pk_type) + + field_name = get_actual_field_name_for_lookup_field(lookup, model_type_info) + + field_node = model_type_info.get(field_name) + if not field_node: + raise LookupException( + f'When resolving lookup "{lookup}", field "{field_name}" was not found in model {model_type_info.name()}') + + if field_name.endswith('_id'): + field_name_without_id = field_name.rstrip('_id') + foreign_key_field = model_type_info.get(field_name_without_id) + if foreign_key_field is not None and helpers.is_foreign_key(foreign_key_field.type): + # Hack: If field ends with '_id' and there is a model field without the '_id' suffix, then use that field. + field_node = foreign_key_field + field_name = field_name_without_id + + field_node_type = field_node.type + if field_node_type is None or not isinstance(field_node_type, Instance): + raise LookupException( + f'When resolving lookup "{lookup}", could not determine type for {model_type_info.name()}.{field_name}') + + if helpers.is_foreign_key(field_node_type): + field_type = helpers.extract_field_getter_type(field_node_type) + is_nullable = helpers.is_optional(field_type) + if is_nullable: + field_type = helpers.make_required(field_type) + + if isinstance(field_type, Instance): + return RelatedModelNode(typ=field_type, is_nullable=is_nullable) + else: + raise LookupException(f"Not an instance for field {field_type} lookup {lookup}") + + field_type = helpers.extract_field_getter_type(field_node_type) + + if field_type: + return FieldNode(typ=field_type) + else: + # Not a Field + if field_name == 'id': + # If no 'id' field was fouond, use an int + return FieldNode(api.named_generic_type('builtins.int', [])) + + related_manager_arg = None + if field_node_type.type.has_base(helpers.RELATED_MANAGER_CLASS_FULLNAME): + related_manager_arg = field_node_type.args[0] + + if related_manager_arg is not None: + # Reverse relation + return RelatedModelNode(typ=related_manager_arg, is_nullable=True) + raise LookupException( + f'When resolving lookup "{lookup}", could not determine type for {model_type_info.name()}.{field_name}') + + +def get_actual_field_name_for_lookup_field(lookup: str, model_type_info: TypeInfo) -> str: + """Attempt to find out the real field name if this lookup is a related_query_name (for reverse relations). + + If it's not, return the original lookup. + """ + lookups_metadata = helpers.get_lookups_metadata(model_type_info) + lookup_metadata = lookups_metadata.get(lookup) + if lookup_metadata is None: + # If not found on current model, look in all bases for their lookup metadata + for base in model_type_info.mro: + lookups_metadata = helpers.get_lookups_metadata(base) + lookup_metadata = lookups_metadata.get(lookup) + if lookup_metadata: + break + if not lookup_metadata: + lookup_metadata = {} + related_name = lookup_metadata.get('related_query_name_target', None) + if related_name: + # If the lookup is a related lookup, then look at the field specified by related_name. + # This is to support if related_query_name is set and differs from. + field_name = related_name + else: + field_name = lookup + return field_name diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index fbbf27d..f004dfa 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -1,6 +1,5 @@ -from functools import partial - import os +from functools import partial from typing import Callable, Dict, Optional, Union, cast from mypy.checker import TypeChecker @@ -23,6 +22,7 @@ from mypy_django_plugin.transformers.migrations import ( determine_model_cls_from_string_for_migrations, get_string_value_from_expr, ) from mypy_django_plugin.transformers.models import process_model_class +from mypy_django_plugin.transformers.queryset import extract_proper_type_for_values_and_values_list from mypy_django_plugin.transformers.settings import ( AddSettingValuesToDjangoConfObject, get_settings_metadata, ) @@ -165,7 +165,7 @@ def extract_and_return_primary_key_of_bound_related_field_parameter(ctx: Attribu if primary_key_type: return primary_key_type - is_nullable = helpers.get_fields_metadata(ctx.type.type).get(field_name, {}).get('null', False) + is_nullable = helpers.is_field_nullable(ctx.type.type, field_name) if is_nullable: return helpers.make_optional(ctx.default_attr_type) @@ -292,7 +292,10 @@ class DjangoPlugin(Plugin): if self.django_settings_module: settings_modules.append(self.django_settings_module) - monkeypatch.add_modules_as_a_source_seed_files(settings_modules) + auto_imports = ['mypy_extensions'] + auto_imports.extend(settings_modules) + + monkeypatch.add_modules_as_a_source_seed_files(auto_imports) monkeypatch.inject_modules_as_dependencies_for_django_conf_settings(settings_modules) def _get_current_model_bases(self) -> Dict[str, int]: @@ -359,10 +362,10 @@ class DjangoPlugin(Plugin): if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME): return extract_proper_type_for_get_form - if method_name == 'values_list': + if method_name in ('values', 'values_list'): sym = self.lookup_fully_qualified(class_name) if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.QUERYSET_CLASS_FULLNAME): - return extract_proper_type_for_values_list + return partial(extract_proper_type_for_values_and_values_list, method_name) if fullname in {'django.apps.registry.Apps.get_model', 'django.db.migrations.state.StateApps.get_model'}: diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index f780580..bb96d0d 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -1,10 +1,10 @@ from typing import Optional, cast from mypy.checker import TypeChecker -from mypy.nodes import ListExpr, NameExpr, StrExpr, TupleExpr, TypeInfo, Var +from mypy.nodes import ListExpr, NameExpr, StrExpr, TupleExpr, TypeInfo from mypy.plugin import FunctionContext from mypy.types import ( - AnyType, CallableType, Instance, TupleType, Type, TypeOfAny, UnionType, + AnyType, CallableType, Instance, TupleType, Type, UnionType, ) from mypy_django_plugin import helpers @@ -88,23 +88,13 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext) -> Type: return helpers.reparametrize_instance(ctx.default_return_type, new_args=args) -def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is_nullable: bool) -> Type: - node = type_info.get(private_field_name).node - if isinstance(node, Var): - descriptor_type = node.type - if is_nullable: - descriptor_type = helpers.make_optional(descriptor_type) - return descriptor_type - return AnyType(TypeOfAny.unannotated) - - def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance: default_return_type = cast(Instance, ctx.default_return_type) is_nullable = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'null')) - set_type = get_private_descriptor_type(default_return_type.type, '_pyi_private_set_type', - is_nullable=is_nullable) - get_type = get_private_descriptor_type(default_return_type.type, '_pyi_private_get_type', - is_nullable=is_nullable) + set_type = helpers.get_private_descriptor_type(default_return_type.type, '_pyi_private_set_type', + is_nullable=is_nullable) + get_type = helpers.get_private_descriptor_type(default_return_type.type, '_pyi_private_get_type', + is_nullable=is_nullable) return helpers.reparametrize_instance(default_return_type, [set_type, get_type]) diff --git a/mypy_django_plugin/transformers/init_create.py b/mypy_django_plugin/transformers/init_create.py index 7b8e982..4e4f97d 100644 --- a/mypy_django_plugin/transformers/init_create.py +++ b/mypy_django_plugin/transformers/init_create.py @@ -6,7 +6,6 @@ from mypy.plugin import FunctionContext, MethodContext from mypy.types import AnyType, Instance, Type, TypeOfAny from mypy_django_plugin import helpers -from mypy_django_plugin.transformers.fields import get_private_descriptor_type def extract_base_pointer_args(model: TypeInfo) -> Set[str]: @@ -162,8 +161,8 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo, if not pk_type: # extract set type of AutoField autofield_info = api.lookup_typeinfo('django.db.models.fields.AutoField') - pk_type = get_private_descriptor_type(autofield_info, '_pyi_private_set_type', - is_nullable=is_nullable) + pk_type = helpers.get_private_descriptor_type(autofield_info, '_pyi_private_set_type', + is_nullable=is_nullable) related_primary_key_type = pk_type if is_init: @@ -185,8 +184,8 @@ def extract_expected_types(ctx: FunctionContext, model: TypeInfo, # if CharField(blank=True,...) and not nullable, then field can be None in __init__ elif ( - helpers.has_any_of_bases(typ.type, (helpers.CHAR_FIELD_FULLNAME,)) and is_init and - field_metadata.get('blank', False) and not field_metadata.get('null', False) + helpers.has_any_of_bases(typ.type, (helpers.CHAR_FIELD_FULLNAME,)) and is_init and + field_metadata.get('blank', False) and not field_metadata.get('null', False) ): field_type = helpers.make_optional(field_type) diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index 5db4dbb..ff872f4 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -172,17 +172,35 @@ class AddRelatedManagers(ModelClassInitializer): ref_to_fullname = module_name + '.' + exc.model_cls_name if self.model_classdef.fullname == ref_to_fullname: - related_manager_name = defn.name.lower() + '_set' + related_name = defn.name.lower() + '_set' if 'related_name' in rvalue.arg_names: related_name_expr = rvalue.args[rvalue.arg_names.index('related_name')] if not isinstance(related_name_expr, StrExpr): - return None - related_manager_name = related_name_expr.value + continue + related_name = related_name_expr.value + if related_name == '+': + # No backwards relation is desired + continue + if 'related_query_name' in rvalue.arg_names: + related_query_name_expr = rvalue.args[rvalue.arg_names.index('related_query_name')] + if not isinstance(related_query_name_expr, StrExpr): + related_query_name = None + else: + related_query_name = related_query_name_expr.value + # TODO: Handle defaulting to model name if related_name is not set + else: + # No related_query_name specified, default to related_name + related_query_name = related_name typ = get_related_field_type(rvalue, self.api, defn.info) if typ is None: - return None - self.add_new_node_to_model_class(related_manager_name, typ) + continue + self.add_new_node_to_model_class(related_name, typ) + if related_query_name is not None: + # Only create related_query_name if it is a string literal + helpers.get_lookups_metadata(self.model_classdef.info)[related_query_name] = { + 'related_query_name_target': related_name + } def iter_over_classdefs(module_file: MypyFile) -> Iterator[ClassDef]: diff --git a/mypy_django_plugin/transformers/queryset.py b/mypy_django_plugin/transformers/queryset.py new file mode 100644 index 0000000..33a6af0 --- /dev/null +++ b/mypy_django_plugin/transformers/queryset.py @@ -0,0 +1,139 @@ +from collections import OrderedDict +from typing import Union, List, cast, Optional + +from mypy.checker import TypeChecker +from mypy.nodes import StrExpr, TypeInfo +from mypy.plugin import MethodContext, CheckerPluginInterface +from mypy.types import Type, Instance, AnyType, TypeOfAny + +from mypy_django_plugin import helpers +from mypy_django_plugin.lookups import resolve_lookup, RelatedModelNode, LookupException + + +def extract_proper_type_for_values_and_values_list(method_name: str, ctx: MethodContext) -> Type: + api = cast(TypeChecker, ctx.api) + + object_type = ctx.type + if not isinstance(object_type, Instance): + return ctx.default_return_type + + ret = ctx.default_return_type + + any_type = AnyType(TypeOfAny.implementation_artifact) + fields_arg_expr = ctx.args[ctx.callee_arg_names.index('fields')] + + model_arg: Union[AnyType, Type] = ret.args[0] if len(ret.args) > 0 else any_type + + column_names: List[Optional[str]] = [] + column_types: OrderedDict[str, Type] = OrderedDict() + + fill_column_types = True + + if len(fields_arg_expr) == 0: + # values_list/values with no args is not yet supported, so default to Any types for field types + # It should in the future include all model fields, "extra" fields and "annotated" fields + fill_column_types = False + + if isinstance(model_arg, Instance): + model_type_info = model_arg.type + else: + model_type_info = None + + # Figure out each field name passed to fields + has_dynamic_column_names = False + for field_expr in fields_arg_expr: + if isinstance(field_expr, StrExpr): + field_name = field_expr.value + column_names.append(field_name) + # Default to any type + column_types[field_name] = any_type + + if model_type_info: + resolved_lookup_type = resolve_values_lookup(ctx.api, model_type_info, field_name) + if resolved_lookup_type is not None: + column_types[field_name] = resolved_lookup_type + else: + # Dynamic field names are partially supported for values_list, but not values + column_names.append(None) + has_dynamic_column_names = True + + if method_name == 'values_list': + flat = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'flat')) + named = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'named')) + + if named and flat: + api.fail("'flat' and 'named' can't be used together.", ctx.context) + return ret + elif named: + if fill_column_types and not has_dynamic_column_names: + row_arg = helpers.make_named_tuple(api, fields=column_types, name="Row") + else: + row_arg = helpers.make_named_tuple(api, fields=OrderedDict(), name="Row") + elif flat: + if len(ctx.args[0]) > 1: + api.fail("'flat' is not valid when values_list is called with more than one field.", ctx.context) + return ret + if fill_column_types and not has_dynamic_column_names: + # Grab first element + row_arg = column_types[column_names[0]] + else: + row_arg = any_type + else: + if fill_column_types: + args = [ + # Fallback to Any if the column name is unknown (e.g. dynamic) + column_types.get(column_name, any_type) if column_name is not None else any_type + for column_name in column_names + ] + else: + args = [any_type] + row_arg = helpers.make_tuple(api, fields=args) + elif method_name == 'values': + expression_arg_names = ctx.arg_names[ctx.callee_arg_names.index('expressions')] + for expression_name in expression_arg_names: + # Arbitrary additional annotation expressions are supported, but they all have type Any for now + column_names.append(expression_name) + column_types[expression_name] = any_type + + if fill_column_types and not has_dynamic_column_names: + row_arg = helpers.make_typeddict(api, fields=column_types, required_keys=set()) + else: + return ctx.default_return_type + else: + raise Exception(f"extract_proper_type_for_values_list doesn't support method {method_name}") + + new_type_args = [model_arg, row_arg] + return helpers.reparametrize_instance(ret, new_type_args) + + +def resolve_values_lookup(api: CheckerPluginInterface, model_type_info: TypeInfo, lookup: str) -> Optional[Type]: + """Resolves a values/values_list lookup if possible, to a Type.""" + try: + nodes = resolve_lookup(api, model_type_info, lookup) + except LookupException: + nodes = [] + + if not nodes: + return None + + make_optional = False + + for node in nodes: + if isinstance(node, RelatedModelNode) and node.is_nullable: + # All lookups following a relation which is nullable should be optional + make_optional = True + + node = nodes[-1] + + node_type = node.typ + if isinstance(node, RelatedModelNode): + # Related models used in values/values_list get resolved to the primary key of the related model. + # So, we lookup the pk of that model. + pk_lookup_nodes = resolve_lookup(api, node_type.type, "pk") + if not pk_lookup_nodes: + return None + node_type = pk_lookup_nodes[0].typ + if make_optional: + return helpers.make_optional(node_type) + else: + return node_type diff --git a/test-data/typecheck/queryset.test b/test-data/typecheck/queryset.test index 89926af..555682e 100644 --- a/test-data/typecheck/queryset.test +++ b/test-data/typecheck/queryset.test @@ -9,13 +9,15 @@ class Blog(models.Model): class BlogQuerySet(models.QuerySet[Blog]): pass +class Entry(models.Model): + blog = models.ForeignKey(Blog, on_delete=models.CASCADE, related_name="entries") + title = models.CharField(max_length=100) + + # Test that second type argument gets filled automatically blog_qs: models.QuerySet[Blog] reveal_type(blog_qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog, main.Blog]' -Blog.objects.values_list('id', flat=True, named=True) # E: 'flat' and 'named' can't be used together. -Blog.objects.values_list('id', 'extra_arg', flat=True) # E: 'flat' is not valid when values_list is called with more than one field. - reveal_type(Blog.objects.in_bulk([1])) # E: Revealed type is 'builtins.dict[Any, main.Blog*]' reveal_type(Blog.objects.in_bulk()) # E: Revealed type is 'builtins.dict[Any, main.Blog*]' reveal_type(Blog.objects.in_bulk(['beatles_blog'], field_name='name')) # E: Revealed type is 'builtins.dict[Any, main.Blog*]' @@ -23,9 +25,9 @@ reveal_type(Blog.objects.in_bulk(['beatles_blog'], field_name='name')) # E: Reve # When ANDing QuerySets, the left-side's _Row parameter is used reveal_type(Blog.objects.all() & Blog.objects.values()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, main.Blog*]' reveal_type(Blog.objects.values() & Blog.objects.values()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.dict*[builtins.str, Any]]' -reveal_type(Blog.objects.values_list('id', 'name') & Blog.objects.values()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.tuple*[Any]]' -reveal_type(Blog.objects.values_list('id', 'name', named=True) & Blog.objects.values()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, typing.NamedTuple*]' -reveal_type(Blog.objects.values_list('id', flat=True) & Blog.objects.values()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Any]' +reveal_type(Blog.objects.values_list('id', 'name') & Blog.objects.values()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Tuple[builtins.int, builtins.str]]' +reveal_type(Blog.objects.values_list('id', 'name', named=True) & Blog.objects.values()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Tuple[builtins.int, builtins.str, fallback=main.Row]]' +reveal_type(Blog.objects.values_list('id', flat=True) & Blog.objects.values()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.int*]' # .dates / .datetimes reveal_type(Blog.objects.dates("created_at", "day")) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, datetime.date]' @@ -57,41 +59,203 @@ reveal_type(values_qs.in_bulk()) # E: Revealed type is 'builtins.dict[Any, main. values_list_qs = Blog.objects.values_list('id', 'name') -reveal_type(values_list_qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.tuple[Any]]' -reveal_type(values_list_qs.all()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.tuple*[Any]]' -reveal_type(values_list_qs.get(id=1)) # E: Revealed type is 'builtins.tuple*[Any]' -reveal_type(iter(values_list_qs)) # E: Revealed type is 'typing.Iterator[builtins.tuple*[Any]]' -reveal_type(values_list_qs.iterator()) # E: Revealed type is 'typing.Iterator[builtins.tuple*[Any]]' -reveal_type(values_list_qs.first()) # E: Revealed type is 'Union[builtins.tuple*[Any], None]' -reveal_type(values_list_qs.earliest()) # E: Revealed type is 'builtins.tuple*[Any]' -reveal_type(values_list_qs[0]) # E: Revealed type is 'builtins.tuple*[Any]' -reveal_type(values_list_qs[:9]) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.tuple*[Any]]' +reveal_type(values_list_qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Tuple[builtins.int, builtins.str]]' +reveal_type(values_list_qs.all()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Tuple[builtins.int, builtins.str]]' +reveal_type(values_list_qs.get(id=1)) # E: Revealed type is 'Tuple[builtins.int, builtins.str]' +reveal_type(iter(values_list_qs)) # E: Revealed type is 'typing.Iterator[Tuple[builtins.int, builtins.str]]' +reveal_type(values_list_qs.iterator()) # E: Revealed type is 'typing.Iterator[Tuple[builtins.int, builtins.str]]' +reveal_type(values_list_qs.first()) # E: Revealed type is 'Union[Tuple[builtins.int, builtins.str], None]' +reveal_type(values_list_qs.earliest()) # E: Revealed type is 'Tuple[builtins.int, builtins.str]' +reveal_type(values_list_qs[0]) # E: Revealed type is 'Tuple[builtins.int, builtins.str]' +reveal_type(values_list_qs[:9]) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Tuple[builtins.int, builtins.str]]' reveal_type(values_list_qs.in_bulk()) # E: Revealed type is 'builtins.dict[Any, main.Blog*]' flat_values_list_qs = Blog.objects.values_list('id', flat=True) -reveal_type(flat_values_list_qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Any]' -reveal_type(flat_values_list_qs.all()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Any]' -reveal_type(flat_values_list_qs.get(id=1)) # E: Revealed type is 'Any' -reveal_type(iter(flat_values_list_qs)) # E: Revealed type is 'typing.Iterator[Any]' -reveal_type(flat_values_list_qs.iterator()) # E: Revealed type is 'typing.Iterator[Any]' -reveal_type(flat_values_list_qs.first()) # E: Revealed type is 'Union[Any, None]' -reveal_type(flat_values_list_qs.earliest()) # E: Revealed type is 'Any' -reveal_type(flat_values_list_qs[0]) # E: Revealed type is 'Any' -reveal_type(flat_values_list_qs[:9]) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Any]' +reveal_type(flat_values_list_qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.int]' +reveal_type(flat_values_list_qs.all()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.int*]' +reveal_type(flat_values_list_qs.get(id=1)) # E: Revealed type is 'builtins.int*' +reveal_type(iter(flat_values_list_qs)) # E: Revealed type is 'typing.Iterator[builtins.int*]' +reveal_type(flat_values_list_qs.iterator()) # E: Revealed type is 'typing.Iterator[builtins.int*]' +reveal_type(flat_values_list_qs.first()) # E: Revealed type is 'Union[builtins.int*, None]' +reveal_type(flat_values_list_qs.earliest()) # E: Revealed type is 'builtins.int*' +reveal_type(flat_values_list_qs[0]) # E: Revealed type is 'builtins.int*' +reveal_type(flat_values_list_qs[:9]) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.int*]' reveal_type(flat_values_list_qs.in_bulk()) # E: Revealed type is 'builtins.dict[Any, main.Blog*]' named_values_list_qs = Blog.objects.values_list('id', named=True) -reveal_type(named_values_list_qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, typing.NamedTuple]' -reveal_type(named_values_list_qs.all()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, typing.NamedTuple*]' -reveal_type(named_values_list_qs.get(id=1)) # E: Revealed type is 'typing.NamedTuple*' -reveal_type(iter(named_values_list_qs)) # E: Revealed type is 'typing.Iterator[typing.NamedTuple*]' -reveal_type(named_values_list_qs.iterator()) # E: Revealed type is 'typing.Iterator[typing.NamedTuple*]' -reveal_type(named_values_list_qs.first()) # E: Revealed type is 'Union[typing.NamedTuple*, None]' -reveal_type(named_values_list_qs.earliest()) # E: Revealed type is 'typing.NamedTuple*' -reveal_type(named_values_list_qs[0]) # E: Revealed type is 'typing.NamedTuple*' -reveal_type(named_values_list_qs[:9]) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, typing.NamedTuple*]' +reveal_type(named_values_list_qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Tuple[builtins.int, fallback=main.Row1]]' +reveal_type(named_values_list_qs.all()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Tuple[builtins.int, fallback=main.Row1]]' +reveal_type(named_values_list_qs.get(id=1)) # E: Revealed type is 'Tuple[builtins.int, fallback=main.Row1]' +reveal_type(iter(named_values_list_qs)) # E: Revealed type is 'typing.Iterator[Tuple[builtins.int, fallback=main.Row1]]' +reveal_type(named_values_list_qs.iterator()) # E: Revealed type is 'typing.Iterator[Tuple[builtins.int, fallback=main.Row1]]' +reveal_type(named_values_list_qs.first()) # E: Revealed type is 'Union[Tuple[builtins.int, fallback=main.Row1], None]' +reveal_type(named_values_list_qs.earliest()) # E: Revealed type is 'Tuple[builtins.int, fallback=main.Row1]' +reveal_type(named_values_list_qs[0]) # E: Revealed type is 'Tuple[builtins.int, fallback=main.Row1]' +reveal_type(named_values_list_qs[:9]) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Tuple[builtins.int, fallback=main.Row1]]' reveal_type(named_values_list_qs.in_bulk()) # E: Revealed type is 'builtins.dict[Any, main.Blog*]' [out] + +[CASE test_queryset_values_list_custom_primary_key] +from django.db import models + +class Blog(models.Model): + primary_uuid = models.UUIDField(primary_key=True) + +class Entry(models.Model): + blog = models.ForeignKey(Blog, on_delete=models.CASCADE, related_name="entries") + +# Blog has a primary key field specified, so no automatic 'id' field is expected to exist +reveal_type(Blog.objects.values_list('id', flat=True).get()) # E: Revealed type is 'Any' + +# Access Blog's pk (which is UUID field) +reveal_type(Blog.objects.values_list('pk', flat=True).get()) # E: Revealed type is 'uuid.UUID*' + +# Accessing PK of model pointed to by foreign key +reveal_type(Entry.objects.values_list('blog', flat=True).get()) # E: Revealed type is 'uuid.UUID*' +# Alternative way of accessing PK of model pointed to by foreign key +reveal_type(Entry.objects.values_list('blog_id', flat=True).get()) # E: Revealed type is 'uuid.UUID*' +# Yet another (more explicit) way of accessing PK of related model +reveal_type(Entry.objects.values_list('blog__pk', flat=True).get()) # E: Revealed type is 'uuid.UUID*' + +# Blog has a primary key field specified, so no automatic 'id' field is expected to exist +reveal_type(Entry.objects.values_list('blog__id', flat=True).get()) # E: Revealed type is 'Any' + +[CASE test_queryset_values_list] +from django.db import models + +class Blog(models.Model): + name = models.CharField(max_length=100) + created_at = models.DateTimeField() + +class Entry(models.Model): + blog = models.ForeignKey(Blog, on_delete=models.CASCADE, related_name="entries") + nullable_blog = models.ForeignKey(Blog, on_delete=models.CASCADE, related_name="+", null=True) + blog_with_related_query_name = models.ForeignKey(Blog, on_delete=models.CASCADE, related_query_name="my_related_query_name") + title = models.CharField(max_length=100) + +class BlogChild(Blog): + child_field = models.CharField(max_length=100) + +# Emulate at type-check time the errors that Django reports +Blog.objects.values_list('id', flat=True, named=True) # E: 'flat' and 'named' can't be used together. +Blog.objects.values_list('id', 'created_at', flat=True) # E: 'flat' is not valid when values_list is called with more than one field. + +# values_list where parameter types are all known +reveal_type(Blog.objects.values_list('id', 'created_at').get()) # E: Revealed type is 'Tuple[builtins.int, datetime.datetime]' +tup = Blog.objects.values_list('id', 'created_at').get() +reveal_type(tup[0]) # E: Revealed type is 'builtins.int' +reveal_type(tup[1]) # E: Revealed type is 'datetime.datetime' +tup[2] # E: Tuple index out of range + +# values_list returning namedtuple +reveal_type(Blog.objects.values_list('id', 'created_at', named=True).get()) # E: Revealed type is 'Tuple[builtins.int, datetime.datetime, fallback=main.Row]' + +# Invalid lookups produce Any type rather than giving errors. +reveal_type(Blog.objects.values_list('id', 'invalid_lookup').get()) # E: Revealed type is 'Tuple[builtins.int, Any]' +reveal_type(Blog.objects.values_list('entries_id', flat=True).get()) # E: Revealed type is 'Any' +reveal_type(Blog.objects.values_list('entries__foo', flat=True).get()) # E: Revealed type is 'Any' +reveal_type(Blog.objects.values_list('+', flat=True).get()) # E: Revealed type is 'Any' + +# Foreign key +reveal_type(Entry.objects.values_list('blog', flat=True).get()) # E: Revealed type is 'builtins.int*' +reveal_type(Entry.objects.values_list('blog__id', flat=True).get()) # E: Revealed type is 'builtins.int*' +reveal_type(Entry.objects.values_list('blog__pk', flat=True).get()) # E: Revealed type is 'builtins.int*' +reveal_type(Entry.objects.values_list('blog_id', flat=True).get()) # E: Revealed type is 'builtins.int*' + +# Foreign key (nullable=True) +reveal_type(Entry.objects.values_list('nullable_blog', flat=True).get()) # E: Revealed type is 'Union[builtins.int, None]' +reveal_type(Entry.objects.values_list('nullable_blog_id', flat=True).get()) # E: Revealed type is 'Union[builtins.int, None]' +reveal_type(Entry.objects.values_list('nullable_blog__id', flat=True).get()) # E: Revealed type is 'Union[builtins.int, None]' +reveal_type(Entry.objects.values_list('nullable_blog__pk', flat=True).get()) # E: Revealed type is 'Union[builtins.int, None]' + +# Reverse relation of ForeignKey +reveal_type(Blog.objects.values_list('entries', flat=True).get()) # E: Revealed type is 'Union[builtins.int, None]' +reveal_type(Blog.objects.values_list('entries__id', flat=True).get()) # E: Revealed type is 'Union[builtins.int, None]' +reveal_type(Blog.objects.values_list('entries__title', flat=True).get()) # E: Revealed type is 'Union[builtins.str, None]' + +# Reverse relation of ForeignKey (with related_query_name set) +reveal_type(Blog.objects.values_list('my_related_query_name__id', flat=True).get()) # E: Revealed type is 'Union[builtins.int, None]' + +# Basic inheritance +reveal_type(BlogChild.objects.values_list('id', 'created_at', 'child_field').get()) # E: Revealed type is 'Tuple[builtins.int, datetime.datetime, builtins.str]' + + + +[CASE test_queryset_values_list_and_values_behavior_with_no_fields_specified_and_accessing_unknown_attributes] +from django.db import models + +class Blog(models.Model): + name = models.CharField(max_length=100) + created_at = models.DateTimeField() + +row_named = Blog.objects.values_list('id', 'created_at', named=True).get() +reveal_type(row_named.id) # E: Revealed type is 'builtins.int' +reveal_type(row_named.created_at) # E: Revealed type is 'datetime.datetime' +row_named.non_existent_field # E: "Row" has no attribute "non_existent_field" + + +# When no fields are specified, fallback to Any +row_named_no_fields = Blog.objects.values_list(named=True).get() +reveal_type(row_named_no_fields) # E: Revealed type is 'Tuple[, fallback=django._NamedTupleAnyAttr]' + +# Don't complain about access to any attribute for now +reveal_type(row_named_no_fields.non_existent_field) # E: Revealed type is 'Any' +row_named_no_fields.non_existent_field = 1 + +# It should still behave like a NamedTuple +reveal_type(row_named_no_fields._asdict()) # E: Revealed type is 'builtins.dict[builtins.str, Any]' + + +dict_row = Blog.objects.values('id', 'created_at').get() +reveal_type(dict_row["id"]) # E: Revealed type is 'builtins.int' +reveal_type(dict_row["created_at"]) # E: Revealed type is 'datetime.datetime' +dict_row["non_existent_field"] # E: 'non_existent_field' is not a valid TypedDict key; expected one of ('id', 'created_at') +dict_row.pop('created_at') +dict_row.pop('non_existent_field') # E: 'non_existent_field' is not a valid TypedDict key; expected one of ('id', 'created_at') + +row_dict_no_fields = Blog.objects.values().get() +reveal_type(row_dict_no_fields) # E: Revealed type is 'builtins.dict*[builtins.str, Any]' +reveal_type(row_dict_no_fields["non_existent_field"]) # E: Revealed type is 'Any' + +[CASE values_with_annotate_inside_the_expressions] +from django.db import models +from django.db.models.functions import Lower, Upper + +class Publisher(models.Model): + pass + +class Book(models.Model): + name = models.CharField(max_length=100) + publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE, related_name='books') + +reveal_type(Publisher().books.values('name', lower_name=Lower('name'), upper_name=Upper('name'))) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book*, TypedDict({'name'?: builtins.str, 'lower_name'?: Any, 'upper_name'?: Any})]' + + +[CASE values_and_values_list_some_dynamic_fields] +from django.db import models + +class Publisher(models.Model): + pass + +class Book(models.Model): + name = models.CharField(max_length=100) + publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE, related_name='books') + +some_dynamic_field = 'publisher' + +# Correct Tuple field types should be filled in when string literal is used, while Any is used for dynamic fields +reveal_type(Publisher().books.values_list('name', some_dynamic_field)) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book*, Tuple[builtins.str, Any]]' + +# Flat with dynamic fields (there is only 1), means of course Any +reveal_type(Publisher().books.values_list(some_dynamic_field, flat=True)) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book*, Any]' + +# A NamedTuple with a fallback to Any could be implemented, but for now that's unsupported, so all +# fields on the NamedTuple are Any for now +reveal_type(Publisher().books.values_list('name', some_dynamic_field, named=True).name) # E: Revealed type is 'Any' + +# A TypedDict with a fallback to Any could be implemented, but for now that's unsupported, +# so an ordinary Dict is used for now. +reveal_type(Publisher().books.values(some_dynamic_field, 'name')) # E: Revealed type is 'django.db.models.query.QuerySet[main.Book*, builtins.dict[builtins.str, Any]]' diff --git a/test-data/typecheck/related_fields.test b/test-data/typecheck/related_fields.test index b367cd4..0e80c74 100644 --- a/test-data/typecheck/related_fields.test +++ b/test-data/typecheck/related_fields.test @@ -287,3 +287,23 @@ class Publisher(models.Model): pass reveal_type(Book().publisher) # E: Revealed type is 'main.Publisher*' [out] + +[CASE test_foreign_key_field_without_backwards_relation] +from django.db import models + +class Publisher(models.Model): + pass + +class Book(models.Model): + publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE, + related_name='+') + publisher2 = models.ForeignKey(to=Publisher, on_delete=models.CASCADE, + related_name='books2') + +book = Book() +reveal_type(book.publisher) # E: Revealed type is 'main.Publisher*' + +publisher = Publisher() +reveal_type(publisher.books) # E: Revealed type is 'Any' +reveal_type(publisher.books2) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.Book]' +