Fix errors in db.models.expressions and db.models.functions.* (#54)

* fix errors at db.models.expressions and db.models.functions.*

* catch KeyError if QuerySet has not been loaded
This commit is contained in:
Maxim Kurnikov
2019-03-24 02:54:10 +03:00
committed by GitHub
parent 77f15d7478
commit 5d0ee40ada
5 changed files with 73 additions and 241 deletions

View File

@@ -1,20 +1,18 @@
from collections import OrderedDict
from datetime import datetime, timedelta
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, Union, TypeVar
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, TypeVar, Union
from django.db.models.lookups import Lookup
from django.db.models.sql.compiler import SQLCompiler
from django.db.models import Q, QuerySet
from django.db.models.fields import Field, FloatField
from django.db.models.sql import Query
from django.db.models.fields import Field
_OutputField = Union[Field, str]
class SQLiteNumericMixin:
def as_sqlite(self, compiler: SQLCompiler, connection: Any, **extra_context: Any) -> Tuple[str, List[float]]: ...
_SelfCombinable = TypeVar("_SelfCombinable", bound="Combinable")
_Self = TypeVar("_Self")
class Combinable:
ADD: str = ...
@@ -27,22 +25,20 @@ class Combinable:
BITOR: str = ...
BITLEFTSHIFT: str = ...
BITRIGHTSHIFT: str = ...
def __neg__(self: _SelfCombinable) -> _SelfCombinable: ...
def __add__(
self: _SelfCombinable, other: Optional[Union[timedelta, Combinable, float, str]]
) -> _SelfCombinable: ...
def __sub__(self: _SelfCombinable, other: Union[timedelta, Combinable, float]) -> _SelfCombinable: ...
def __mul__(self: _SelfCombinable, other: Union[timedelta, Combinable, float]) -> _SelfCombinable: ...
def __truediv__(self: _SelfCombinable, other: Union[Combinable, float]) -> _SelfCombinable: ...
def __itruediv__(self: _SelfCombinable, other: Union[Combinable, float]) -> _SelfCombinable: ...
def __mod__(self: _SelfCombinable, other: Union[int, Combinable]) -> _SelfCombinable: ...
def __pow__(self: _SelfCombinable, other: Union[float, Combinable]) -> _SelfCombinable: ...
def __and__(self: _SelfCombinable, other: Combinable) -> _SelfCombinable: ...
def bitand(self: _SelfCombinable, other: int) -> _SelfCombinable: ...
def bitleftshift(self: _SelfCombinable, other: int) -> _SelfCombinable: ...
def bitrightshift(self: _SelfCombinable, other: int) -> _SelfCombinable: ...
def __or__(self: _SelfCombinable, other: Combinable) -> _SelfCombinable: ...
def bitor(self: _SelfCombinable, other: int) -> _SelfCombinable: ...
def __neg__(self: _Self) -> _Self: ...
def __add__(self: _Self, other: Optional[Union[timedelta, Combinable, float, str]]) -> _Self: ...
def __sub__(self: _Self, other: Union[timedelta, Combinable, float]) -> _Self: ...
def __mul__(self: _Self, other: Union[timedelta, Combinable, float]) -> _Self: ...
def __truediv__(self: _Self, other: Union[Combinable, float]) -> _Self: ...
def __itruediv__(self: _Self, other: Union[Combinable, float]) -> _Self: ...
def __mod__(self: _Self, other: Union[int, Combinable]) -> _Self: ...
def __pow__(self: _Self, other: Union[float, Combinable]) -> _Self: ...
def __and__(self: _Self, other: Combinable) -> _Self: ...
def bitand(self: _Self, other: int) -> _Self: ...
def bitleftshift(self: _Self, other: int) -> _Self: ...
def bitrightshift(self: _Self, other: int) -> _Self: ...
def __or__(self: _Self, other: Combinable) -> _Self: ...
def bitor(self: _Self, other: int) -> _Self: ...
def __radd__(self, other: Optional[Union[datetime, float, Combinable]]) -> Combinable: ...
def __rsub__(self, other: Union[float, Combinable]) -> Combinable: ...
def __rmul__(self, other: Union[float, Combinable]) -> Combinable: ...
@@ -52,43 +48,45 @@ class Combinable:
def __rand__(self, other: Any) -> Combinable: ...
def __ror__(self, other: Any) -> Combinable: ...
_SelfBaseExpression = TypeVar("_SelfBaseExpression", bound="BaseExpression")
class BaseExpression:
is_summary: bool = ...
filterable: bool = ...
window_compatible: bool = ...
output_field: Any
def __init__(self, output_field: Optional[_OutputField] = ...) -> None: ...
def get_db_converters(self, connection: Any) -> List[Callable]: ...
def get_source_expressions(self) -> List[Any]: ...
def set_source_expressions(self, exprs: List[Any]) -> None: ...
def set_source_expressions(self, exprs: Sequence[Combinable]) -> None: ...
@property
def contains_aggregate(self) -> bool: ...
@property
def contains_over_clause(self) -> bool: ...
@property
def contains_column_references(self) -> bool: ...
def resolve_expression(
self,
self: _Self,
query: Any = ...,
allow_joins: bool = ...,
reuse: Optional[Set[str]] = ...,
summarize: bool = ...,
for_save: bool = ...,
) -> BaseExpression: ...
) -> _Self: ...
@property
def field(self) -> Field: ...
@property
def output_field(self) -> Field: ...
@property
def convert_value(self) -> Callable: ...
def get_lookup(self, lookup: str) -> Optional[Type[Lookup]]: ...
def get_transform(self, name: str) -> Optional[Type[Expression]]: ...
def relabeled_clone(self, change_map: Dict[Optional[str], str]) -> Expression: ...
def copy(self) -> BaseExpression: ...
def get_group_by_cols(self: _SelfBaseExpression) -> List[_SelfBaseExpression]: ...
def get_group_by_cols(self: _Self) -> List[_Self]: ...
def get_source_fields(self) -> List[Optional[Field]]: ...
def asc(self, **kwargs: Any) -> Expression: ...
def desc(self, **kwargs: Any) -> Expression: ...
def reverse_ordering(self): ...
def flatten(self) -> Iterator[Expression]: ...
def __hash__(self) -> int: ...
def deconstruct(self) -> Any: ...
def as_sqlite(self, compiler: SQLCompiler, connection: Any) -> Any: ...
def as_sql(self, compiler: SQLCompiler, connection: Any, **extra_context: Any) -> Any: ...
@@ -105,28 +103,18 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
def __init__(
self, lhs: Combinable, connector: str, rhs: Combinable, output_field: Optional[_OutputField] = ...
) -> None: ...
def get_source_expressions(self) -> Union[List[Combinable], List[SQLiteNumericMixin]]: ...
def set_source_expressions(self, exprs: List[Combinable]) -> None: ...
def resolve_expression(
self,
query: Any = ...,
allow_joins: bool = ...,
reuse: Optional[Set[str]] = ...,
summarize: bool = ...,
for_save: bool = ...,
) -> CombinedExpression: ...
class F(Combinable):
name: str
def __init__(self, name: str): ...
def resolve_expression(
self,
self: _Self,
query: Any = ...,
allow_joins: bool = ...,
reuse: Optional[Set[str]] = ...,
summarize: bool = ...,
for_save: bool = ...,
) -> Expression: ...
) -> _Self: ...
def asc(self, **kwargs) -> OrderBy: ...
def desc(self, **kwargs) -> OrderBy: ...
def deconstruct(self) -> Any: ...
@@ -141,8 +129,6 @@ class Subquery(Expression):
def __init__(self, queryset: QuerySet, output_field: Optional[_OutputField] = ..., **extra: Any) -> None: ...
class Exists(Subquery):
extra: Dict[Any, Any]
template: str = ...
negated: bool = ...
def __init__(self, *args: Any, negated: bool = ..., **kwargs: Any) -> None: ...
def __invert__(self) -> Exists: ...
@@ -162,30 +148,19 @@ class Value(Expression):
def __init__(self, value: Any, output_field: Optional[_OutputField] = ...) -> None: ...
class RawSQL(Expression):
output_field: Field
params: List[Any]
sql: str
def __init__(self, sql: str, params: Sequence[Any], output_field: Optional[_OutputField] = ...) -> None: ...
class Func(SQLiteNumericMixin, Expression):
function: str = ...
name: str = ...
template: str = ...
arg_joiner: str = ...
arity: int = ...
source_expressions: List[Expression] = ...
source_expressions: List[Combinable] = ...
extra: Dict[Any, Any] = ...
def __init__(self, *expressions: Any, output_field: Optional[_OutputField] = ..., **extra: Any) -> None: ...
def get_source_expressions(self) -> List[Combinable]: ...
def set_source_expressions(self, exprs: List[Expression]) -> None: ...
def resolve_expression(
self,
query: Query = ...,
allow_joins: bool = ...,
reuse: Optional[Set[Any]] = ...,
summarize: bool = ...,
for_save: bool = ...,
) -> Func: ...
def copy(self) -> Func: ...
class When(Expression):
template: str = ...
@@ -205,8 +180,6 @@ class Case(Expression):
class ExpressionWrapper(Expression):
def __init__(self, expression: Union[Q, Combinable], output_field: _OutputField): ...
def set_source_expressions(self, exprs: Sequence[Expression]) -> None: ...
def get_source_expressions(self) -> List[Expression]: ...
class Col(Expression):
def __init__(self, alias: str, target: str, output_field: Optional[_OutputField] = ...): ...
@@ -214,8 +187,7 @@ class Col(Expression):
class ExpressionList(Func):
def __init__(self, *expressions: Union[BaseExpression, Combinable], **extra: Any) -> None: ...
class Random(Expression):
output_field: FloatField
class Random(Expression): ...
class Ref(Expression):
def __init__(self, refs: str, source: Expression): ...

