Use _AsSqlType for as_sql (#1052)

* Annotate the return type of as_sql for SpatialOperator.

Its subclasses inherits the type annotation from `SpatialOperator`, so
copying `as_sql` over is unnecessary.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>

* Remove unnecessary as_sql definition.

`Query` inherits `as_sql` from `BaseExpression`, `GISLookup` inherits
`as_sql` from `Lookup`, and `BuiltinLookup` inherits `as_sql` from
`Lookup[_T]`.  None is required to be redefined.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>

* Unify return types of as_sql and friends as _AsSqlType.

`Tuple[str, _ParamsT]`, `Tuple[str, List[Union[str, int]]]` and other
similar type annotations are all replaced with the `_AsSqlType`
alias. Any as_sql definition that annotate the return type as `Any` is also
updated to use `_AsSqlType`.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
PIG208
2022-07-13 02:47:01 -04:00
committed by GitHub
parent 29f0762540
commit 3648e10350
15 changed files with 40 additions and 34 deletions

View File

@@ -18,7 +18,6 @@ class SDODisjoint(SpatialOperator):
class SDORelate(SpatialOperator): class SDORelate(SpatialOperator):
sql_template: str = ... sql_template: str = ...
def check_relate_argument(self, arg: Any) -> None: ... def check_relate_argument(self, arg: Any) -> None: ...
def as_sql(self, connection: Any, lookup: Any, template_params: Any, sql_params: Any): ...
class OracleOperations(BaseSpatialOperations, DatabaseOperations): class OracleOperations(BaseSpatialOperations, DatabaseOperations):
name: str = ... name: str = ...

View File

@@ -19,7 +19,6 @@ class PostGISOperator(SpatialOperator):
def __init__( def __init__(
self, geography: bool = ..., raster: Union[bool, Literal["bilateral"]] = ..., **kwargs: Any self, geography: bool = ..., raster: Union[bool, Literal["bilateral"]] = ..., **kwargs: Any
) -> None: ... ) -> None: ...
def as_sql(self, connection: Any, lookup: Any, template_params: Any, *args: Any): ...
def check_raster(self, lookup: Any, template_params: Any): ... def check_raster(self, lookup: Any, template_params: Any): ...
class ST_Polygon(Func): class ST_Polygon(Func):

View File

@@ -4,8 +4,7 @@ from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations
from django.contrib.gis.db.backends.utils import SpatialOperator as SpatialOperator from django.contrib.gis.db.backends.utils import SpatialOperator as SpatialOperator
from django.db.backends.sqlite3.operations import DatabaseOperations from django.db.backends.sqlite3.operations import DatabaseOperations
class SpatialiteNullCheckOperator(SpatialOperator): class SpatialiteNullCheckOperator(SpatialOperator): ...
def as_sql(self, connection: Any, lookup: Any, template_params: Any, sql_params: Any): ...
class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
name: str = ... name: str = ...

View File

@@ -1,5 +1,7 @@
from typing import Any, Optional from typing import Any, Optional
from django.db.models.sql.compiler import _AsSqlType
class SpatialOperator: class SpatialOperator:
sql_template: Any = ... sql_template: Any = ...
op: Any = ... op: Any = ...
@@ -7,4 +9,4 @@ class SpatialOperator:
def __init__(self, op: Optional[Any] = ..., func: Optional[Any] = ...) -> None: ... def __init__(self, op: Optional[Any] = ..., func: Optional[Any] = ...) -> None: ...
@property @property
def default_template(self): ... def default_template(self): ...
def as_sql(self, connection: Any, lookup: Any, template_params: Any, sql_params: Any): ... def as_sql(self, connection: Any, lookup: Any, template_params: Any, sql_params: Any) -> _AsSqlType: ...

View File

@@ -1,13 +1,16 @@
from typing import Any, Optional from typing import Any, Optional
from django.db.models import Aggregate from django.db.models import Aggregate
from django.db.models.sql.compiler import _AsSqlType
class GeoAggregate(Aggregate): class GeoAggregate(Aggregate):
function: Any = ... function: Any = ...
is_extent: bool = ... is_extent: bool = ...
@property @property
def output_field(self): ... def output_field(self): ...
def as_sql(self, compiler: Any, connection: Any, function: Optional[Any] = ..., **extra_context: Any): ... def as_sql(
self, compiler: Any, connection: Any, function: Optional[Any] = ..., **extra_context: Any
) -> _AsSqlType: ...
def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any): ... def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any): ...
def resolve_expression( def resolve_expression(
self, self,

View File

@@ -2,6 +2,7 @@ from typing import Any, Optional
from django.db.models import Func from django.db.models import Func
from django.db.models import Transform as StandardTransform from django.db.models import Transform as StandardTransform
from django.db.models.sql.compiler import _AsSqlType
NUMERIC_TYPES: Any NUMERIC_TYPES: Any
@@ -11,7 +12,9 @@ class GeoFuncMixin:
def __init__(self, *expressions: Any, **extra: Any) -> None: ... def __init__(self, *expressions: Any, **extra: Any) -> None: ...
@property @property
def geo_field(self): ... def geo_field(self): ...
def as_sql(self, compiler: Any, connection: Any, function: Optional[Any] = ..., **extra_context: Any): ... def as_sql(
self, compiler: Any, connection: Any, function: Optional[Any] = ..., **extra_context: Any
) -> _AsSqlType: ...
def resolve_expression(self, *args: Any, **kwargs: Any): ... def resolve_expression(self, *args: Any, **kwargs: Any): ...
class GeoFunc(GeoFuncMixin, Func): ... class GeoFunc(GeoFuncMixin, Func): ...

View File

@@ -18,7 +18,6 @@ class GISLookup(Lookup):
rhs: Any = ... rhs: Any = ...
def process_rhs(self, compiler: Any, connection: Any): ... def process_rhs(self, compiler: Any, connection: Any): ...
def get_rhs_op(self, connection: Any, rhs: Any): ... def get_rhs_op(self, connection: Any, rhs: Any): ...
def as_sql(self, compiler: Any, connection: Any): ...
class OverlapsLeftLookup(GISLookup): class OverlapsLeftLookup(GISLookup):
lookup_name: str = ... lookup_name: str = ...
@@ -107,8 +106,7 @@ class DWithinLookup(DistanceLookupBase):
def process_distance(self, compiler: Any, connection: Any): ... def process_distance(self, compiler: Any, connection: Any): ...
def process_rhs(self, compiler: Any, connection: Any): ... def process_rhs(self, compiler: Any, connection: Any): ...
class DistanceLookupFromFunction(DistanceLookupBase): class DistanceLookupFromFunction(DistanceLookupBase): ...
def as_sql(self, compiler: Any, connection: Any): ...
class DistanceGTLookup(DistanceLookupFromFunction): class DistanceGTLookup(DistanceLookupFromFunction):
lookup_name: str = ... lookup_name: str = ...

View File

@@ -1,6 +1,7 @@
from typing import Any from typing import Any
from django.db.models.sql import compiler as compiler from django.db.models.sql import compiler as compiler
from django.db.models.sql.compiler import _AsSqlType
class SQLCompiler(compiler.SQLCompiler): class SQLCompiler(compiler.SQLCompiler):
def as_subquery_condition(self, alias: Any, columns: Any, compiler: Any): ... def as_subquery_condition(self, alias: Any, columns: Any, compiler: Any): ...
@@ -8,7 +9,7 @@ class SQLCompiler(compiler.SQLCompiler):
class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler): ... class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler): ...
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler): class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
def as_sql(self): ... def as_sql(self) -> _AsSqlType: ...
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): ... class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler): ...
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler): ... class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler): ...

View File

@@ -99,7 +99,7 @@ class BaseExpression:
def desc(self, **kwargs: Any) -> OrderBy: ... def desc(self, **kwargs: Any) -> OrderBy: ...
def reverse_ordering(self) -> BaseExpression: ... def reverse_ordering(self) -> BaseExpression: ...
def flatten(self) -> Iterator[BaseExpression]: ... def flatten(self) -> Iterator[BaseExpression]: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ... def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...
def deconstruct(self) -> Any: ... # fake def deconstruct(self) -> Any: ... # fake
class Expression(BaseExpression, Combinable): ... class Expression(BaseExpression, Combinable): ...

View File

@@ -6,7 +6,7 @@ from django.db.backends.ddl_references import Statement
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.expressions import BaseExpression, Combinable, Expression, Func from django.db.models.expressions import BaseExpression, Combinable, Expression, Func
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
from django.db.models.sql.compiler import SQLCompiler from django.db.models.sql.compiler import SQLCompiler, _AsSqlType
class Index: class Index:
model: Type[Model] model: Type[Model]
@@ -54,4 +54,4 @@ class IndexExpression(Func):
summarize: bool = ..., summarize: bool = ...,
for_save: bool = ..., for_save: bool = ...,
) -> IndexExpression: ... ) -> IndexExpression: ...
def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> Any: ... def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

View File

@@ -9,7 +9,7 @@ else:
from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.expressions import Expression, Func from django.db.models.expressions import Expression, Func
from django.db.models.query_utils import RegisterLookupMixin from django.db.models.query_utils import RegisterLookupMixin
from django.db.models.sql.compiler import SQLCompiler, _ParamsT, _ParamT from django.db.models.sql.compiler import SQLCompiler, _AsSqlType, _ParamT
from django.utils.datastructures import OrderedSet from django.utils.datastructures import OrderedSet
_L = TypeVar("_L", bound="Lookup") _L = TypeVar("_L", bound="Lookup")
@@ -30,16 +30,16 @@ class Lookup(Generic[_T]):
def get_source_expressions(self) -> List[Expression]: ... def get_source_expressions(self) -> List[Expression]: ...
def set_source_expressions(self, new_exprs: List[Expression]) -> None: ... def set_source_expressions(self, new_exprs: List[Expression]) -> None: ...
def get_prep_lookup(self) -> Any: ... def get_prep_lookup(self) -> Any: ...
def get_db_prep_lookup(self, value: _ParamT, connection: BaseDatabaseWrapper) -> Tuple[str, _ParamsT]: ... def get_db_prep_lookup(self, value: _ParamT, connection: BaseDatabaseWrapper) -> _AsSqlType: ...
def process_lhs( def process_lhs(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Optional[Expression] = ... self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Optional[Expression] = ...
) -> Tuple[str, _ParamsT]: ... ) -> _AsSqlType: ...
def process_rhs(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, _ParamsT]: ... def process_rhs(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...
def rhs_is_direct_value(self) -> bool: ... def rhs_is_direct_value(self) -> bool: ...
def relabeled_clone(self: _L, relabels: Mapping[str, str]) -> _L: ... def relabeled_clone(self: _L, relabels: Mapping[str, str]) -> _L: ...
def get_group_by_cols(self, alias: Optional[str] = ...) -> List[Expression]: ... def get_group_by_cols(self, alias: Optional[str] = ...) -> List[Expression]: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, _ParamsT]: ... def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...
def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, _ParamsT]: ... def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...
@property @property
def contains_aggregate(self) -> bool: ... def contains_aggregate(self) -> bool: ...
@property @property
@@ -58,24 +58,23 @@ class Transform(RegisterLookupMixin, Func):
class BuiltinLookup(Lookup[_T]): class BuiltinLookup(Lookup[_T]):
def process_lhs( def process_lhs(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Optional[Expression] = ... self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Optional[Expression] = ...
) -> Tuple[str, _ParamsT]: ... ) -> _AsSqlType: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, _ParamsT]: ...
def get_rhs_op(self, connection: BaseDatabaseWrapper, rhs: str) -> str: ... def get_rhs_op(self, connection: BaseDatabaseWrapper, rhs: str) -> str: ...
class FieldGetDbPrepValueMixin: class FieldGetDbPrepValueMixin:
get_db_prep_lookup_value_is_iterable: bool = ... get_db_prep_lookup_value_is_iterable: bool = ...
def get_db_prep_lookup(self, value: _ParamT, connection: BaseDatabaseWrapper) -> Tuple[str, _ParamsT]: ... def get_db_prep_lookup(self, value: _ParamT, connection: BaseDatabaseWrapper) -> _AsSqlType: ...
class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin): class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
get_db_prep_lookup_value_is_iterable: Literal[True] = ... get_db_prep_lookup_value_is_iterable: Literal[True] = ...
def get_prep_lookup(self) -> Iterable[Any]: ... def get_prep_lookup(self) -> Iterable[Any]: ...
def resolve_expression_parameter( def resolve_expression_parameter(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, sql: str, param: Any self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, sql: str, param: Any
) -> Tuple[str, _ParamsT]: ... ) -> _AsSqlType: ...
class PostgresOperatorLookup(FieldGetDbPrepValueMixin, Lookup[_T]): class PostgresOperatorLookup(FieldGetDbPrepValueMixin, Lookup[_T]):
postgres_operator: str = ... postgres_operator: str = ...
def as_postgresql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, _ParamsT]: ... def as_postgresql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...
class Exact(FieldGetDbPrepValueMixin, BuiltinLookup[_T]): ... class Exact(FieldGetDbPrepValueMixin, BuiltinLookup[_T]): ...
class IExact(BuiltinLookup[_T]): ... class IExact(BuiltinLookup[_T]): ...

View File

@@ -17,6 +17,8 @@ from typing import (
Union, Union,
) )
from django.db.models.sql.compiler import _AsSqlType
if sys.version_info < (3, 8): if sys.version_info < (3, 8):
from typing_extensions import Literal from typing_extensions import Literal
else: else:
@@ -105,4 +107,4 @@ class FilteredRelation:
def __init__(self, relation_name: str, *, condition: Q = ...) -> None: ... def __init__(self, relation_name: str, *, condition: Q = ...) -> None: ...
def clone(self) -> FilteredRelation: ... def clone(self) -> FilteredRelation: ...
def resolve_expression(self, *args: Any, **kwargs: Any) -> None: ... def resolve_expression(self, *args: Any, **kwargs: Any) -> None: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ... def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.fields.mixins import FieldCacheMixin from django.db.models.fields.mixins import FieldCacheMixin
from django.db.models.query_utils import FilteredRelation, PathInfo from django.db.models.query_utils import FilteredRelation, PathInfo
from django.db.models.sql.compiler import SQLCompiler from django.db.models.sql.compiler import SQLCompiler, _AsSqlType
class MultiJoin(Exception): class MultiJoin(Exception):
level: int = ... level: int = ...
@@ -31,7 +31,7 @@ class Join:
nullable: bool, nullable: bool,
filtered_relation: Optional[FilteredRelation] = ..., filtered_relation: Optional[FilteredRelation] = ...,
) -> None: ... ) -> None: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, List[Union[int, str]]]: ... def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...
def relabeled_clone(self, change_map: Dict[Optional[str], str]) -> Join: ... def relabeled_clone(self, change_map: Dict[Optional[str], str]) -> Join: ...
def equals(self, other: Union[BaseTable, Join], with_filtered_relation: bool) -> bool: ... def equals(self, other: Union[BaseTable, Join], with_filtered_relation: bool) -> bool: ...
def demote(self) -> Join: ... def demote(self) -> Join: ...
@@ -44,6 +44,6 @@ class BaseTable:
table_name: str = ... table_name: str = ...
table_alias: Optional[str] = ... table_alias: Optional[str] = ...
def __init__(self, table_name: str, alias: Optional[str]) -> None: ... def __init__(self, table_name: str, alias: Optional[str]) -> None: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, List[Any]]: ... def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...
def relabeled_clone(self, change_map: Dict[Optional[str], str]) -> BaseTable: ... def relabeled_clone(self, change_map: Dict[Optional[str], str]) -> BaseTable: ...
def equals(self, other: Join, with_filtered_relation: bool) -> bool: ... def equals(self, other: Join, with_filtered_relation: bool) -> bool: ...

View File

@@ -116,7 +116,6 @@ class Query(BaseExpression):
def get_initial_alias(self) -> str: ... def get_initial_alias(self) -> str: ...
def count_active_tables(self) -> int: ... def count_active_tables(self) -> int: ...
def resolve_expression(self, query: Query, *args: Any, **kwargs: Any) -> Query: ... # type: ignore def resolve_expression(self, query: Query, *args: Any, **kwargs: Any) -> Query: ... # type: ignore
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
def resolve_lookup_value(self, value: Any, can_reuse: Optional[Set[str]], allow_joins: bool) -> Any: ... def resolve_lookup_value(self, value: Any, can_reuse: Optional[Set[str]], allow_joins: bool) -> Any: ...
def solve_lookup_type( def solve_lookup_type(
self, lookup: str self, lookup: str

View File

@@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.expressions import Expression from django.db.models.expressions import Expression
from django.db.models.sql.compiler import SQLCompiler from django.db.models.sql.compiler import SQLCompiler, _AsSqlType
from django.db.models.sql.query import Query from django.db.models.sql.query import Query
from django.utils import tree from django.utils import tree
@@ -31,7 +31,9 @@ class WhereNode(tree.Node):
class NothingNode: class NothingNode:
contains_aggregate: bool = ... contains_aggregate: bool = ...
def as_sql(self, compiler: Optional[SQLCompiler] = ..., connection: Optional[BaseDatabaseWrapper] = ...) -> Any: ... def as_sql(
self, compiler: Optional[SQLCompiler] = ..., connection: Optional[BaseDatabaseWrapper] = ...
) -> _AsSqlType: ...
class ExtraWhere: class ExtraWhere:
contains_aggregate: bool = ... contains_aggregate: bool = ...
@@ -40,7 +42,7 @@ class ExtraWhere:
def __init__(self, sqls: Sequence[str], params: Optional[Union[Sequence[int], Sequence[str]]]) -> None: ... def __init__(self, sqls: Sequence[str], params: Optional[Union[Sequence[int], Sequence[str]]]) -> None: ...
def as_sql( def as_sql(
self, compiler: Optional[SQLCompiler] = ..., connection: Optional[BaseDatabaseWrapper] = ... self, compiler: Optional[SQLCompiler] = ..., connection: Optional[BaseDatabaseWrapper] = ...
) -> Tuple[str, Union[List[int], List[str]]]: ... ) -> _AsSqlType: ...
class SubqueryConstraint: class SubqueryConstraint:
contains_aggregate: bool = ... contains_aggregate: bool = ...
@@ -49,4 +51,4 @@ class SubqueryConstraint:
targets: List[str] = ... targets: List[str] = ...
query_object: Query = ... query_object: Query = ...
def __init__(self, alias: str, columns: List[str], targets: List[str], query_object: Query) -> None: ... def __init__(self, alias: str, columns: List[str], targets: List[str], query_object: Query) -> None: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, Tuple]: ... def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...