From 5d0ee40ada6668d27762f5fc23b6bb57d32bcd5b Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Sun, 24 Mar 2019 02:54:10 +0300 Subject: [PATCH] Fix errors in db.models.expressions and db.models.functions.* (#54) * fix errors at db.models.expressions and db.models.functions.* * catch KeyError if QuerySet has not been loaded --- django-stubs/db/models/expressions.pyi | 90 +++++++----------- .../db/models/functions/comparison.pyi | 50 +--------- django-stubs/db/models/functions/text.pyi | 92 +++---------------- django-stubs/db/models/functions/window.pyi | 67 +++----------- mypy_django_plugin/main.py | 15 ++- 5 files changed, 73 insertions(+), 241 deletions(-) diff --git a/django-stubs/db/models/expressions.pyi b/django-stubs/db/models/expressions.pyi index 180738a..6f9d62c 100644 --- a/django-stubs/db/models/expressions.pyi +++ b/django-stubs/db/models/expressions.pyi @@ -1,20 +1,18 @@ -from collections import OrderedDict from datetime import datetime, timedelta -from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, Union, TypeVar +from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, TypeVar, Union from django.db.models.lookups import Lookup from django.db.models.sql.compiler import SQLCompiler from django.db.models import Q, QuerySet -from django.db.models.fields import Field, FloatField -from django.db.models.sql import Query +from django.db.models.fields import Field _OutputField = Union[Field, str] class SQLiteNumericMixin: def as_sqlite(self, compiler: SQLCompiler, connection: Any, **extra_context: Any) -> Tuple[str, List[float]]: ... -_SelfCombinable = TypeVar("_SelfCombinable", bound="Combinable") +_Self = TypeVar("_Self") class Combinable: ADD: str = ... @@ -27,22 +25,20 @@ class Combinable: BITOR: str = ... BITLEFTSHIFT: str = ... BITRIGHTSHIFT: str = ... - def __neg__(self: _SelfCombinable) -> _SelfCombinable: ... - def __add__( - self: _SelfCombinable, other: Optional[Union[timedelta, Combinable, float, str]] - ) -> _SelfCombinable: ... - def __sub__(self: _SelfCombinable, other: Union[timedelta, Combinable, float]) -> _SelfCombinable: ... - def __mul__(self: _SelfCombinable, other: Union[timedelta, Combinable, float]) -> _SelfCombinable: ... - def __truediv__(self: _SelfCombinable, other: Union[Combinable, float]) -> _SelfCombinable: ... - def __itruediv__(self: _SelfCombinable, other: Union[Combinable, float]) -> _SelfCombinable: ... - def __mod__(self: _SelfCombinable, other: Union[int, Combinable]) -> _SelfCombinable: ... - def __pow__(self: _SelfCombinable, other: Union[float, Combinable]) -> _SelfCombinable: ... - def __and__(self: _SelfCombinable, other: Combinable) -> _SelfCombinable: ... - def bitand(self: _SelfCombinable, other: int) -> _SelfCombinable: ... - def bitleftshift(self: _SelfCombinable, other: int) -> _SelfCombinable: ... - def bitrightshift(self: _SelfCombinable, other: int) -> _SelfCombinable: ... - def __or__(self: _SelfCombinable, other: Combinable) -> _SelfCombinable: ... - def bitor(self: _SelfCombinable, other: int) -> _SelfCombinable: ... + def __neg__(self: _Self) -> _Self: ... + def __add__(self: _Self, other: Optional[Union[timedelta, Combinable, float, str]]) -> _Self: ... + def __sub__(self: _Self, other: Union[timedelta, Combinable, float]) -> _Self: ... + def __mul__(self: _Self, other: Union[timedelta, Combinable, float]) -> _Self: ... + def __truediv__(self: _Self, other: Union[Combinable, float]) -> _Self: ... + def __itruediv__(self: _Self, other: Union[Combinable, float]) -> _Self: ... + def __mod__(self: _Self, other: Union[int, Combinable]) -> _Self: ... + def __pow__(self: _Self, other: Union[float, Combinable]) -> _Self: ... + def __and__(self: _Self, other: Combinable) -> _Self: ... + def bitand(self: _Self, other: int) -> _Self: ... + def bitleftshift(self: _Self, other: int) -> _Self: ... + def bitrightshift(self: _Self, other: int) -> _Self: ... + def __or__(self: _Self, other: Combinable) -> _Self: ... + def bitor(self: _Self, other: int) -> _Self: ... def __radd__(self, other: Optional[Union[datetime, float, Combinable]]) -> Combinable: ... def __rsub__(self, other: Union[float, Combinable]) -> Combinable: ... def __rmul__(self, other: Union[float, Combinable]) -> Combinable: ... @@ -52,43 +48,45 @@ class Combinable: def __rand__(self, other: Any) -> Combinable: ... def __ror__(self, other: Any) -> Combinable: ... -_SelfBaseExpression = TypeVar("_SelfBaseExpression", bound="BaseExpression") - class BaseExpression: is_summary: bool = ... filterable: bool = ... window_compatible: bool = ... + output_field: Any def __init__(self, output_field: Optional[_OutputField] = ...) -> None: ... def get_db_converters(self, connection: Any) -> List[Callable]: ... def get_source_expressions(self) -> List[Any]: ... - def set_source_expressions(self, exprs: List[Any]) -> None: ... + def set_source_expressions(self, exprs: Sequence[Combinable]) -> None: ... + @property def contains_aggregate(self) -> bool: ... + @property def contains_over_clause(self) -> bool: ... + @property def contains_column_references(self) -> bool: ... def resolve_expression( - self, + self: _Self, query: Any = ..., allow_joins: bool = ..., reuse: Optional[Set[str]] = ..., summarize: bool = ..., for_save: bool = ..., - ) -> BaseExpression: ... + ) -> _Self: ... @property def field(self) -> Field: ... @property def output_field(self) -> Field: ... + @property def convert_value(self) -> Callable: ... def get_lookup(self, lookup: str) -> Optional[Type[Lookup]]: ... def get_transform(self, name: str) -> Optional[Type[Expression]]: ... def relabeled_clone(self, change_map: Dict[Optional[str], str]) -> Expression: ... def copy(self) -> BaseExpression: ... - def get_group_by_cols(self: _SelfBaseExpression) -> List[_SelfBaseExpression]: ... + def get_group_by_cols(self: _Self) -> List[_Self]: ... def get_source_fields(self) -> List[Optional[Field]]: ... def asc(self, **kwargs: Any) -> Expression: ... def desc(self, **kwargs: Any) -> Expression: ... def reverse_ordering(self): ... def flatten(self) -> Iterator[Expression]: ... - def __hash__(self) -> int: ... def deconstruct(self) -> Any: ... def as_sqlite(self, compiler: SQLCompiler, connection: Any) -> Any: ... def as_sql(self, compiler: SQLCompiler, connection: Any, **extra_context: Any) -> Any: ... @@ -105,28 +103,18 @@ class CombinedExpression(SQLiteNumericMixin, Expression): def __init__( self, lhs: Combinable, connector: str, rhs: Combinable, output_field: Optional[_OutputField] = ... ) -> None: ... - def get_source_expressions(self) -> Union[List[Combinable], List[SQLiteNumericMixin]]: ... - def set_source_expressions(self, exprs: List[Combinable]) -> None: ... - def resolve_expression( - self, - query: Any = ..., - allow_joins: bool = ..., - reuse: Optional[Set[str]] = ..., - summarize: bool = ..., - for_save: bool = ..., - ) -> CombinedExpression: ... class F(Combinable): name: str def __init__(self, name: str): ... def resolve_expression( - self, + self: _Self, query: Any = ..., allow_joins: bool = ..., reuse: Optional[Set[str]] = ..., summarize: bool = ..., for_save: bool = ..., - ) -> Expression: ... + ) -> _Self: ... def asc(self, **kwargs) -> OrderBy: ... def desc(self, **kwargs) -> OrderBy: ... def deconstruct(self) -> Any: ... @@ -141,8 +129,6 @@ class Subquery(Expression): def __init__(self, queryset: QuerySet, output_field: Optional[_OutputField] = ..., **extra: Any) -> None: ... class Exists(Subquery): - extra: Dict[Any, Any] - template: str = ... negated: bool = ... def __init__(self, *args: Any, negated: bool = ..., **kwargs: Any) -> None: ... def __invert__(self) -> Exists: ... @@ -162,30 +148,19 @@ class Value(Expression): def __init__(self, value: Any, output_field: Optional[_OutputField] = ...) -> None: ... class RawSQL(Expression): - output_field: Field params: List[Any] sql: str def __init__(self, sql: str, params: Sequence[Any], output_field: Optional[_OutputField] = ...) -> None: ... class Func(SQLiteNumericMixin, Expression): function: str = ... + name: str = ... template: str = ... arg_joiner: str = ... arity: int = ... - source_expressions: List[Expression] = ... + source_expressions: List[Combinable] = ... extra: Dict[Any, Any] = ... def __init__(self, *expressions: Any, output_field: Optional[_OutputField] = ..., **extra: Any) -> None: ... - def get_source_expressions(self) -> List[Combinable]: ... - def set_source_expressions(self, exprs: List[Expression]) -> None: ... - def resolve_expression( - self, - query: Query = ..., - allow_joins: bool = ..., - reuse: Optional[Set[Any]] = ..., - summarize: bool = ..., - for_save: bool = ..., - ) -> Func: ... - def copy(self) -> Func: ... class When(Expression): template: str = ... @@ -205,8 +180,6 @@ class Case(Expression): class ExpressionWrapper(Expression): def __init__(self, expression: Union[Q, Combinable], output_field: _OutputField): ... - def set_source_expressions(self, exprs: Sequence[Expression]) -> None: ... - def get_source_expressions(self) -> List[Expression]: ... class Col(Expression): def __init__(self, alias: str, target: str, output_field: Optional[_OutputField] = ...): ... @@ -214,8 +187,7 @@ class Col(Expression): class ExpressionList(Func): def __init__(self, *expressions: Union[BaseExpression, Combinable], **extra: Any) -> None: ... -class Random(Expression): - output_field: FloatField +class Random(Expression): ... class Ref(Expression): def __init__(self, refs: str, source: Expression): ... diff --git a/django-stubs/db/models/functions/comparison.pyi b/django-stubs/db/models/functions/comparison.pyi index fdd9b8d..5a1ec22 100644 --- a/django-stubs/db/models/functions/comparison.pyi +++ b/django-stubs/db/models/functions/comparison.pyi @@ -1,51 +1,11 @@ -from datetime import date -from decimal import Decimal -from typing import Any, Callable, Dict, List, Union - -from django.db.models.expressions import Combinable, Expression +from typing import Any, Union from django.db.models import Func from django.db.models.fields import Field class Cast(Func): - contains_aggregate: bool - convert_value: Callable - extra: Dict[Any, Any] - is_summary: bool - output_field: Field - source_expressions: List[Combinable] - function: str = ... - template: str = ... - def __init__(self, expression: Union[date, Decimal, Expression, str], output_field: Union[str, Field]) -> None: ... + def __init__(self, expression: Any, output_field: Union[str, Field]) -> None: ... -class Coalesce(Func): - contains_aggregate: bool - convert_value: Callable - extra: Dict[Any, Any] - is_summary: bool - output_field: Field - source_expressions: List[Combinable] - function: str = ... - def __init__(self, *expressions: Any, **extra: Any) -> None: ... - -class Greatest(Func): - contains_aggregate: bool - contains_over_clause: bool - convert_value: Callable - extra: Dict[Any, Any] - is_summary: bool - output_field: Field - source_expressions: List[Combinable] - function: str = ... - def __init__(self, *expressions: Any, **extra: Any) -> None: ... - -class Least(Func): - contains_aggregate: bool - contains_over_clause: bool - convert_value: Callable - extra: Dict[Any, Any] - is_summary: bool - output_field: Field - source_expressions: List[Combinable] - function: str = ... - def __init__(self, *expressions: Any, **extra: Any) -> None: ... +class Coalesce(Func): ... +class Greatest(Func): ... +class Least(Func): ... diff --git a/django-stubs/db/models/functions/text.pyi b/django-stubs/db/models/functions/text.pyi index d66f352..20fe6ca 100644 --- a/django-stubs/db/models/functions/text.pyi +++ b/django-stubs/db/models/functions/text.pyi @@ -1,105 +1,47 @@ -from typing import Any, List, Optional, Tuple, Union, Callable +from typing import Any, List, Optional, Tuple, Union from django.db.backends.sqlite3.base import DatabaseWrapper from django.db.models.expressions import Combinable, Expression, Value from django.db.models.sql.compiler import SQLCompiler from django.db.models import Func, Transform -from django.db.models.fields import Field -class BytesToCharFieldConversionMixin: - def convert_value( - self, value: str, expression: BytesToCharFieldConversionMixin, connection: DatabaseWrapper - ) -> str: ... - -class Chr(Transform): - contains_aggregate: bool - lookup_name: str = ... +class BytesToCharFieldConversionMixin: ... +class Chr(Transform): ... class ConcatPair(Func): - contains_aggregate: bool def coalesce(self) -> ConcatPair: ... -class Concat(Func): - contains_aggregate: bool - convert_value: Callable +class Concat(Func): ... class Left(Func): - contains_aggregate: bool - contains_over_clause: bool - convert_value: Callable - output_field: Field def __init__(self, expression: str, length: Union[Value, int], **extra: Any) -> None: ... def get_substr(self) -> Substr: ... def use_substr( self, compiler: SQLCompiler, connection: DatabaseWrapper, **extra_context: Any ) -> Tuple[str, List[int]]: ... - as_oracle: Any = ... - as_sqlite: Any = ... -class Length(Transform): - contains_aggregate: bool - convert_value: Callable - lookup_name: str = ... - -class Lower(Transform): - contains_aggregate: bool - contains_column_references: bool - contains_over_clause: bool - convert_value: Callable - lookup_name: str = ... +class Length(Transform): ... +class Lower(Transform): ... class LPad(BytesToCharFieldConversionMixin, Func): - contains_aggregate: bool - convert_value: Callable def __init__(self, expression: str, length: Union[Length, int], fill_text: Value = ..., **extra: Any) -> None: ... -class LTrim(Transform): - contains_aggregate: bool - convert_value: Callable - lookup_name: str = ... - -class Ord(Transform): - contains_aggregate: bool - convert_value: Callable - lookup_name: str = ... - def as_mysql(self, compiler: Any, connection: Any, **extra_context: Any): ... +class LTrim(Transform): ... +class Ord(Transform): ... class Repeat(BytesToCharFieldConversionMixin, Func): - contains_aggregate: bool - convert_value: Callable - output_field: django.db.models.fields.CharField def __init__(self, expression: Union[Value, str], number: Union[Length, int], **extra: Any) -> None: ... - def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any): ... class Replace(Func): - contains_aggregate: bool - contains_over_clause: bool - convert_value: Callable - output_field: Field def __init__(self, expression: Combinable, text: Value, replacement: Value = ..., **extra: Any) -> None: ... class Right(Left): ... - -class RPad(LPad): - output_field: Field - -class RTrim(Transform): - contains_aggregate: bool - convert_value: Callable - lookup_name: str = ... - -class StrIndex(Func): - contains_aggregate: bool - convert_value: Callable - output_field: Any = ... - def as_postgresql(self, compiler: Any, connection: Any): ... +class RPad(LPad): ... +class RTrim(Transform): ... +class StrIndex(Func): ... class Substr(Func): - contains_aggregate: bool - contains_over_clause: bool - convert_value: Callable - output_field: Field def __init__( self, expression: Union[Expression, str], @@ -108,13 +50,5 @@ class Substr(Func): **extra: Any ) -> None: ... -class Trim(Transform): - contains_aggregate: bool - convert_value: Callable - lookup_name: str = ... - -class Upper(Transform): - contains_aggregate: bool - contains_over_clause: bool - convert_value: Callable - lookup_name: str = ... +class Trim(Transform): ... +class Upper(Transform): ... diff --git a/django-stubs/db/models/functions/window.pyi b/django-stubs/db/models/functions/window.pyi index e4454d9..789a4fd 100644 --- a/django-stubs/db/models/functions/window.pyi +++ b/django-stubs/db/models/functions/window.pyi @@ -1,69 +1,26 @@ -from typing import Any, Optional, Dict, List +from typing import Any, Optional from django.db.models import Func -class CumeDist(Func): - function: str = ... - name: str = ... - output_field: Any = ... - window_compatible: bool = ... - -class DenseRank(Func): - name: str = ... - output_field: Any = ... - window_compatible: bool = ... - -class FirstValue(Func): - name: str = ... - window_compatible: bool = ... +class CumeDist(Func): ... +class DenseRank(Func): ... +class FirstValue(Func): ... class LagLeadFunction(Func): - window_compatible: bool = ... def __init__( self, expression: Optional[str], offset: int = ..., default: Optional[int] = ..., **extra: Any ) -> None: ... -class Lag(LagLeadFunction): - function: str = ... - name: str = ... - -class LastValue(Func): - arity: int = ... - function: str = ... - name: str = ... - window_compatible: bool = ... - -class Lead(LagLeadFunction): - function: str = ... - name: str = ... +class Lag(LagLeadFunction): ... +class LastValue(Func): ... +class Lead(LagLeadFunction): ... class NthValue(Func): - function: str = ... - name: str = ... - window_compatible: bool = ... - def __init__(self, expression: Optional[str], nth: int = ..., **extra: Any) -> Any: ... + def __init__(self, expression: Optional[str], nth: int = ..., **extra: Any) -> None: ... class Ntile(Func): - function: str = ... - name: str = ... - output_field: Any = ... - window_compatible: bool = ... - def __init__(self, num_buckets: int = ..., **extra: Any) -> Any: ... + def __init__(self, num_buckets: int = ..., **extra: Any) -> None: ... -class PercentRank(Func): - function: str = ... - name: str = ... - output_field: Any = ... - window_compatible: bool = ... - -class Rank(Func): - function: str = ... - name: str = ... - output_field: Any = ... - window_compatible: bool = ... - -class RowNumber(Func): - function: str = ... - name: str = ... - output_field: Any = ... - window_compatible: bool = ... +class PercentRank(Func): ... +class Rank(Func): ... +class RowNumber(Func): ... diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 8247e18..fbbf27d 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -84,14 +84,23 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type: def set_first_generic_param_as_default_for_second(fullname: str, ctx: AnalyzeTypeContext) -> Type: if not ctx.type.args: - return ctx.api.named_type(fullname, [AnyType(TypeOfAny.explicit), - AnyType(TypeOfAny.explicit)]) + try: + return ctx.api.named_type(fullname, [AnyType(TypeOfAny.explicit), + AnyType(TypeOfAny.explicit)]) + except KeyError: + # really should never happen + return AnyType(TypeOfAny.explicit) + args = ctx.type.args if len(args) == 1: args = [args[0], args[0]] analyzed_args = [ctx.api.analyze_type(arg) for arg in args] - return ctx.api.named_type(fullname, analyzed_args) + try: + return ctx.api.named_type(fullname, analyzed_args) + except KeyError: + # really should never happen + return AnyType(TypeOfAny.explicit) def return_user_model_hook(ctx: FunctionContext) -> Type: