Fix errors in db.models.expressions and db.models.functions.* (#54)

* fix errors at db.models.expressions and db.models.functions.*

* catch KeyError if QuerySet has not been loaded
This commit is contained in:
Maxim Kurnikov
2019-03-24 02:54:10 +03:00
committed by GitHub
parent 77f15d7478
commit 5d0ee40ada
5 changed files with 73 additions and 241 deletions

View File

@@ -1,20 +1,18 @@
from collections import OrderedDict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, Union, TypeVar from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, TypeVar, Union
from django.db.models.lookups import Lookup from django.db.models.lookups import Lookup
from django.db.models.sql.compiler import SQLCompiler from django.db.models.sql.compiler import SQLCompiler
from django.db.models import Q, QuerySet from django.db.models import Q, QuerySet
from django.db.models.fields import Field, FloatField from django.db.models.fields import Field
from django.db.models.sql import Query
_OutputField = Union[Field, str] _OutputField = Union[Field, str]
class SQLiteNumericMixin: class SQLiteNumericMixin:
def as_sqlite(self, compiler: SQLCompiler, connection: Any, **extra_context: Any) -> Tuple[str, List[float]]: ... def as_sqlite(self, compiler: SQLCompiler, connection: Any, **extra_context: Any) -> Tuple[str, List[float]]: ...
_SelfCombinable = TypeVar("_SelfCombinable", bound="Combinable") _Self = TypeVar("_Self")
class Combinable: class Combinable:
ADD: str = ... ADD: str = ...
@@ -27,22 +25,20 @@ class Combinable:
BITOR: str = ... BITOR: str = ...
BITLEFTSHIFT: str = ... BITLEFTSHIFT: str = ...
BITRIGHTSHIFT: str = ... BITRIGHTSHIFT: str = ...
def __neg__(self: _SelfCombinable) -> _SelfCombinable: ... def __neg__(self: _Self) -> _Self: ...
def __add__( def __add__(self: _Self, other: Optional[Union[timedelta, Combinable, float, str]]) -> _Self: ...
self: _SelfCombinable, other: Optional[Union[timedelta, Combinable, float, str]] def __sub__(self: _Self, other: Union[timedelta, Combinable, float]) -> _Self: ...
) -> _SelfCombinable: ... def __mul__(self: _Self, other: Union[timedelta, Combinable, float]) -> _Self: ...
def __sub__(self: _SelfCombinable, other: Union[timedelta, Combinable, float]) -> _SelfCombinable: ... def __truediv__(self: _Self, other: Union[Combinable, float]) -> _Self: ...
def __mul__(self: _SelfCombinable, other: Union[timedelta, Combinable, float]) -> _SelfCombinable: ... def __itruediv__(self: _Self, other: Union[Combinable, float]) -> _Self: ...
def __truediv__(self: _SelfCombinable, other: Union[Combinable, float]) -> _SelfCombinable: ... def __mod__(self: _Self, other: Union[int, Combinable]) -> _Self: ...
def __itruediv__(self: _SelfCombinable, other: Union[Combinable, float]) -> _SelfCombinable: ... def __pow__(self: _Self, other: Union[float, Combinable]) -> _Self: ...
def __mod__(self: _SelfCombinable, other: Union[int, Combinable]) -> _SelfCombinable: ... def __and__(self: _Self, other: Combinable) -> _Self: ...
def __pow__(self: _SelfCombinable, other: Union[float, Combinable]) -> _SelfCombinable: ... def bitand(self: _Self, other: int) -> _Self: ...
def __and__(self: _SelfCombinable, other: Combinable) -> _SelfCombinable: ... def bitleftshift(self: _Self, other: int) -> _Self: ...
def bitand(self: _SelfCombinable, other: int) -> _SelfCombinable: ... def bitrightshift(self: _Self, other: int) -> _Self: ...
def bitleftshift(self: _SelfCombinable, other: int) -> _SelfCombinable: ... def __or__(self: _Self, other: Combinable) -> _Self: ...
def bitrightshift(self: _SelfCombinable, other: int) -> _SelfCombinable: ... def bitor(self: _Self, other: int) -> _Self: ...
def __or__(self: _SelfCombinable, other: Combinable) -> _SelfCombinable: ...
def bitor(self: _SelfCombinable, other: int) -> _SelfCombinable: ...
def __radd__(self, other: Optional[Union[datetime, float, Combinable]]) -> Combinable: ... def __radd__(self, other: Optional[Union[datetime, float, Combinable]]) -> Combinable: ...
def __rsub__(self, other: Union[float, Combinable]) -> Combinable: ... def __rsub__(self, other: Union[float, Combinable]) -> Combinable: ...
def __rmul__(self, other: Union[float, Combinable]) -> Combinable: ... def __rmul__(self, other: Union[float, Combinable]) -> Combinable: ...
@@ -52,43 +48,45 @@ class Combinable:
def __rand__(self, other: Any) -> Combinable: ... def __rand__(self, other: Any) -> Combinable: ...
def __ror__(self, other: Any) -> Combinable: ... def __ror__(self, other: Any) -> Combinable: ...
_SelfBaseExpression = TypeVar("_SelfBaseExpression", bound="BaseExpression")
class BaseExpression: class BaseExpression:
is_summary: bool = ... is_summary: bool = ...
filterable: bool = ... filterable: bool = ...
window_compatible: bool = ... window_compatible: bool = ...
output_field: Any
def __init__(self, output_field: Optional[_OutputField] = ...) -> None: ... def __init__(self, output_field: Optional[_OutputField] = ...) -> None: ...
def get_db_converters(self, connection: Any) -> List[Callable]: ... def get_db_converters(self, connection: Any) -> List[Callable]: ...
def get_source_expressions(self) -> List[Any]: ... def get_source_expressions(self) -> List[Any]: ...
def set_source_expressions(self, exprs: List[Any]) -> None: ... def set_source_expressions(self, exprs: Sequence[Combinable]) -> None: ...
@property
def contains_aggregate(self) -> bool: ... def contains_aggregate(self) -> bool: ...
@property
def contains_over_clause(self) -> bool: ... def contains_over_clause(self) -> bool: ...
@property
def contains_column_references(self) -> bool: ... def contains_column_references(self) -> bool: ...
def resolve_expression( def resolve_expression(
self, self: _Self,
query: Any = ..., query: Any = ...,
allow_joins: bool = ..., allow_joins: bool = ...,
reuse: Optional[Set[str]] = ..., reuse: Optional[Set[str]] = ...,
summarize: bool = ..., summarize: bool = ...,
for_save: bool = ..., for_save: bool = ...,
) -> BaseExpression: ... ) -> _Self: ...
@property @property
def field(self) -> Field: ... def field(self) -> Field: ...
@property @property
def output_field(self) -> Field: ... def output_field(self) -> Field: ...
@property
def convert_value(self) -> Callable: ... def convert_value(self) -> Callable: ...
def get_lookup(self, lookup: str) -> Optional[Type[Lookup]]: ... def get_lookup(self, lookup: str) -> Optional[Type[Lookup]]: ...
def get_transform(self, name: str) -> Optional[Type[Expression]]: ... def get_transform(self, name: str) -> Optional[Type[Expression]]: ...
def relabeled_clone(self, change_map: Dict[Optional[str], str]) -> Expression: ... def relabeled_clone(self, change_map: Dict[Optional[str], str]) -> Expression: ...
def copy(self) -> BaseExpression: ... def copy(self) -> BaseExpression: ...
def get_group_by_cols(self: _SelfBaseExpression) -> List[_SelfBaseExpression]: ... def get_group_by_cols(self: _Self) -> List[_Self]: ...
def get_source_fields(self) -> List[Optional[Field]]: ... def get_source_fields(self) -> List[Optional[Field]]: ...
def asc(self, **kwargs: Any) -> Expression: ... def asc(self, **kwargs: Any) -> Expression: ...
def desc(self, **kwargs: Any) -> Expression: ... def desc(self, **kwargs: Any) -> Expression: ...
def reverse_ordering(self): ... def reverse_ordering(self): ...
def flatten(self) -> Iterator[Expression]: ... def flatten(self) -> Iterator[Expression]: ...
def __hash__(self) -> int: ...
def deconstruct(self) -> Any: ... def deconstruct(self) -> Any: ...
def as_sqlite(self, compiler: SQLCompiler, connection: Any) -> Any: ... def as_sqlite(self, compiler: SQLCompiler, connection: Any) -> Any: ...
def as_sql(self, compiler: SQLCompiler, connection: Any, **extra_context: Any) -> Any: ... def as_sql(self, compiler: SQLCompiler, connection: Any, **extra_context: Any) -> Any: ...
@@ -105,28 +103,18 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
def __init__( def __init__(
self, lhs: Combinable, connector: str, rhs: Combinable, output_field: Optional[_OutputField] = ... self, lhs: Combinable, connector: str, rhs: Combinable, output_field: Optional[_OutputField] = ...
) -> None: ... ) -> None: ...
def get_source_expressions(self) -> Union[List[Combinable], List[SQLiteNumericMixin]]: ...
def set_source_expressions(self, exprs: List[Combinable]) -> None: ...
def resolve_expression(
self,
query: Any = ...,
allow_joins: bool = ...,
reuse: Optional[Set[str]] = ...,
summarize: bool = ...,
for_save: bool = ...,
) -> CombinedExpression: ...
class F(Combinable): class F(Combinable):
name: str name: str
def __init__(self, name: str): ... def __init__(self, name: str): ...
def resolve_expression( def resolve_expression(
self, self: _Self,
query: Any = ..., query: Any = ...,
allow_joins: bool = ..., allow_joins: bool = ...,
reuse: Optional[Set[str]] = ..., reuse: Optional[Set[str]] = ...,
summarize: bool = ..., summarize: bool = ...,
for_save: bool = ..., for_save: bool = ...,
) -> Expression: ... ) -> _Self: ...
def asc(self, **kwargs) -> OrderBy: ... def asc(self, **kwargs) -> OrderBy: ...
def desc(self, **kwargs) -> OrderBy: ... def desc(self, **kwargs) -> OrderBy: ...
def deconstruct(self) -> Any: ... def deconstruct(self) -> Any: ...
@@ -141,8 +129,6 @@ class Subquery(Expression):
def __init__(self, queryset: QuerySet, output_field: Optional[_OutputField] = ..., **extra: Any) -> None: ... def __init__(self, queryset: QuerySet, output_field: Optional[_OutputField] = ..., **extra: Any) -> None: ...
class Exists(Subquery): class Exists(Subquery):
extra: Dict[Any, Any]
template: str = ...
negated: bool = ... negated: bool = ...
def __init__(self, *args: Any, negated: bool = ..., **kwargs: Any) -> None: ... def __init__(self, *args: Any, negated: bool = ..., **kwargs: Any) -> None: ...
def __invert__(self) -> Exists: ... def __invert__(self) -> Exists: ...
@@ -162,30 +148,19 @@ class Value(Expression):
def __init__(self, value: Any, output_field: Optional[_OutputField] = ...) -> None: ... def __init__(self, value: Any, output_field: Optional[_OutputField] = ...) -> None: ...
class RawSQL(Expression): class RawSQL(Expression):
output_field: Field
params: List[Any] params: List[Any]
sql: str sql: str
def __init__(self, sql: str, params: Sequence[Any], output_field: Optional[_OutputField] = ...) -> None: ... def __init__(self, sql: str, params: Sequence[Any], output_field: Optional[_OutputField] = ...) -> None: ...
class Func(SQLiteNumericMixin, Expression): class Func(SQLiteNumericMixin, Expression):
function: str = ... function: str = ...
name: str = ...
template: str = ... template: str = ...
arg_joiner: str = ... arg_joiner: str = ...
arity: int = ... arity: int = ...
source_expressions: List[Expression] = ... source_expressions: List[Combinable] = ...
extra: Dict[Any, Any] = ... extra: Dict[Any, Any] = ...
def __init__(self, *expressions: Any, output_field: Optional[_OutputField] = ..., **extra: Any) -> None: ... def __init__(self, *expressions: Any, output_field: Optional[_OutputField] = ..., **extra: Any) -> None: ...
def get_source_expressions(self) -> List[Combinable]: ...
def set_source_expressions(self, exprs: List[Expression]) -> None: ...
def resolve_expression(
self,
query: Query = ...,
allow_joins: bool = ...,
reuse: Optional[Set[Any]] = ...,
summarize: bool = ...,
for_save: bool = ...,
) -> Func: ...
def copy(self) -> Func: ...
class When(Expression): class When(Expression):
template: str = ... template: str = ...
@@ -205,8 +180,6 @@ class Case(Expression):
class ExpressionWrapper(Expression): class ExpressionWrapper(Expression):
def __init__(self, expression: Union[Q, Combinable], output_field: _OutputField): ... def __init__(self, expression: Union[Q, Combinable], output_field: _OutputField): ...
def set_source_expressions(self, exprs: Sequence[Expression]) -> None: ...
def get_source_expressions(self) -> List[Expression]: ...
class Col(Expression): class Col(Expression):
def __init__(self, alias: str, target: str, output_field: Optional[_OutputField] = ...): ... def __init__(self, alias: str, target: str, output_field: Optional[_OutputField] = ...): ...
@@ -214,8 +187,7 @@ class Col(Expression):
class ExpressionList(Func): class ExpressionList(Func):
def __init__(self, *expressions: Union[BaseExpression, Combinable], **extra: Any) -> None: ... def __init__(self, *expressions: Union[BaseExpression, Combinable], **extra: Any) -> None: ...
class Random(Expression): class Random(Expression): ...
output_field: FloatField
class Ref(Expression): class Ref(Expression):
def __init__(self, refs: str, source: Expression): ... def __init__(self, refs: str, source: Expression): ...

View File

@@ -1,51 +1,11 @@
from datetime import date from typing import Any, Union
from decimal import Decimal
from typing import Any, Callable, Dict, List, Union
from django.db.models.expressions import Combinable, Expression
from django.db.models import Func from django.db.models import Func
from django.db.models.fields import Field from django.db.models.fields import Field
class Cast(Func): class Cast(Func):
contains_aggregate: bool def __init__(self, expression: Any, output_field: Union[str, Field]) -> None: ...
convert_value: Callable
extra: Dict[Any, Any]
is_summary: bool
output_field: Field
source_expressions: List[Combinable]
function: str = ...
template: str = ...
def __init__(self, expression: Union[date, Decimal, Expression, str], output_field: Union[str, Field]) -> None: ...
class Coalesce(Func): class Coalesce(Func): ...
contains_aggregate: bool class Greatest(Func): ...
convert_value: Callable class Least(Func): ...
extra: Dict[Any, Any]
is_summary: bool
output_field: Field
source_expressions: List[Combinable]
function: str = ...
def __init__(self, *expressions: Any, **extra: Any) -> None: ...
class Greatest(Func):
contains_aggregate: bool
contains_over_clause: bool
convert_value: Callable
extra: Dict[Any, Any]
is_summary: bool
output_field: Field
source_expressions: List[Combinable]
function: str = ...
def __init__(self, *expressions: Any, **extra: Any) -> None: ...
class Least(Func):
contains_aggregate: bool
contains_over_clause: bool
convert_value: Callable
extra: Dict[Any, Any]
is_summary: bool
output_field: Field
source_expressions: List[Combinable]
function: str = ...
def __init__(self, *expressions: Any, **extra: Any) -> None: ...

View File

@@ -1,105 +1,47 @@
from typing import Any, List, Optional, Tuple, Union, Callable from typing import Any, List, Optional, Tuple, Union
from django.db.backends.sqlite3.base import DatabaseWrapper from django.db.backends.sqlite3.base import DatabaseWrapper
from django.db.models.expressions import Combinable, Expression, Value from django.db.models.expressions import Combinable, Expression, Value
from django.db.models.sql.compiler import SQLCompiler from django.db.models.sql.compiler import SQLCompiler
from django.db.models import Func, Transform from django.db.models import Func, Transform
from django.db.models.fields import Field
class BytesToCharFieldConversionMixin: class BytesToCharFieldConversionMixin: ...
def convert_value( class Chr(Transform): ...
self, value: str, expression: BytesToCharFieldConversionMixin, connection: DatabaseWrapper
) -> str: ...
class Chr(Transform):
contains_aggregate: bool
lookup_name: str = ...
class ConcatPair(Func): class ConcatPair(Func):
contains_aggregate: bool
def coalesce(self) -> ConcatPair: ... def coalesce(self) -> ConcatPair: ...
class Concat(Func): class Concat(Func): ...
contains_aggregate: bool
convert_value: Callable
class Left(Func): class Left(Func):
contains_aggregate: bool
contains_over_clause: bool
convert_value: Callable
output_field: Field
def __init__(self, expression: str, length: Union[Value, int], **extra: Any) -> None: ... def __init__(self, expression: str, length: Union[Value, int], **extra: Any) -> None: ...
def get_substr(self) -> Substr: ... def get_substr(self) -> Substr: ...
def use_substr( def use_substr(
self, compiler: SQLCompiler, connection: DatabaseWrapper, **extra_context: Any self, compiler: SQLCompiler, connection: DatabaseWrapper, **extra_context: Any
) -> Tuple[str, List[int]]: ... ) -> Tuple[str, List[int]]: ...
as_oracle: Any = ...
as_sqlite: Any = ...
class Length(Transform): class Length(Transform): ...
contains_aggregate: bool class Lower(Transform): ...
convert_value: Callable
lookup_name: str = ...
class Lower(Transform):
contains_aggregate: bool
contains_column_references: bool
contains_over_clause: bool
convert_value: Callable
lookup_name: str = ...
class LPad(BytesToCharFieldConversionMixin, Func): class LPad(BytesToCharFieldConversionMixin, Func):
contains_aggregate: bool
convert_value: Callable
def __init__(self, expression: str, length: Union[Length, int], fill_text: Value = ..., **extra: Any) -> None: ... def __init__(self, expression: str, length: Union[Length, int], fill_text: Value = ..., **extra: Any) -> None: ...
class LTrim(Transform): class LTrim(Transform): ...
contains_aggregate: bool class Ord(Transform): ...
convert_value: Callable
lookup_name: str = ...
class Ord(Transform):
contains_aggregate: bool
convert_value: Callable
lookup_name: str = ...
def as_mysql(self, compiler: Any, connection: Any, **extra_context: Any): ...
class Repeat(BytesToCharFieldConversionMixin, Func): class Repeat(BytesToCharFieldConversionMixin, Func):
contains_aggregate: bool
convert_value: Callable
output_field: django.db.models.fields.CharField
def __init__(self, expression: Union[Value, str], number: Union[Length, int], **extra: Any) -> None: ... def __init__(self, expression: Union[Value, str], number: Union[Length, int], **extra: Any) -> None: ...
def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any): ...
class Replace(Func): class Replace(Func):
contains_aggregate: bool
contains_over_clause: bool
convert_value: Callable
output_field: Field
def __init__(self, expression: Combinable, text: Value, replacement: Value = ..., **extra: Any) -> None: ... def __init__(self, expression: Combinable, text: Value, replacement: Value = ..., **extra: Any) -> None: ...
class Right(Left): ... class Right(Left): ...
class RPad(LPad): ...
class RPad(LPad): class RTrim(Transform): ...
output_field: Field class StrIndex(Func): ...
class RTrim(Transform):
contains_aggregate: bool
convert_value: Callable
lookup_name: str = ...
class StrIndex(Func):
contains_aggregate: bool
convert_value: Callable
output_field: Any = ...
def as_postgresql(self, compiler: Any, connection: Any): ...
class Substr(Func): class Substr(Func):
contains_aggregate: bool
contains_over_clause: bool
convert_value: Callable
output_field: Field
def __init__( def __init__(
self, self,
expression: Union[Expression, str], expression: Union[Expression, str],
@@ -108,13 +50,5 @@ class Substr(Func):
**extra: Any **extra: Any
) -> None: ... ) -> None: ...
class Trim(Transform): class Trim(Transform): ...
contains_aggregate: bool class Upper(Transform): ...
convert_value: Callable
lookup_name: str = ...
class Upper(Transform):
contains_aggregate: bool
contains_over_clause: bool
convert_value: Callable
lookup_name: str = ...

View File

@@ -1,69 +1,26 @@
from typing import Any, Optional, Dict, List from typing import Any, Optional
from django.db.models import Func from django.db.models import Func
class CumeDist(Func): class CumeDist(Func): ...
function: str = ... class DenseRank(Func): ...
name: str = ... class FirstValue(Func): ...
output_field: Any = ...
window_compatible: bool = ...
class DenseRank(Func):
name: str = ...
output_field: Any = ...
window_compatible: bool = ...
class FirstValue(Func):
name: str = ...
window_compatible: bool = ...
class LagLeadFunction(Func): class LagLeadFunction(Func):
window_compatible: bool = ...
def __init__( def __init__(
self, expression: Optional[str], offset: int = ..., default: Optional[int] = ..., **extra: Any self, expression: Optional[str], offset: int = ..., default: Optional[int] = ..., **extra: Any
) -> None: ... ) -> None: ...
class Lag(LagLeadFunction): class Lag(LagLeadFunction): ...
function: str = ... class LastValue(Func): ...
name: str = ... class Lead(LagLeadFunction): ...
class LastValue(Func):
arity: int = ...
function: str = ...
name: str = ...
window_compatible: bool = ...
class Lead(LagLeadFunction):
function: str = ...
name: str = ...
class NthValue(Func): class NthValue(Func):
function: str = ... def __init__(self, expression: Optional[str], nth: int = ..., **extra: Any) -> None: ...
name: str = ...
window_compatible: bool = ...
def __init__(self, expression: Optional[str], nth: int = ..., **extra: Any) -> Any: ...
class Ntile(Func): class Ntile(Func):
function: str = ... def __init__(self, num_buckets: int = ..., **extra: Any) -> None: ...
name: str = ...
output_field: Any = ...
window_compatible: bool = ...
def __init__(self, num_buckets: int = ..., **extra: Any) -> Any: ...
class PercentRank(Func): class PercentRank(Func): ...
function: str = ... class Rank(Func): ...
name: str = ... class RowNumber(Func): ...
output_field: Any = ...
window_compatible: bool = ...
class Rank(Func):
function: str = ...
name: str = ...
output_field: Any = ...
window_compatible: bool = ...
class RowNumber(Func):
function: str = ...
name: str = ...
output_field: Any = ...
window_compatible: bool = ...

View File

@@ -84,14 +84,23 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type:
def set_first_generic_param_as_default_for_second(fullname: str, ctx: AnalyzeTypeContext) -> Type: def set_first_generic_param_as_default_for_second(fullname: str, ctx: AnalyzeTypeContext) -> Type:
if not ctx.type.args: if not ctx.type.args:
return ctx.api.named_type(fullname, [AnyType(TypeOfAny.explicit), try:
AnyType(TypeOfAny.explicit)]) return ctx.api.named_type(fullname, [AnyType(TypeOfAny.explicit),
AnyType(TypeOfAny.explicit)])
except KeyError:
# really should never happen
return AnyType(TypeOfAny.explicit)
args = ctx.type.args args = ctx.type.args
if len(args) == 1: if len(args) == 1:
args = [args[0], args[0]] args = [args[0], args[0]]
analyzed_args = [ctx.api.analyze_type(arg) for arg in args] analyzed_args = [ctx.api.analyze_type(arg) for arg in args]
return ctx.api.named_type(fullname, analyzed_args) try:
return ctx.api.named_type(fullname, analyzed_args)
except KeyError:
# really should never happen
return AnyType(TypeOfAny.explicit)
def return_user_model_hook(ctx: FunctionContext) -> Type: def return_user_model_hook(ctx: FunctionContext) -> Type: