Improve hints for database connections (DatabaseWrapper) (#612)

* `django.db.{connection, connections, router}` are now hinted -- including `ConnectionHandler` and `ConnectionRouter` classes.
* Several improvements to `BaseDatabaseWrapper` attribute hints.
* In many places, database connections were hinted as `Any`, which I changed to `BaseDatabaseWrapper`.
* In a few places I added additional `SQLCompiler` hints.
* Minor tweaks to nearby code.
This commit is contained in:
Marti Raudsepp
2021-05-13 20:22:35 +03:00
committed by GitHub
parent d1dd95181a
commit 45f6dc0362
29 changed files with 166 additions and 105 deletions

View File

@@ -1,5 +1,6 @@
from django.core.management.base import BaseCommand as BaseCommand, CommandError as CommandError
from django.db import DEFAULT_DB_ALIAS as DEFAULT_DB_ALIAS, connections as connections
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.constants import LOOKUP_SEP as LOOKUP_SEP
from typing import Any, Dict, Iterable, List, Tuple
@@ -10,7 +11,9 @@ class Command(BaseCommand):
def normalize_col_name(
self, col_name: str, used_column_names: List[str], is_relation: bool
) -> Tuple[str, Dict[str, str], List[str]]: ...
def get_field_type(self, connection: Any, table_name: Any, row: Any) -> Tuple[str, Dict[str, str], List[str]]: ...
def get_field_type(
self, connection: BaseDatabaseWrapper, table_name: str, row: Any
) -> Tuple[str, Dict[str, str], List[str]]: ...
def get_meta(
self, table_name: str, constraints: Any, column_to_field_name: Any, is_view: Any, is_partition: Any
) -> List[str]: ...

View File

@@ -9,6 +9,7 @@ from django.core.management.sql import (
emit_pre_migrate_signal as emit_pre_migrate_signal,
)
from django.db import DEFAULT_DB_ALIAS as DEFAULT_DB_ALIAS, connections as connections, router as router
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.migrations.autodetector import MigrationAutodetector as MigrationAutodetector
from django.db.migrations.executor import MigrationExecutor as MigrationExecutor
from django.db.migrations.loader import AmbiguityError as AmbiguityError
@@ -23,6 +24,6 @@ class Command(BaseCommand):
interactive: bool = ...
start: float = ...
def migration_progress_callback(self, action: str, migration: Optional[Any] = ..., fake: bool = ...) -> None: ...
def sync_apps(self, connection: Any, app_labels: List[str]) -> None: ...
def sync_apps(self, connection: BaseDatabaseWrapper, app_labels: List[str]) -> None: ...
@staticmethod
def describe_operation(operation: Operation, backwards: bool) -> str: ...

View File

@@ -1,10 +1,11 @@
from django.apps import apps as apps
from django.core.management.base import BaseCommand as BaseCommand
from django.db import DEFAULT_DB_ALIAS as DEFAULT_DB_ALIAS, connections as connections
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.migrations.loader import MigrationLoader as MigrationLoader
from typing import Any, List, Optional
class Command(BaseCommand):
verbosity: int = ...
def show_list(self, connection: Any, app_names: Optional[List[str]] = ...) -> None: ...
def show_plan(self, connection: Any, app_names: Optional[List[str]] = ...) -> None: ...
def show_list(self, connection: BaseDatabaseWrapper, app_names: Optional[List[str]] = ...) -> None: ...
def show_plan(self, connection: BaseDatabaseWrapper, app_names: Optional[List[str]] = ...) -> None: ...

View File

@@ -1,9 +1,14 @@
from typing import Any, List
from django.core.management.color import Style
from django.db.backends.base.base import BaseDatabaseWrapper
def sql_flush(
style: Style, connection: Any, only_django: bool = ..., reset_sequences: bool = ..., allow_cascade: bool = ...
style: Style,
connection: BaseDatabaseWrapper,
only_django: bool = ...,
reset_sequences: bool = ...,
allow_cascade: bool = ...,
) -> List[str]: ...
def emit_pre_migrate_signal(verbosity: int, interactive: bool, db: str, **kwargs: Any) -> None: ...
def emit_post_migrate_signal(verbosity: int, interactive: bool, db: str, **kwargs: Any) -> None: ...

View File

@@ -1,5 +1,6 @@
from typing import Any
from .backends.base.base import BaseDatabaseWrapper
from .utils import (
DEFAULT_DB_ALIAS as DEFAULT_DB_ALIAS,
DJANGO_VERSION_PICKLE_KEY as DJANGO_VERSION_PICKLE_KEY,
@@ -11,16 +12,19 @@ from .utils import (
NotSupportedError as NotSupportedError,
InternalError as InternalError,
InterfaceError as InterfaceError,
ConnectionHandler as ConnectionHandler,
Error as Error,
ConnectionDoesNotExist as ConnectionDoesNotExist,
# Not exported in __all__
ConnectionHandler,
ConnectionRouter,
)
from . import migrations
connections: Any
router: Any
connection: Any
connections: ConnectionHandler
router: ConnectionRouter
# Actually DefaultConnectionProxy, but quacks exactly like BaseDatabaseWrapper, it's not worth distinguishing the two.
connection: BaseDatabaseWrapper
class DefaultConnectionProxy:
def __getattr__(self, item: str) -> Any: ...

View File

@@ -1,4 +1,5 @@
from typing import Any, Callable, Dict, Iterator, List, Optional
from datetime import tzinfo
from typing import Any, Callable, Dict, Iterator, List, Optional, Type, TypeVar
from django.db.backends.base.client import BaseDatabaseClient
from django.db.backends.base.creation import BaseDatabaseCreation
@@ -11,25 +12,29 @@ from django.db.backends.base.features import BaseDatabaseFeatures
from django.db.backends.base.introspection import BaseDatabaseIntrospection
from django.db.backends.base.operations import BaseDatabaseOperations
NO_DB_ALIAS: str
_T = TypeVar("_T", bound="BaseDatabaseWrapper")
_ExecuteWrapper = Callable[[Callable[[str, Any, bool, Dict[str, Any]], Any], str, Any, bool, Dict[str, Any]], Any]
class BaseDatabaseWrapper:
data_types: Any = ...
data_types_suffix: Any = ...
data_type_check_constraints: Any = ...
ops: Any = ...
data_types: Dict[str, str] = ...
data_types_suffix: Dict[str, str] = ...
data_type_check_constraints: Dict[str, str] = ...
vendor: str = ...
display_name: str = ...
SchemaEditorClass: Optional[BaseDatabaseSchemaEditor] = ...
client_class: Any = ...
creation_class: Any = ...
features_class: Any = ...
introspection_class: Any = ...
ops_class: Any = ...
validation_class: Any = ...
SchemaEditorClass: Type[BaseDatabaseSchemaEditor] = ...
client_class: Type[BaseDatabaseClient] = ...
creation_class: Type[BaseDatabaseCreation] = ...
features_class: Type[BaseDatabaseFeatures] = ...
introspection_class: Type[BaseDatabaseIntrospection] = ...
ops_class: Type[BaseDatabaseOperations] = ...
validation_class: Type[BaseDatabaseValidation] = ...
queries_limit: int = ...
connection: Any = ...
settings_dict: Any = ...
settings_dict: Dict[str, Any] = ...
alias: str = ...
queries_log: Any = ...
force_debug_cursor: bool = ...
@@ -39,24 +44,23 @@ class BaseDatabaseWrapper:
savepoint_ids: Any = ...
commit_on_exit: bool = ...
needs_rollback: bool = ...
close_at: Optional[Any] = ...
close_at: Optional[float] = ...
closed_in_transaction: bool = ...
errors_occurred: bool = ...
allow_thread_sharing: bool = ...
run_on_commit: List[Any] = ...
run_on_commit: List[Callable[[], None]] = ...
run_commit_hooks_on_set_autocommit_on: bool = ...
execute_wrappers: List[Any] = ...
execute_wrappers: List[_ExecuteWrapper] = ...
client: BaseDatabaseClient = ...
creation: BaseDatabaseCreation = ...
features: BaseDatabaseFeatures = ...
introspection: BaseDatabaseIntrospection = ...
ops: BaseDatabaseOperations = ...
validation: BaseDatabaseValidation = ...
def __init__(
self, settings_dict: Dict[str, Dict[str, str]], alias: str = ..., allow_thread_sharing: bool = ...
) -> None: ...
def __init__(self, settings_dict: Dict[str, Any], alias: str = ..., allow_thread_sharing: bool = ...) -> None: ...
def ensure_timezone(self) -> bool: ...
def timezone(self): ...
def timezone_name(self): ...
def timezone(self) -> tzinfo: ...
def timezone_name(self) -> str: ...
@property
def queries_logged(self) -> bool: ...
@property
@@ -86,7 +90,7 @@ class BaseDatabaseWrapper:
def disable_constraint_checking(self): ...
def enable_constraint_checking(self) -> None: ...
def check_constraints(self, table_names: Optional[Any] = ...) -> None: ...
def is_usable(self) -> None: ...
def is_usable(self) -> bool: ...
def close_if_unusable_or_obsolete(self) -> None: ...
def validate_thread_sharing(self) -> None: ...
def prepare_database(self) -> None: ...
@@ -96,7 +100,7 @@ class BaseDatabaseWrapper:
def make_cursor(self, cursor: CursorWrapper) -> CursorWrapper: ...
def temporary_connection(self) -> None: ...
def schema_editor(self, *args: Any, **kwargs: Any) -> BaseDatabaseSchemaEditor: ...
def on_commit(self, func: Callable) -> None: ...
def on_commit(self, func: Callable[[], None]) -> None: ...
def run_and_clear_commit_hooks(self) -> None: ...
def execute_wrapper(self, wrapper: Callable) -> Iterator[None]: ...
def copy(self, alias: None = ..., allow_thread_sharing: None = ...) -> Any: ...
def execute_wrapper(self, wrapper: _ExecuteWrapper) -> Iterator[None]: ...
def copy(self: _T, alias: Optional[str] = ...) -> _T: ...

View File

@@ -4,6 +4,6 @@ from django.db.backends.base.base import BaseDatabaseWrapper
class BaseDatabaseClient:
executable_name: Any = ...
connection: Any = ...
connection: BaseDatabaseWrapper
def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
def runshell(self) -> None: ...

View File

@@ -5,7 +5,7 @@ from django.db.backends.base.base import BaseDatabaseWrapper
TEST_DATABASE_PREFIX: str
class BaseDatabaseCreation:
connection: Any = ...
connection: BaseDatabaseWrapper
def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
def create_test_db(
self, verbosity: int = ..., autoclobber: bool = ..., serialize: bool = ..., keepdb: bool = ...

View File

@@ -94,7 +94,7 @@ class BaseDatabaseFeatures:
db_functions_convert_bytes_to_str: bool = ...
supported_explain_formats: Any = ...
validates_explain_options: bool = ...
connection: Any = ...
connection: BaseDatabaseWrapper
def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
def supports_explaining_query_execution(self) -> bool: ...
def supports_transactions(self): ...

View File

@@ -11,7 +11,7 @@ FieldInfo = namedtuple("FieldInfo", "name type_code display_size internal_size p
class BaseDatabaseIntrospection:
data_types_reverse: Any = ...
connection: Any = ...
connection: BaseDatabaseWrapper
def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
def get_field_type(self, data_type: str, description: FieldInfo) -> str: ...
def table_name_converter(self, name: str) -> str: ...

View File

@@ -8,12 +8,8 @@ from django.db.backends.utils import CursorWrapper
from django.db.models.base import Model
from django.db.models.expressions import Case, Expression
from django.db.models.sql.compiler import SQLCompiler
from django.db import DefaultConnectionProxy
from django.db.models.fields import Field
_Connection = Union[DefaultConnectionProxy, BaseDatabaseWrapper]
class BaseDatabaseOperations:
compiler_module: str = ...
integer_field_ranges: Any = ...
@@ -25,9 +21,9 @@ class BaseDatabaseOperations:
UNBOUNDED_PRECEDING: Any = ...
UNBOUNDED_FOLLOWING: Any = ...
CURRENT_ROW: str = ...
explain_prefix: Any = ...
connection: _Connection = ...
def __init__(self, connection: Optional[_Connection]) -> None: ...
explain_prefix: Optional[str] = ...
connection: BaseDatabaseWrapper
def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
def autoinc_sql(self, table: str, column: str) -> None: ...
def bulk_batch_size(self, fields: Any, objs: Any): ...
def cache_key_culling_sql(self) -> str: ...
@@ -88,7 +84,7 @@ class BaseDatabaseOperations:
def year_lookup_bounds_for_datetime_field(self, value: int) -> List[str]: ...
def get_db_converters(self, expression: Expression) -> List[Any]: ...
def convert_durationfield_value(
self, value: Optional[float], expression: Expression, connection: _Connection
self, value: Optional[float], expression: Expression, connection: BaseDatabaseWrapper
) -> Optional[timedelta]: ...
def check_expression_support(self, expression: Any) -> None: ...
def combine_expression(self, connector: str, sub_expressions: List[str]) -> str: ...

View File

@@ -1,5 +1,6 @@
from typing import Any, ContextManager, List, Optional, Sequence, Tuple, Type, Union
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.backends.ddl_references import Statement
from django.db.models.base import Model
from django.db.models.indexes import Index
@@ -35,11 +36,11 @@ class BaseDatabaseSchemaEditor(ContextManager[Any]):
sql_create_pk: str = ...
sql_delete_pk: str = ...
sql_delete_procedure: str = ...
connection: Any = ...
connection: BaseDatabaseWrapper = ...
collect_sql: bool = ...
collected_sql: Any = ...
atomic_migration: Any = ...
def __init__(self, connection: Any, collect_sql: bool = ..., atomic: bool = ...) -> None: ...
def __init__(self, connection: BaseDatabaseWrapper, collect_sql: bool = ..., atomic: bool = ...) -> None: ...
deferred_sql: Any = ...
atomic: Any = ...
def __enter__(self) -> BaseDatabaseSchemaEditor: ...

View File

@@ -5,7 +5,7 @@ from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.fields import Field
class BaseDatabaseValidation:
connection: Any = ...
connection: BaseDatabaseWrapper = ...
def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
def check(self, **kwargs: Any) -> List[Any]: ...
def check_field(self, field: Field, **kwargs: Any) -> List[Any]: ...

View File

@@ -1,6 +1,5 @@
from typing import Any, Callable, List, Optional, Set, Tuple, Union
from django.db import DefaultConnectionProxy
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.migrations.migration import Migration
@@ -9,13 +8,13 @@ from .recorder import MigrationRecorder
from .state import ProjectState
class MigrationExecutor:
connection: Any = ...
connection: BaseDatabaseWrapper = ...
loader: MigrationLoader = ...
recorder: MigrationRecorder = ...
progress_callback: Callable = ...
def __init__(
self,
connection: Optional[Union[DefaultConnectionProxy, BaseDatabaseWrapper]],
connection: Optional[BaseDatabaseWrapper],
progress_callback: Optional[Callable] = ...,
) -> None: ...
def migration_plan(

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, Optional, Sequence, Set, Tuple, Union
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.migrations.migration import Migration
from django.db.migrations.state import ProjectState
@@ -13,11 +14,13 @@ from .exceptions import (
MIGRATIONS_MODULE_NAME: str
class MigrationLoader:
connection: Any = ...
connection: Optional[BaseDatabaseWrapper] = ...
disk_migrations: Dict[Tuple[str, str], Migration] = ...
applied_migrations: Set[Tuple[str, str]] = ...
ignore_no_migrations: bool = ...
def __init__(self, connection: Any, load: bool = ..., ignore_no_migrations: bool = ...) -> None: ...
def __init__(
self, connection: Optional[BaseDatabaseWrapper], load: bool = ..., ignore_no_migrations: bool = ...
) -> None: ...
@classmethod
def migrations_module(cls, app_label: str) -> Tuple[Optional[str], bool]: ...
unmigrated_apps: Set[str] = ...
@@ -31,7 +34,7 @@ class MigrationLoader:
graph: Any = ...
replacements: Any = ...
def build_graph(self) -> None: ...
def check_consistent_history(self, connection: Any) -> None: ...
def check_consistent_history(self, connection: BaseDatabaseWrapper) -> None: ...
def detect_conflicts(self) -> Dict[str, Set[str]]: ...
def project_state(
self, nodes: Optional[Union[Tuple[str, str], Sequence[Tuple[str, str]]]] = ..., at_end: bool = ...

View File

@@ -10,8 +10,8 @@ class MigrationRecorder:
app: Any = ...
name: Any = ...
applied: Any = ...
connection: Optional[BaseDatabaseWrapper] = ...
def __init__(self, connection: Optional[BaseDatabaseWrapper]) -> None: ...
connection: BaseDatabaseWrapper = ...
def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
@property
def migration_qs(self) -> QuerySet: ...
def has_table(self) -> bool: ...

View File

@@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set,
from django.db.models.lookups import Lookup
from django.db.models.sql.compiler import SQLCompiler
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import Q, QuerySet
from django.db.models.fields import Field
from django.db.models.query import _BaseQuerySet
@@ -12,7 +13,9 @@ from django.db.models.query import _BaseQuerySet
_OutputField = Union[Field, str]
class SQLiteNumericMixin:
def as_sqlite(self, compiler: SQLCompiler, connection: Any, **extra_context: Any) -> Tuple[str, List[float]]: ...
def as_sqlite(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any
) -> Tuple[str, List[float]]: ...
_Self = TypeVar("_Self")
@@ -57,7 +60,7 @@ class BaseExpression:
filterable: bool = ...
window_compatible: bool = ...
def __init__(self, output_field: Optional[_OutputField] = ...) -> None: ...
def get_db_converters(self, connection: Any) -> List[Callable]: ...
def get_db_converters(self, connection: BaseDatabaseWrapper) -> List[Callable]: ...
def get_source_expressions(self) -> List[Any]: ...
def set_source_expressions(self, exprs: Sequence[Combinable]) -> None: ...
@property
@@ -91,11 +94,11 @@ class BaseExpression:
def reverse_ordering(self): ...
def flatten(self) -> Iterator[Expression]: ...
def deconstruct(self) -> Any: ...
def as_sqlite(self, compiler: SQLCompiler, connection: Any) -> Any: ...
def as_sql(self, compiler: SQLCompiler, connection: Any, **extra_context: Any) -> Any: ...
def as_mysql(self, compiler: Any, connection: Any) -> Any: ...
def as_postgresql(self, compiler: Any, connection: Any) -> Any: ...
def as_oracle(self, compiler: Any, connection: Any): ...
def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> Any: ...
def as_mysql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
def as_postgresql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
class Expression(BaseExpression, Combinable): ...
@@ -216,7 +219,9 @@ class WindowFrame(Expression):
template: str = ...
frame_type: str = ...
def __init__(self, start: Optional[int] = ..., end: Optional[int] = ...) -> None: ...
def window_frame_start_end(self, connection: Any, start: Optional[int], end: Optional[int]) -> Tuple[int, int]: ...
def window_frame_start_end(
self, connection: BaseDatabaseWrapper, start: Optional[int], end: Optional[int]
) -> Tuple[int, int]: ...
class RowRange(WindowFrame): ...
class ValueRange(WindowFrame): ...

View File

@@ -18,9 +18,9 @@ from typing import (
)
from django.core.checks import CheckMessage
from django.db.models import Model
from django.core.exceptions import FieldDoesNotExist as FieldDoesNotExist
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import Model
from django.db.models.expressions import Combinable, Col
from django.db.models.query_utils import RegisterLookupMixin
from django.forms import Field as FormField, Widget
@@ -110,12 +110,12 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
def __get__(self: _T, instance, owner) -> _T: ...
def deconstruct(self) -> Any: ...
def set_attributes_from_name(self, name: str) -> None: ...
def db_type(self, connection: Any) -> str: ...
def db_parameters(self, connection: Any) -> Dict[str, str]: ...
def db_type(self, connection: BaseDatabaseWrapper) -> str: ...
def db_parameters(self, connection: BaseDatabaseWrapper) -> Dict[str, str]: ...
def pre_save(self, model_instance: Model, add: bool) -> Any: ...
def get_prep_value(self, value: Any) -> Any: ...
def get_db_prep_value(self, value: Any, connection: Any, prepared: bool) -> Any: ...
def get_db_prep_save(self, value: Any, connection: Any) -> Any: ...
def get_db_prep_value(self, value: Any, connection: BaseDatabaseWrapper, prepared: bool) -> Any: ...
def get_db_prep_save(self, value: Any, connection: BaseDatabaseWrapper) -> Any: ...
def get_internal_type(self) -> str: ...
# TODO: plugin support
def formfield(self, **kwargs) -> Any: ...
@@ -150,7 +150,7 @@ class IntegerField(Field[_ST, _GT]):
_pyi_lookup_exact_type: Union[str, int]
class PositiveIntegerRelDbTypeMixin:
def rel_db_type(self, connection: Any): ...
def rel_db_type(self, connection: BaseDatabaseWrapper) -> str: ...
class PositiveIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField[_ST, _GT]): ...
class PositiveSmallIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField[_ST, _GT]): ...

View File

@@ -1,7 +1,9 @@
from . import Field
from .mixins import CheckFieldDefaultMixin
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import lookups
from django.db.models.lookups import PostgresOperatorLookup, Transform
from django.db.models.sql.compiler import SQLCompiler
from typing import Any, Optional, Callable
class JSONField(CheckFieldDefaultMixin, Field):
@@ -25,7 +27,7 @@ class JSONExact(lookups.Exact): ...
class KeyTransform(Transform):
key_name: Any = ...
def __init__(self, key_name: Any, *args: Any, **kwargs: Any) -> None: ...
def preprocess_lhs(self, compiler: Any, connection: Any, lhs_only: bool = ...): ...
def preprocess_lhs(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs_only: bool = ...) -> Any: ...
class KeyTextTransform(KeyTransform): ...

View File

@@ -27,15 +27,19 @@ class Lookup(Generic[_T]):
def get_source_expressions(self) -> List[Expression]: ...
def set_source_expressions(self, new_exprs: List[Expression]) -> None: ...
def get_prep_lookup(self) -> Any: ...
def get_db_prep_lookup(self, value: Union[int, str], connection: BaseDatabaseWrapper) -> Tuple[str, List[SafeText]]: ...
def get_db_prep_lookup(
self, value: Union[int, str], connection: BaseDatabaseWrapper
) -> Tuple[str, List[SafeText]]: ...
def process_lhs(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Optional[Expression] = ...
) -> Tuple[str, List[Union[int, str]]]: ...
def process_rhs(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, List[Union[int, str]]]: ...
def process_rhs(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
) -> Tuple[str, List[Union[int, str]]]: ...
def rhs_is_direct_value(self) -> bool: ...
def relabeled_clone(self: _T, relabels: Mapping[str, str]) -> _T: ...
def get_group_by_cols(self) -> List[Expression]: ...
def as_sql(self, compiler: Any, connection: Any) -> Any: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
def contains_aggregate(self) -> bool: ...
def contains_over_clause(self) -> bool: ...
@property
@@ -61,7 +65,7 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
class PostgresOperatorLookup(FieldGetDbPrepValueMixin, Lookup):
postgres_operator: str = ...
def as_postgresql(self, compiler: Any, connection: Any): ...
def as_postgresql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
class Exact(FieldGetDbPrepValueMixin, BuiltinLookup): ...
class IExact(BuiltinLookup): ...
@@ -78,7 +82,7 @@ class IntegerGreaterThanOrEqual(IntegerFieldFloatRounding, GreaterThanOrEqual[Un
class IntegerLessThan(IntegerFieldFloatRounding, LessThan[Union[int, float]]): ...
class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
def split_parameter_list_as_sql(self, compiler: Any, connection: Any): ...
def split_parameter_list_as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
class PatternLookup(BuiltinLookup[str]):
param_pattern: str = ...
@@ -109,7 +113,7 @@ class YearLte(YearComparisonLookup): ...
class UUIDTextMixin:
rhs: Any = ...
def process_rhs(self, qn: Any, connection: Any): ...
def process_rhs(self, qn: Any, connection: BaseDatabaseWrapper) -> Any: ...
class UUIDIExact(UUIDTextMixin, IExact): ...
class UUIDContains(UUIDTextMixin, Contains): ...

View File

@@ -1,6 +1,7 @@
from collections import namedtuple
from typing import Any, Collection, Dict, Iterator, List, Mapping, Optional, Sequence, Set, Tuple, Type
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.base import Model
from django.db.models.fields.mixins import FieldCacheMixin
from django.db.models.sql.compiler import SQLCompiler
@@ -20,7 +21,7 @@ class QueryWrapper:
contains_aggregate: bool = ...
data: Tuple[str, List[Any]] = ...
def __init__(self, sql: str, params: List[Any]) -> None: ...
def as_sql(self, compiler: SQLCompiler = ..., connection: Any = ...) -> Any: ...
def as_sql(self, compiler: Optional[SQLCompiler] = ..., connection: Optional[BaseDatabaseWrapper] = ...) -> Any: ...
class Q(tree.Node):
AND: str = ...
@@ -76,4 +77,4 @@ class FilteredRelation:
def __init__(self, relation_name: str, *, condition: Any = ...) -> None: ...
def clone(self) -> FilteredRelation: ...
def resolve_expression(self, *args: Any, **kwargs: Any) -> None: ...
def as_sql(self, compiler: SQLCompiler, connection: Any) -> Any: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...

View File

@@ -4,23 +4,25 @@ from itertools import chain
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union
from uuid import UUID
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.base import Model
from django.db.models.expressions import BaseExpression, Expression
from django.db.models.sql.query import Query, RawQuery
FORCE: Any
class SQLCompiler:
query: Any = ...
connection: Any = ...
connection: BaseDatabaseWrapper = ...
using: Any = ...
quote_cache: Any = ...
select: Any = ...
annotation_col_map: Any = ...
klass_info: Any = ...
ordering_parts: Any = ...
def __init__(self, query: Union[Query, RawQuery], connection: Any, using: Optional[str]) -> None: ...
def __init__(
self, query: Union[Query, RawQuery], connection: BaseDatabaseWrapper, using: Optional[str]
) -> None: ...
col_count: Any = ...
def setup_query(self) -> None: ...
has_extra_select: Any = ...

View File

@@ -1,6 +1,7 @@
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.fields.mixins import FieldCacheMixin
from django.db.models.query_utils import FilteredRelation, PathInfo
from django.db.models.sql.compiler import SQLCompiler
@@ -31,7 +32,7 @@ class Join:
nullable: bool,
filtered_relation: Optional[FilteredRelation] = ...,
) -> None: ...
def as_sql(self, compiler: SQLCompiler, connection: Any) -> Tuple[str, List[Union[int, str]]]: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, List[Union[int, str]]]: ...
def relabeled_clone(self, change_map: Union[Dict[str, str], OrderedDict]) -> Join: ...
def equals(self, other: Union[BaseTable, Join], with_filtered_relation: bool) -> bool: ...
def demote(self) -> Join: ...
@@ -44,6 +45,6 @@ class BaseTable:
table_name: str = ...
table_alias: Optional[str] = ...
def __init__(self, table_name: str, alias: Optional[str]) -> None: ...
def as_sql(self, compiler: SQLCompiler, connection: Any) -> Tuple[str, List[Any]]: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, List[Any]]: ...
def relabeled_clone(self, change_map: OrderedDict) -> BaseTable: ...
def equals(self, other: Join, with_filtered_relation: bool) -> bool: ...

View File

@@ -2,15 +2,15 @@ import collections
from collections import OrderedDict, namedtuple
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, Union, Iterable
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import Expression, Field, FilteredRelation, Model, Q, QuerySet
from django.db.models.expressions import Combinable
from django.db.models.lookups import Lookup, Transform
from django.db.models.query_utils import PathInfo, RegisterLookupMixin
from django.db.models.sql.compiler import SQLCompiler
from django.db.models.sql.datastructures import BaseTable
from django.db.models.sql.where import WhereNode
from django.db.models import Expression, Field, FilteredRelation, Model, Q, QuerySet
from django.db.models.expressions import Combinable
JoinInfo = namedtuple("JoinInfo", ["final_field", "targets", "opts", "joins", "path", "transform_function"])
class RawQuery:
@@ -81,7 +81,7 @@ class Query:
def has_select_fields(self) -> bool: ...
def sql_with_params(self) -> Tuple[str, Tuple]: ...
def __deepcopy__(self, memo: Dict[str, Any]) -> Query: ...
def get_compiler(self, using: Optional[str] = ..., connection: Any = ...) -> SQLCompiler: ...
def get_compiler(self, using: Optional[str] = ..., connection: BaseDatabaseWrapper = ...) -> SQLCompiler: ...
def clone(self) -> Query: ...
def chain(self, klass: Optional[Type[Query]] = ...) -> Query: ...
def relabeled_clone(self, change_map: Union[Dict[Any, Any], OrderedDict]) -> Query: ...
@@ -101,7 +101,7 @@ class Query:
def get_initial_alias(self) -> str: ...
def count_active_tables(self) -> int: ...
def resolve_expression(self, query: Query, *args: Any, **kwargs: Any) -> Query: ...
def as_sql(self, compiler: SQLCompiler, connection: Any) -> Any: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
def resolve_lookup_value(self, value: Any, can_reuse: Optional[Set[str]], allow_joins: bool) -> Any: ...
def solve_lookup_type(self, lookup: str) -> Tuple[Sequence[str], Sequence[str], bool]: ...
def build_filter(

View File

@@ -1,6 +1,7 @@
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.expressions import Expression
from django.db.models.sql.compiler import SQLCompiler
from django.db.models.sql.query import Query
@@ -18,7 +19,7 @@ class WhereNode(tree.Node):
resolved: bool = ...
conditional: bool = ...
def split_having(self, negated: bool = ...) -> Tuple[Optional[WhereNode], Optional[WhereNode]]: ...
def as_sql(self, compiler: SQLCompiler, connection: Any) -> Any: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
def get_group_by_cols(self) -> List[Expression]: ...
def relabel_aliases(self, change_map: Union[Dict[Optional[str], str], OrderedDict]) -> None: ...
def clone(self) -> WhereNode: ...
@@ -27,14 +28,16 @@ class WhereNode(tree.Node):
class NothingNode:
contains_aggregate: bool = ...
def as_sql(self, compiler: SQLCompiler = ..., connection: Any = ...) -> Any: ...
def as_sql(self, compiler: Optional[SQLCompiler] = ..., connection: Optional[BaseDatabaseWrapper] = ...) -> Any: ...
class ExtraWhere:
contains_aggregate: bool = ...
sqls: List[str] = ...
params: Optional[Union[List[int], List[str]]] = ...
def __init__(self, sqls: List[str], params: Optional[Union[List[int], List[str]]]) -> None: ...
def as_sql(self, compiler: SQLCompiler = ..., connection: Any = ...) -> Tuple[str, Union[List[int], List[str]]]: ...
def as_sql(
self, compiler: Optional[SQLCompiler] = ..., connection: Optional[BaseDatabaseWrapper] = ...
) -> Tuple[str, Union[List[int], List[str]]]: ...
class SubqueryConstraint:
contains_aggregate: bool = ...
@@ -43,4 +46,4 @@ class SubqueryConstraint:
targets: List[str] = ...
query_object: Query = ...
def __init__(self, alias: str, columns: List[str], targets: List[str], query_object: Query) -> None: ...
def as_sql(self, compiler: SQLCompiler, connection: Any) -> Tuple[str, Tuple]: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, Tuple]: ...

View File

@@ -1,4 +1,8 @@
from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable, List, Optional, Iterator, Type
from django.apps import AppConfig
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import Model
DEFAULT_DB_ALIAS: str
DJANGO_VERSION_PICKLE_KEY: str
@@ -22,14 +26,22 @@ class ConnectionHandler:
def __init__(self, databases: Dict[str, Dict[str, Optional[Any]]] = ...) -> None: ...
def ensure_defaults(self, alias: str) -> None: ...
def prepare_test_settings(self, alias: str) -> None: ...
def __getitem__(self, alias: str) -> Any: ...
def __setitem__(self, key: Any, value: Any) -> None: ...
def __delitem__(self, key: Any) -> None: ...
def __iter__(self): ...
def all(self) -> List[Any]: ...
def __getitem__(self, alias: str) -> BaseDatabaseWrapper: ...
def __setitem__(self, key: str, value: BaseDatabaseWrapper) -> None: ...
def __delitem__(self, key: BaseDatabaseWrapper) -> None: ...
def __iter__(self) -> Iterator[str]: ...
def all(self) -> List[BaseDatabaseWrapper]: ...
def close_all(self) -> None: ...
class ConnectionRouter:
def __init__(self, routers: Optional[Iterable[Any]] = ...) -> None: ...
@property
def routers(self) -> List[Any]: ...
def db_for_read(self, model: Type[Model], **hints: Any) -> str: ...
def db_for_write(self, model: Type[Model], **hints: Any) -> str: ...
def allow_relation(self, obj1: Model, obj2: Model, **hints: Any) -> bool: ...
def allow_migrate(self, db: str, app_label: str, **hints: Any) -> bool: ...
def allow_migrate_model(self, db: str, model: Type[Model]) -> bool: ...
def get_migratable_models(
self, app_config: AppConfig, db: str, include_auto_created: bool = ...
) -> List[Type[Model]]: ...

View File

@@ -20,7 +20,7 @@ from django.db import connections as connections # noqa: F401
class _AssertNumQueriesContext(CaptureQueriesContext):
test_case: SimpleTestCase = ...
num: int = ...
def __init__(self, test_case: Any, num: Any, connection: Any) -> None: ...
def __init__(self, test_case: Any, num: Any, connection: BaseDatabaseWrapper) -> None: ...
class _AssertTemplateUsedContext:
test_case: SimpleTestCase = ...
@@ -200,7 +200,7 @@ class LiveServerThread(threading.Thread):
is_ready: threading.Event = ...
error: Optional[ImproperlyConfigured] = ...
static_handler: Type[WSGIHandler] = ...
connections_override: Dict[str, Any] = ...
connections_override: Dict[str, BaseDatabaseWrapper] = ...
def __init__(
self,
host: str,

View File

@@ -27,6 +27,7 @@ from django.test.runner import DiscoverRunner
from django.test.testcases import SimpleTestCase
from django.conf import LazySettings, Settings
from django.db.backends.base.base import BaseDatabaseWrapper
_TestClass = Type[SimpleTestCase]
_DecoratedTest = Union[Callable, _TestClass]
@@ -84,11 +85,11 @@ class override_system_checks(TestContextDecorator):
old_deployment_checks: Set[Callable] = ...
class CaptureQueriesContext:
connection: Any = ...
connection: BaseDatabaseWrapper = ...
force_debug_cursor: bool = ...
initial_queries: int = ...
final_queries: Optional[int] = ...
def __init__(self, connection: Any) -> None: ...
def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
def __iter__(self): ...
def __getitem__(self, index: int) -> Dict[str, str]: ...
def __len__(self) -> int: ...

View File

@@ -0,0 +1,13 @@
- case: raw_default_connection
main: |
from django.db import connection
with connection.cursor() as cursor:
reveal_type(cursor) # N: Revealed type is 'django.db.backends.utils.CursorWrapper'
cursor.execute("SELECT %s", [123])
- case: raw_connections
main: |
from django.db import connections
reveal_type(connections["test"]) # N: Revealed type is 'django.db.backends.base.base.BaseDatabaseWrapper'
for connection in connections.all():
with connection.cursor() as cursor:
reveal_type(cursor) # N: Revealed type is 'django.db.backends.utils.CursorWrapper'