View File

@@ -1,51 +1,11 @@
from datetime import date
from decimal import Decimal
from typing import Any, Callable, Dict, List, Union
from django.db.models.expressions import Combinable, Expression
from typing import Any, Union
from django.db.models import Func
from django.db.models.fields import Field
class Cast(Func):
contains_aggregate: bool
convert_value: Callable
extra: Dict[Any, Any]
is_summary: bool
output_field: Field
source_expressions: List[Combinable]
function: str = ...
template: str = ...
def __init__(self, expression: Union[date, Decimal, Expression, str], output_field: Union[str, Field]) -> None: ...
def __init__(self, expression: Any, output_field: Union[str, Field]) -> None: ...
class Coalesce(Func):
contains_aggregate: bool
convert_value: Callable
extra: Dict[Any, Any]
is_summary: bool
output_field: Field
source_expressions: List[Combinable]
function: str = ...
def __init__(self, *expressions: Any, **extra: Any) -> None: ...
class Greatest(Func):
contains_aggregate: bool
contains_over_clause: bool
convert_value: Callable
extra: Dict[Any, Any]
is_summary: bool
output_field: Field
source_expressions: List[Combinable]
function: str = ...
def __init__(self, *expressions: Any, **extra: Any) -> None: ...
class Least(Func):
contains_aggregate: bool
contains_over_clause: bool
convert_value: Callable
extra: Dict[Any, Any]
is_summary: bool
output_field: Field
source_expressions: List[Combinable]
function: str = ...
def __init__(self, *expressions: Any, **extra: Any) -> None: ...
class Coalesce(Func): ...
class Greatest(Func): ...
class Least(Func): ...

