Use _AsSqlType for as_sql (#1052)

* Annotate the return type of as_sql for SpatialOperator.

Its subclasses inherits the type annotation from `SpatialOperator`, so
copying `as_sql` over is unnecessary.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>

* Remove unnecessary as_sql definition.

`Query` inherits `as_sql` from `BaseExpression`, `GISLookup` inherits
`as_sql` from `Lookup`, and `BuiltinLookup` inherits `as_sql` from
`Lookup[_T]`.  None is required to be redefined.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>

* Unify return types of as_sql and friends as _AsSqlType.

`Tuple[str, _ParamsT]`, `Tuple[str, List[Union[str, int]]]` and other
similar type annotations are all replaced with the `_AsSqlType`
alias. Any as_sql definition that annotate the return type as `Any` is also
updated to use `_AsSqlType`.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
PIG208
2022-07-13 02:47:01 -04:00
committed by GitHub
parent 29f0762540
commit 3648e10350
15 changed files with 40 additions and 34 deletions

View File

@@ -18,7 +18,6 @@ class SDODisjoint(SpatialOperator):
class SDORelate(SpatialOperator):
sql_template: str = ...
def check_relate_argument(self, arg: Any) -> None: ...
def as_sql(self, connection: Any, lookup: Any, template_params: Any, sql_params: Any): ...
class OracleOperations(BaseSpatialOperations, DatabaseOperations):
name: str = ...

View File

@@ -19,7 +19,6 @@ class PostGISOperator(SpatialOperator):
def __init__(
self, geography: bool = ..., raster: Union[bool, Literal["bilateral"]] = ..., **kwargs: Any
) -> None: ...
def as_sql(self, connection: Any, lookup: Any, template_params: Any, *args: Any): ...
def check_raster(self, lookup: Any, template_params: Any): ...
class ST_Polygon(Func):

View File

@@ -4,8 +4,7 @@ from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations
from django.contrib.gis.db.backends.utils import SpatialOperator as SpatialOperator
from django.db.backends.sqlite3.operations import DatabaseOperations
class SpatialiteNullCheckOperator(SpatialOperator):
def as_sql(self, connection: Any, lookup: Any, template_params: Any, sql_params: Any): ...
class SpatialiteNullCheckOperator(SpatialOperator): ...
class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
name: str = ...

View File

@@ -1,5 +1,7 @@
from typing import Any, Optional
from django.db.models.sql.compiler import _AsSqlType
class SpatialOperator:
sql_template: Any = ...
op: Any = ...
@@ -7,4 +9,4 @@ class SpatialOperator:
def __init__(self, op: Optional[Any] = ..., func: Optional[Any] = ...) -> None: ...
@property
def default_template(self): ...
def as_sql(self, connection: Any, lookup: Any, template_params: Any, sql_params: Any): ...
def as_sql(self, connection: Any, lookup: Any, template_params: Any, sql_params: Any) -> _AsSqlType: ...

View File

@@ -1,13 +1,16 @@
from typing import Any, Optional
from django.db.models import Aggregate
from django.db.models.sql.compiler import _AsSqlType
class GeoAggregate(Aggregate):
function: Any = ...
is_extent: bool = ...
@property
def output_field(self): ...
def as_sql(self, compiler: Any, connection: Any, function: Optional[Any] = ..., **extra_context: Any): ...
def as_sql(
self, compiler: Any, connection: Any, function: Optional[Any] = ..., **extra_context: Any
) -> _AsSqlType: ...
def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any): ...
def resolve_expression(
self,

View File

@@ -2,6 +2,7 @@ from typing import Any, Optional
from django.db.models import Func
from django.db.models import Transform as StandardTransform
from django.db.models.sql.compiler import _AsSqlType
NUMERIC_TYPES: Any
@@ -11,7 +12,9 @@ class GeoFuncMixin:
def __init__(self, *expressions: Any, **extra: Any) -> None: ...
@property
def geo_field(self): ...
def as_sql(self, compiler: Any, connection: Any, function: Optional[Any] = ..., **extra_context: Any): ...
def as_sql(
self, compiler: Any, connection: Any, function: Optional[Any] = ..., **extra_context: Any
) -> _AsSqlType: ...
def resolve_expression(self, *args: Any, **kwargs: Any): ...
class GeoFunc(GeoFuncMixin, Func): ...

View File

@@ -18,7 +18,6 @@ class GISLookup(Lookup):
rhs: Any = ...
def process_rhs(self, compiler: Any, connection: Any): ...
def get_rhs_op(self, connection: Any, rhs: Any): ...
def as_sql(self, compiler: Any, connection: Any): ...
class OverlapsLeftLookup(GISLookup):
lookup_name: str = ...
@@ -107,8 +106,7 @@ class DWithinLookup(DistanceLookupBase):
def process_distance(self, compiler: Any, connection: Any): ...
def process_rhs(self, compiler: Any, connection: Any): ...
class DistanceLookupFromFunction(DistanceLookupBase):
def as_sql(self, compiler: Any, connection: Any): ...
class DistanceLookupFromFunction(DistanceLookupBase): ...
class DistanceGTLookup(DistanceLookupFromFunction):
lookup_name: str = ...

View File

@@ -1,6 +1,7 @@
from typing import Any
from django.db.models.sql import compiler as compiler
from django.db.models.sql.compiler import _AsSqlType
class SQLCompiler(compiler.SQLCompiler):
def as_subquery_condition(self, alias: Any, columns: Any, compiler: Any): ...
@@ -8,7 +9,7 @@ class SQLCompiler(compiler.SQLCompiler):
class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler): ...
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
def as_sql(self): ...
def as_sql(self) -> _AsSqlType: ...
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): ...
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler): ...

View File

@@ -99,7 +99,7 @@ class BaseExpression:
def desc(self, **kwargs: Any) -> OrderBy: ...
def reverse_ordering(self) -> BaseExpression: ...
def flatten(self) -> Iterator[BaseExpression]: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...
def deconstruct(self) -> Any: ... # fake
class Expression(BaseExpression, Combinable): ...

View File

@@ -6,7 +6,7 @@ from django.db.backends.ddl_references import Statement
from django.db.models.base import Model
from django.db.models.expressions import BaseExpression, Combinable, Expression, Func
from django.db.models.query_utils import Q
from django.db.models.sql.compiler import SQLCompiler
from django.db.models.sql.compiler import SQLCompiler, _AsSqlType
class Index:
model: Type[Model]
@@ -54,4 +54,4 @@ class IndexExpression(Func):
summarize: bool = ...,
for_save: bool = ...,
) -> IndexExpression: ...
def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> Any: ...
def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

View File

@@ -9,7 +9,7 @@ else:
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.expressions import Expression, Func
from django.db.models.query_utils import RegisterLookupMixin
from django.db.models.sql.compiler import SQLCompiler, _ParamsT, _ParamT
from django.db.models.sql.compiler import SQLCompiler, _AsSqlType, _ParamT
from django.utils.datastructures import OrderedSet
_L = TypeVar("_L", bound="Lookup")
@@ -30,16 +30,16 @@ class Lookup(Generic[_T]):
def get_source_expressions(self) -> List[Expression]: ...
def set_source_expressions(self, new_exprs: List[Expression]) -> None: ...
def get_prep_lookup(self) -> Any: ...
def get_db_prep_lookup(self, value: _ParamT, connection: BaseDatabaseWrapper) -> Tuple[str, _ParamsT]: ...
def get_db_prep_lookup(self, value: _ParamT, connection: BaseDatabaseWrapper) -> _AsSqlType: ...
def process_lhs(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Optional[Expression] = ...
) -> Tuple[str, _ParamsT]: ...
def process_rhs(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, _ParamsT]: ...
) -> _AsSqlType: ...
def process_rhs(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...
def rhs_is_direct_value(self) -> bool: ...
def relabeled_clone(self: _L, relabels: Mapping[str, str]) -> _L: ...
def get_group_by_cols(self, alias: Optional[str] = ...) -> List[Expression]: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, _ParamsT]: ...
def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, _ParamsT]: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...
def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...
@property
def contains_aggregate(self) -> bool: ...
@property
@@ -58,24 +58,23 @@ class Transform(RegisterLookupMixin, Func):
class BuiltinLookup(Lookup[_T]):
def process_lhs(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Optional[Expression] = ...
) -> Tuple[str, _ParamsT]: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, _ParamsT]: ...
) -> _AsSqlType: ...
def get_rhs_op(self, connection: BaseDatabaseWrapper, rhs: str) -> str: ...
class FieldGetDbPrepValueMixin:
get_db_prep_lookup_value_is_iterable: bool = ...
def get_db_prep_lookup(self, value: _ParamT, connection: BaseDatabaseWrapper) -> Tuple[str, _ParamsT]: ...
def get_db_prep_lookup(self, value: _ParamT, connection: BaseDatabaseWrapper) -> _AsSqlType: ...
class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
get_db_prep_lookup_value_is_iterable: Literal[True] = ...
def get_prep_lookup(self) -> Iterable[Any]: ...
def resolve_expression_parameter(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, sql: str, param: Any
) -> Tuple[str, _ParamsT]: ...
) -> _AsSqlType: ...
class PostgresOperatorLookup(FieldGetDbPrepValueMixin, Lookup[_T]):
postgres_operator: str = ...
def as_postgresql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, _ParamsT]: ...
def as_postgresql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...
class Exact(FieldGetDbPrepValueMixin, BuiltinLookup[_T]): ...
class IExact(BuiltinLookup[_T]): ...

View File

@@ -17,6 +17,8 @@ from typing import (
Union,
)
from django.db.models.sql.compiler import _AsSqlType
if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
@@ -105,4 +107,4 @@ class FilteredRelation:
def __init__(self, relation_name: str, *, condition: Q = ...) -> None: ...
def clone(self) -> FilteredRelation: ...
def resolve_expression(self, *args: Any, **kwargs: Any) -> None: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.fields.mixins import FieldCacheMixin
from django.db.models.query_utils import FilteredRelation, PathInfo
from django.db.models.sql.compiler import SQLCompiler
from django.db.models.sql.compiler import SQLCompiler, _AsSqlType
class MultiJoin(Exception):
level: int = ...
@@ -31,7 +31,7 @@ class Join:
nullable: bool,
filtered_relation: Optional[FilteredRelation] = ...,
) -> None: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, List[Union[int, str]]]: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...
def relabeled_clone(self, change_map: Dict[Optional[str], str]) -> Join: ...
def equals(self, other: Union[BaseTable, Join], with_filtered_relation: bool) -> bool: ...
def demote(self) -> Join: ...
@@ -44,6 +44,6 @@ class BaseTable:
table_name: str = ...
table_alias: Optional[str] = ...
def __init__(self, table_name: str, alias: Optional[str]) -> None: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, List[Any]]: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...
def relabeled_clone(self, change_map: Dict[Optional[str], str]) -> BaseTable: ...
def equals(self, other: Join, with_filtered_relation: bool) -> bool: ...

View File

@@ -116,7 +116,6 @@ class Query(BaseExpression):
def get_initial_alias(self) -> str: ...
def count_active_tables(self) -> int: ...
def resolve_expression(self, query: Query, *args: Any, **kwargs: Any) -> Query: ... # type: ignore
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
def resolve_lookup_value(self, value: Any, can_reuse: Optional[Set[str]], allow_joins: bool) -> Any: ...
def solve_lookup_type(
self, lookup: str

View File

@@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.expressions import Expression
from django.db.models.sql.compiler import SQLCompiler
from django.db.models.sql.compiler import SQLCompiler, _AsSqlType
from django.db.models.sql.query import Query
from django.utils import tree
@@ -31,7 +31,9 @@ class WhereNode(tree.Node):
class NothingNode:
contains_aggregate: bool = ...
def as_sql(self, compiler: Optional[SQLCompiler] = ..., connection: Optional[BaseDatabaseWrapper] = ...) -> Any: ...
def as_sql(
self, compiler: Optional[SQLCompiler] = ..., connection: Optional[BaseDatabaseWrapper] = ...
) -> _AsSqlType: ...
class ExtraWhere:
contains_aggregate: bool = ...
@@ -40,7 +42,7 @@ class ExtraWhere:
def __init__(self, sqls: Sequence[str], params: Optional[Union[Sequence[int], Sequence[str]]]) -> None: ...
def as_sql(
self, compiler: Optional[SQLCompiler] = ..., connection: Optional[BaseDatabaseWrapper] = ...
) -> Tuple[str, Union[List[int], List[str]]]: ...
) -> _AsSqlType: ...
class SubqueryConstraint:
contains_aggregate: bool = ...
@@ -49,4 +51,4 @@ class SubqueryConstraint:
targets: List[str] = ...
query_object: Query = ...
def __init__(self, alias: str, columns: List[str], targets: List[str], query_object: Query) -> None: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, Tuple]: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...