View File

@@ -1,105 +1,47 @@
from typing import Any, List, Optional, Tuple, Union, Callable
from typing import Any, List, Optional, Tuple, Union
from django.db.backends.sqlite3.base import DatabaseWrapper
from django.db.models.expressions import Combinable, Expression, Value
from django.db.models.sql.compiler import SQLCompiler
from django.db.models import Func, Transform
from django.db.models.fields import Field
class BytesToCharFieldConversionMixin:
def convert_value(
self, value: str, expression: BytesToCharFieldConversionMixin, connection: DatabaseWrapper
) -> str: ...
class Chr(Transform):
contains_aggregate: bool
lookup_name: str = ...
class BytesToCharFieldConversionMixin: ...
class Chr(Transform): ...
class ConcatPair(Func):
contains_aggregate: bool
def coalesce(self) -> ConcatPair: ...
class Concat(Func):
contains_aggregate: bool
convert_value: Callable
class Concat(Func): ...
class Left(Func):
contains_aggregate: bool
contains_over_clause: bool
convert_value: Callable
output_field: Field
def __init__(self, expression: str, length: Union[Value, int], **extra: Any) -> None: ...
def get_substr(self) -> Substr: ...
def use_substr(
self, compiler: SQLCompiler, connection: DatabaseWrapper, **extra_context: Any
) -> Tuple[str, List[int]]: ...
as_oracle: Any = ...
as_sqlite: Any = ...
class Length(Transform):
contains_aggregate: bool
convert_value: Callable
lookup_name: str = ...
class Lower(Transform):
contains_aggregate: bool
contains_column_references: bool
contains_over_clause: bool
convert_value: Callable
lookup_name: str = ...
class Length(Transform): ...
class Lower(Transform): ...
class LPad(BytesToCharFieldConversionMixin, Func):
contains_aggregate: bool
convert_value: Callable
def __init__(self, expression: str, length: Union[Length, int], fill_text: Value = ..., **extra: Any) -> None: ...
class LTrim(Transform):
contains_aggregate: bool
convert_value: Callable
lookup_name: str = ...
class Ord(Transform):
contains_aggregate: bool
convert_value: Callable
lookup_name: str = ...
def as_mysql(self, compiler: Any, connection: Any, **extra_context: Any): ...
class LTrim(Transform): ...
class Ord(Transform): ...
class Repeat(BytesToCharFieldConversionMixin, Func):
contains_aggregate: bool
convert_value: Callable
output_field: django.db.models.fields.CharField
def __init__(self, expression: Union[Value, str], number: Union[Length, int], **extra: Any) -> None: ...
def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any): ...
class Replace(Func):
contains_aggregate: bool
contains_over_clause: bool
convert_value: Callable
output_field: Field
def __init__(self, expression: Combinable, text: Value, replacement: Value = ..., **extra: Any) -> None: ...
class Right(Left): ...
class RPad(LPad):
output_field: Field
class RTrim(Transform):
contains_aggregate: bool
convert_value: Callable
lookup_name: str = ...
class StrIndex(Func):
contains_aggregate: bool
convert_value: Callable
output_field: Any = ...
def as_postgresql(self, compiler: Any, connection: Any): ...
class RPad(LPad): ...
class RTrim(Transform): ...
class StrIndex(Func): ...
class Substr(Func):
contains_aggregate: bool
contains_over_clause: bool
convert_value: Callable
output_field: Field
def __init__(
self,
expression: Union[Expression, str],
@@ -108,13 +50,5 @@ class Substr(Func):
**extra: Any
) -> None: ...
class Trim(Transform):
contains_aggregate: bool
convert_value: Callable
lookup_name: str = ...
class Upper(Transform):
contains_aggregate: bool
contains_over_clause: bool
convert_value: Callable
lookup_name: str = ...
class Trim(Transform): ...
class Upper(Transform): ...

View File

@@ -1,69 +1,26 @@
from typing import Any, Optional, Dict, List
from typing import Any, Optional
from django.db.models import Func
class CumeDist(Func):
function: str = ...
name: str = ...
output_field: Any = ...
window_compatible: bool = ...
class DenseRank(Func):
name: str = ...
output_field: Any = ...
window_compatible: bool = ...
class FirstValue(Func):
name: str = ...
window_compatible: bool = ...
class CumeDist(Func): ...
class DenseRank(Func): ...
class FirstValue(Func): ...
class LagLeadFunction(Func):
window_compatible: bool = ...
def __init__(
self, expression: Optional[str], offset: int = ..., default: Optional[int] = ..., **extra: Any
) -> None: ...
class Lag(LagLeadFunction):
function: str = ...
name: str = ...
class LastValue(Func):
arity: int = ...
function: str = ...
name: str = ...
window_compatible: bool = ...
class Lead(LagLeadFunction):
function: str = ...
name: str = ...
class Lag(LagLeadFunction): ...
class LastValue(Func): ...
class Lead(LagLeadFunction): ...
class NthValue(Func):
function: str = ...
name: str = ...
window_compatible: bool = ...
def __init__(self, expression: Optional[str], nth: int = ..., **extra: Any) -> Any: ...
def __init__(self, expression: Optional[str], nth: int = ..., **extra: Any) -> None: ...
class Ntile(Func):
function: str = ...
name: str = ...
output_field: Any = ...
window_compatible: bool = ...
def __init__(self, num_buckets: int = ..., **extra: Any) -> Any: ...
def __init__(self, num_buckets: int = ..., **extra: Any) -> None: ...
class PercentRank(Func):
function: str = ...
name: str = ...
output_field: Any = ...
window_compatible: bool = ...
class Rank(Func):
function: str = ...
name: str = ...
output_field: Any = ...
window_compatible: bool = ...
class RowNumber(Func):
function: str = ...
name: str = ...
output_field: Any = ...
window_compatible: bool = ...
class PercentRank(Func): ...
class Rank(Func): ...
class RowNumber(Func): ...

View File

@@ -84,14 +84,23 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type:
def set_first_generic_param_as_default_for_second(fullname: str, ctx: AnalyzeTypeContext) -> Type:
if not ctx.type.args:
return ctx.api.named_type(fullname, [AnyType(TypeOfAny.explicit),
AnyType(TypeOfAny.explicit)])
try:
return ctx.api.named_type(fullname, [AnyType(TypeOfAny.explicit),
AnyType(TypeOfAny.explicit)])
except KeyError:
# really should never happen
return AnyType(TypeOfAny.explicit)
args = ctx.type.args
if len(args) == 1:
args = [args[0], args[0]]
analyzed_args = [ctx.api.analyze_type(arg) for arg in args]
return ctx.api.named_type(fullname, analyzed_args)
try:
return ctx.api.named_type(fullname, analyzed_args)
except KeyError:
# really should never happen
return AnyType(TypeOfAny.explicit)
def return_user_model_hook(ctx: FunctionContext) -> Type: