Improve hints for database connections (DatabaseWrapper) (#612)

* `django.db.{connection, connections, router}` are now hinted -- including `ConnectionHandler` and `ConnectionRouter` classes.
* Several improvements to `BaseDatabaseWrapper` attribute hints.
* In many places, database connections were hinted as `Any`, which I changed to `BaseDatabaseWrapper`.
* In a few places I added additional `SQLCompiler` hints.
* Minor tweaks to nearby code.
This commit is contained in:
Marti Raudsepp
2021-05-13 20:22:35 +03:00
committed by GitHub
parent d1dd95181a
commit 45f6dc0362
29 changed files with 166 additions and 105 deletions

View File

@@ -1,5 +1,6 @@
from django.core.management.base import BaseCommand as BaseCommand, CommandError as CommandError from django.core.management.base import BaseCommand as BaseCommand, CommandError as CommandError
from django.db import DEFAULT_DB_ALIAS as DEFAULT_DB_ALIAS, connections as connections from django.db import DEFAULT_DB_ALIAS as DEFAULT_DB_ALIAS, connections as connections
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.constants import LOOKUP_SEP as LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP as LOOKUP_SEP
from typing import Any, Dict, Iterable, List, Tuple from typing import Any, Dict, Iterable, List, Tuple
@@ -10,7 +11,9 @@ class Command(BaseCommand):
def normalize_col_name( def normalize_col_name(
self, col_name: str, used_column_names: List[str], is_relation: bool self, col_name: str, used_column_names: List[str], is_relation: bool
) -> Tuple[str, Dict[str, str], List[str]]: ... ) -> Tuple[str, Dict[str, str], List[str]]: ...
def get_field_type(self, connection: Any, table_name: Any, row: Any) -> Tuple[str, Dict[str, str], List[str]]: ... def get_field_type(
self, connection: BaseDatabaseWrapper, table_name: str, row: Any
) -> Tuple[str, Dict[str, str], List[str]]: ...
def get_meta( def get_meta(
self, table_name: str, constraints: Any, column_to_field_name: Any, is_view: Any, is_partition: Any self, table_name: str, constraints: Any, column_to_field_name: Any, is_view: Any, is_partition: Any
) -> List[str]: ... ) -> List[str]: ...

View File

@@ -9,6 +9,7 @@ from django.core.management.sql import (
emit_pre_migrate_signal as emit_pre_migrate_signal, emit_pre_migrate_signal as emit_pre_migrate_signal,
) )
from django.db import DEFAULT_DB_ALIAS as DEFAULT_DB_ALIAS, connections as connections, router as router from django.db import DEFAULT_DB_ALIAS as DEFAULT_DB_ALIAS, connections as connections, router as router
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.migrations.autodetector import MigrationAutodetector as MigrationAutodetector from django.db.migrations.autodetector import MigrationAutodetector as MigrationAutodetector
from django.db.migrations.executor import MigrationExecutor as MigrationExecutor from django.db.migrations.executor import MigrationExecutor as MigrationExecutor
from django.db.migrations.loader import AmbiguityError as AmbiguityError from django.db.migrations.loader import AmbiguityError as AmbiguityError
@@ -23,6 +24,6 @@ class Command(BaseCommand):
interactive: bool = ... interactive: bool = ...
start: float = ... start: float = ...
def migration_progress_callback(self, action: str, migration: Optional[Any] = ..., fake: bool = ...) -> None: ... def migration_progress_callback(self, action: str, migration: Optional[Any] = ..., fake: bool = ...) -> None: ...
def sync_apps(self, connection: Any, app_labels: List[str]) -> None: ... def sync_apps(self, connection: BaseDatabaseWrapper, app_labels: List[str]) -> None: ...
@staticmethod @staticmethod
def describe_operation(operation: Operation, backwards: bool) -> str: ... def describe_operation(operation: Operation, backwards: bool) -> str: ...

View File

@@ -1,10 +1,11 @@
from django.apps import apps as apps from django.apps import apps as apps
from django.core.management.base import BaseCommand as BaseCommand from django.core.management.base import BaseCommand as BaseCommand
from django.db import DEFAULT_DB_ALIAS as DEFAULT_DB_ALIAS, connections as connections from django.db import DEFAULT_DB_ALIAS as DEFAULT_DB_ALIAS, connections as connections
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.migrations.loader import MigrationLoader as MigrationLoader from django.db.migrations.loader import MigrationLoader as MigrationLoader
from typing import Any, List, Optional from typing import Any, List, Optional
class Command(BaseCommand): class Command(BaseCommand):
verbosity: int = ... verbosity: int = ...
def show_list(self, connection: Any, app_names: Optional[List[str]] = ...) -> None: ... def show_list(self, connection: BaseDatabaseWrapper, app_names: Optional[List[str]] = ...) -> None: ...
def show_plan(self, connection: Any, app_names: Optional[List[str]] = ...) -> None: ... def show_plan(self, connection: BaseDatabaseWrapper, app_names: Optional[List[str]] = ...) -> None: ...

View File

@@ -1,9 +1,14 @@
from typing import Any, List from typing import Any, List
from django.core.management.color import Style from django.core.management.color import Style
from django.db.backends.base.base import BaseDatabaseWrapper
def sql_flush( def sql_flush(
style: Style, connection: Any, only_django: bool = ..., reset_sequences: bool = ..., allow_cascade: bool = ... style: Style,
connection: BaseDatabaseWrapper,
only_django: bool = ...,
reset_sequences: bool = ...,
allow_cascade: bool = ...,
) -> List[str]: ... ) -> List[str]: ...
def emit_pre_migrate_signal(verbosity: int, interactive: bool, db: str, **kwargs: Any) -> None: ... def emit_pre_migrate_signal(verbosity: int, interactive: bool, db: str, **kwargs: Any) -> None: ...
def emit_post_migrate_signal(verbosity: int, interactive: bool, db: str, **kwargs: Any) -> None: ... def emit_post_migrate_signal(verbosity: int, interactive: bool, db: str, **kwargs: Any) -> None: ...

View File

@@ -1,5 +1,6 @@
from typing import Any from typing import Any
from .backends.base.base import BaseDatabaseWrapper
from .utils import ( from .utils import (
DEFAULT_DB_ALIAS as DEFAULT_DB_ALIAS, DEFAULT_DB_ALIAS as DEFAULT_DB_ALIAS,
DJANGO_VERSION_PICKLE_KEY as DJANGO_VERSION_PICKLE_KEY, DJANGO_VERSION_PICKLE_KEY as DJANGO_VERSION_PICKLE_KEY,
@@ -11,16 +12,19 @@ from .utils import (
NotSupportedError as NotSupportedError, NotSupportedError as NotSupportedError,
InternalError as InternalError, InternalError as InternalError,
InterfaceError as InterfaceError, InterfaceError as InterfaceError,
ConnectionHandler as ConnectionHandler,
Error as Error, Error as Error,
ConnectionDoesNotExist as ConnectionDoesNotExist, ConnectionDoesNotExist as ConnectionDoesNotExist,
# Not exported in __all__
ConnectionHandler,
ConnectionRouter,
) )
from . import migrations from . import migrations
connections: Any connections: ConnectionHandler
router: Any router: ConnectionRouter
connection: Any # Actually DefaultConnectionProxy, but quacks exactly like BaseDatabaseWrapper, it's not worth distinguishing the two.
connection: BaseDatabaseWrapper
class DefaultConnectionProxy: class DefaultConnectionProxy:
def __getattr__(self, item: str) -> Any: ... def __getattr__(self, item: str) -> Any: ...

View File

@@ -1,4 +1,5 @@
from typing import Any, Callable, Dict, Iterator, List, Optional from datetime import tzinfo
from typing import Any, Callable, Dict, Iterator, List, Optional, Type, TypeVar
from django.db.backends.base.client import BaseDatabaseClient from django.db.backends.base.client import BaseDatabaseClient
from django.db.backends.base.creation import BaseDatabaseCreation from django.db.backends.base.creation import BaseDatabaseCreation
@@ -11,25 +12,29 @@ from django.db.backends.base.features import BaseDatabaseFeatures
from django.db.backends.base.introspection import BaseDatabaseIntrospection from django.db.backends.base.introspection import BaseDatabaseIntrospection
from django.db.backends.base.operations import BaseDatabaseOperations
NO_DB_ALIAS: str NO_DB_ALIAS: str
_T = TypeVar("_T", bound="BaseDatabaseWrapper")
_ExecuteWrapper = Callable[[Callable[[str, Any, bool, Dict[str, Any]], Any], str, Any, bool, Dict[str, Any]], Any]
class BaseDatabaseWrapper: class BaseDatabaseWrapper:
data_types: Any = ... data_types: Dict[str, str] = ...
data_types_suffix: Any = ... data_types_suffix: Dict[str, str] = ...
data_type_check_constraints: Any = ... data_type_check_constraints: Dict[str, str] = ...
ops: Any = ...
vendor: str = ... vendor: str = ...
display_name: str = ... display_name: str = ...
SchemaEditorClass: Optional[BaseDatabaseSchemaEditor] = ... SchemaEditorClass: Type[BaseDatabaseSchemaEditor] = ...
client_class: Any = ... client_class: Type[BaseDatabaseClient] = ...
creation_class: Any = ... creation_class: Type[BaseDatabaseCreation] = ...
features_class: Any = ... features_class: Type[BaseDatabaseFeatures] = ...
introspection_class: Any = ... introspection_class: Type[BaseDatabaseIntrospection] = ...
ops_class: Any = ... ops_class: Type[BaseDatabaseOperations] = ...
validation_class: Any = ... validation_class: Type[BaseDatabaseValidation] = ...
queries_limit: int = ... queries_limit: int = ...
connection: Any = ... connection: Any = ...
settings_dict: Any = ... settings_dict: Dict[str, Any] = ...
alias: str = ... alias: str = ...
queries_log: Any = ... queries_log: Any = ...
force_debug_cursor: bool = ... force_debug_cursor: bool = ...
@@ -39,24 +44,23 @@ class BaseDatabaseWrapper:
savepoint_ids: Any = ... savepoint_ids: Any = ...
commit_on_exit: bool = ... commit_on_exit: bool = ...
needs_rollback: bool = ... needs_rollback: bool = ...
close_at: Optional[Any] = ... close_at: Optional[float] = ...
closed_in_transaction: bool = ... closed_in_transaction: bool = ...
errors_occurred: bool = ... errors_occurred: bool = ...
allow_thread_sharing: bool = ... allow_thread_sharing: bool = ...
run_on_commit: List[Any] = ... run_on_commit: List[Callable[[], None]] = ...
run_commit_hooks_on_set_autocommit_on: bool = ... run_commit_hooks_on_set_autocommit_on: bool = ...
execute_wrappers: List[Any] = ... execute_wrappers: List[_ExecuteWrapper] = ...
client: BaseDatabaseClient = ... client: BaseDatabaseClient = ...
creation: BaseDatabaseCreation = ... creation: BaseDatabaseCreation = ...
features: BaseDatabaseFeatures = ... features: BaseDatabaseFeatures = ...
introspection: BaseDatabaseIntrospection = ... introspection: BaseDatabaseIntrospection = ...
ops: BaseDatabaseOperations = ...
validation: BaseDatabaseValidation = ... validation: BaseDatabaseValidation = ...
def __init__( def __init__(self, settings_dict: Dict[str, Any], alias: str = ..., allow_thread_sharing: bool = ...) -> None: ...
self, settings_dict: Dict[str, Dict[str, str]], alias: str = ..., allow_thread_sharing: bool = ...
) -> None: ...
def ensure_timezone(self) -> bool: ... def ensure_timezone(self) -> bool: ...
def timezone(self): ... def timezone(self) -> tzinfo: ...
def timezone_name(self): ... def timezone_name(self) -> str: ...
@property @property
def queries_logged(self) -> bool: ... def queries_logged(self) -> bool: ...
@property @property
@@ -86,7 +90,7 @@ class BaseDatabaseWrapper:
def disable_constraint_checking(self): ... def disable_constraint_checking(self): ...
def enable_constraint_checking(self) -> None: ... def enable_constraint_checking(self) -> None: ...
def check_constraints(self, table_names: Optional[Any] = ...) -> None: ... def check_constraints(self, table_names: Optional[Any] = ...) -> None: ...
def is_usable(self) -> None: ... def is_usable(self) -> bool: ...
def close_if_unusable_or_obsolete(self) -> None: ... def close_if_unusable_or_obsolete(self) -> None: ...
def validate_thread_sharing(self) -> None: ... def validate_thread_sharing(self) -> None: ...
def prepare_database(self) -> None: ... def prepare_database(self) -> None: ...
@@ -96,7 +100,7 @@ class BaseDatabaseWrapper:
def make_cursor(self, cursor: CursorWrapper) -> CursorWrapper: ... def make_cursor(self, cursor: CursorWrapper) -> CursorWrapper: ...
def temporary_connection(self) -> None: ... def temporary_connection(self) -> None: ...
def schema_editor(self, *args: Any, **kwargs: Any) -> BaseDatabaseSchemaEditor: ... def schema_editor(self, *args: Any, **kwargs: Any) -> BaseDatabaseSchemaEditor: ...
def on_commit(self, func: Callable) -> None: ... def on_commit(self, func: Callable[[], None]) -> None: ...
def run_and_clear_commit_hooks(self) -> None: ... def run_and_clear_commit_hooks(self) -> None: ...
def execute_wrapper(self, wrapper: Callable) -> Iterator[None]: ... def execute_wrapper(self, wrapper: _ExecuteWrapper) -> Iterator[None]: ...
def copy(self, alias: None = ..., allow_thread_sharing: None = ...) -> Any: ... def copy(self: _T, alias: Optional[str] = ...) -> _T: ...

View File

@@ -4,6 +4,6 @@ from django.db.backends.base.base import BaseDatabaseWrapper
class BaseDatabaseClient: class BaseDatabaseClient:
executable_name: Any = ... executable_name: Any = ...
connection: Any = ... connection: BaseDatabaseWrapper
def __init__(self, connection: BaseDatabaseWrapper) -> None: ... def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
def runshell(self) -> None: ... def runshell(self) -> None: ...

View File

@@ -5,7 +5,7 @@ from django.db.backends.base.base import BaseDatabaseWrapper
TEST_DATABASE_PREFIX: str TEST_DATABASE_PREFIX: str
class BaseDatabaseCreation: class BaseDatabaseCreation:
connection: Any = ... connection: BaseDatabaseWrapper
def __init__(self, connection: BaseDatabaseWrapper) -> None: ... def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
def create_test_db( def create_test_db(
self, verbosity: int = ..., autoclobber: bool = ..., serialize: bool = ..., keepdb: bool = ... self, verbosity: int = ..., autoclobber: bool = ..., serialize: bool = ..., keepdb: bool = ...

View File

@@ -94,7 +94,7 @@ class BaseDatabaseFeatures:
db_functions_convert_bytes_to_str: bool = ... db_functions_convert_bytes_to_str: bool = ...
supported_explain_formats: Any = ... supported_explain_formats: Any = ...
validates_explain_options: bool = ... validates_explain_options: bool = ...
connection: Any = ... connection: BaseDatabaseWrapper
def __init__(self, connection: BaseDatabaseWrapper) -> None: ... def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
def supports_explaining_query_execution(self) -> bool: ... def supports_explaining_query_execution(self) -> bool: ...
def supports_transactions(self): ... def supports_transactions(self): ...

View File

@@ -11,7 +11,7 @@ FieldInfo = namedtuple("FieldInfo", "name type_code display_size internal_size p
class BaseDatabaseIntrospection: class BaseDatabaseIntrospection:
data_types_reverse: Any = ... data_types_reverse: Any = ...
connection: Any = ... connection: BaseDatabaseWrapper
def __init__(self, connection: BaseDatabaseWrapper) -> None: ... def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
def get_field_type(self, data_type: str, description: FieldInfo) -> str: ... def get_field_type(self, data_type: str, description: FieldInfo) -> str: ...
def table_name_converter(self, name: str) -> str: ... def table_name_converter(self, name: str) -> str: ...

View File

@@ -8,12 +8,8 @@ from django.db.backends.utils import CursorWrapper
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.expressions import Case, Expression from django.db.models.expressions import Case, Expression
from django.db.models.sql.compiler import SQLCompiler from django.db.models.sql.compiler import SQLCompiler
from django.db import DefaultConnectionProxy
from django.db.models.fields import Field from django.db.models.fields import Field
_Connection = Union[DefaultConnectionProxy, BaseDatabaseWrapper]
class BaseDatabaseOperations: class BaseDatabaseOperations:
compiler_module: str = ... compiler_module: str = ...
integer_field_ranges: Any = ... integer_field_ranges: Any = ...
@@ -25,9 +21,9 @@ class BaseDatabaseOperations:
UNBOUNDED_PRECEDING: Any = ... UNBOUNDED_PRECEDING: Any = ...
UNBOUNDED_FOLLOWING: Any = ... UNBOUNDED_FOLLOWING: Any = ...
CURRENT_ROW: str = ... CURRENT_ROW: str = ...
explain_prefix: Any = ... explain_prefix: Optional[str] = ...
connection: _Connection = ... connection: BaseDatabaseWrapper
def __init__(self, connection: Optional[_Connection]) -> None: ... def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
def autoinc_sql(self, table: str, column: str) -> None: ... def autoinc_sql(self, table: str, column: str) -> None: ...
def bulk_batch_size(self, fields: Any, objs: Any): ... def bulk_batch_size(self, fields: Any, objs: Any): ...
def cache_key_culling_sql(self) -> str: ... def cache_key_culling_sql(self) -> str: ...
@@ -88,7 +84,7 @@ class BaseDatabaseOperations:
def year_lookup_bounds_for_datetime_field(self, value: int) -> List[str]: ... def year_lookup_bounds_for_datetime_field(self, value: int) -> List[str]: ...
def get_db_converters(self, expression: Expression) -> List[Any]: ... def get_db_converters(self, expression: Expression) -> List[Any]: ...
def convert_durationfield_value( def convert_durationfield_value(
self, value: Optional[float], expression: Expression, connection: _Connection self, value: Optional[float], expression: Expression, connection: BaseDatabaseWrapper
) -> Optional[timedelta]: ... ) -> Optional[timedelta]: ...
def check_expression_support(self, expression: Any) -> None: ... def check_expression_support(self, expression: Any) -> None: ...
def combine_expression(self, connector: str, sub_expressions: List[str]) -> str: ... def combine_expression(self, connector: str, sub_expressions: List[str]) -> str: ...

View File

@@ -1,5 +1,6 @@
from typing import Any, ContextManager, List, Optional, Sequence, Tuple, Type, Union from typing import Any, ContextManager, List, Optional, Sequence, Tuple, Type, Union
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.backends.ddl_references import Statement from django.db.backends.ddl_references import Statement
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.indexes import Index from django.db.models.indexes import Index
@@ -35,11 +36,11 @@ class BaseDatabaseSchemaEditor(ContextManager[Any]):
sql_create_pk: str = ... sql_create_pk: str = ...
sql_delete_pk: str = ... sql_delete_pk: str = ...
sql_delete_procedure: str = ... sql_delete_procedure: str = ...
connection: Any = ... connection: BaseDatabaseWrapper = ...
collect_sql: bool = ... collect_sql: bool = ...
collected_sql: Any = ... collected_sql: Any = ...
atomic_migration: Any = ... atomic_migration: Any = ...
def __init__(self, connection: Any, collect_sql: bool = ..., atomic: bool = ...) -> None: ... def __init__(self, connection: BaseDatabaseWrapper, collect_sql: bool = ..., atomic: bool = ...) -> None: ...
deferred_sql: Any = ... deferred_sql: Any = ...
atomic: Any = ... atomic: Any = ...
def __enter__(self) -> BaseDatabaseSchemaEditor: ... def __enter__(self) -> BaseDatabaseSchemaEditor: ...

View File

@@ -5,7 +5,7 @@ from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.fields import Field from django.db.models.fields import Field
class BaseDatabaseValidation: class BaseDatabaseValidation:
connection: Any = ... connection: BaseDatabaseWrapper = ...
def __init__(self, connection: BaseDatabaseWrapper) -> None: ... def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
def check(self, **kwargs: Any) -> List[Any]: ... def check(self, **kwargs: Any) -> List[Any]: ...
def check_field(self, field: Field, **kwargs: Any) -> List[Any]: ... def check_field(self, field: Field, **kwargs: Any) -> List[Any]: ...

View File

@@ -1,6 +1,5 @@
from typing import Any, Callable, List, Optional, Set, Tuple, Union from typing import Any, Callable, List, Optional, Set, Tuple, Union
from django.db import DefaultConnectionProxy
from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.migrations.migration import Migration from django.db.migrations.migration import Migration
@@ -9,13 +8,13 @@ from .recorder import MigrationRecorder
from .state import ProjectState from .state import ProjectState
class MigrationExecutor: class MigrationExecutor:
connection: Any = ... connection: BaseDatabaseWrapper = ...
loader: MigrationLoader = ... loader: MigrationLoader = ...
recorder: MigrationRecorder = ... recorder: MigrationRecorder = ...
progress_callback: Callable = ... progress_callback: Callable = ...
def __init__( def __init__(
self, self,
connection: Optional[Union[DefaultConnectionProxy, BaseDatabaseWrapper]], connection: Optional[BaseDatabaseWrapper],
progress_callback: Optional[Callable] = ..., progress_callback: Optional[Callable] = ...,
) -> None: ... ) -> None: ...
def migration_plan( def migration_plan(

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, Optional, Sequence, Set, Tuple, Union from typing import Any, Dict, Optional, Sequence, Set, Tuple, Union
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.migrations.migration import Migration from django.db.migrations.migration import Migration
from django.db.migrations.state import ProjectState from django.db.migrations.state import ProjectState
@@ -13,11 +14,13 @@ from .exceptions import (
MIGRATIONS_MODULE_NAME: str MIGRATIONS_MODULE_NAME: str
class MigrationLoader: class MigrationLoader:
connection: Any = ... connection: Optional[BaseDatabaseWrapper] = ...
disk_migrations: Dict[Tuple[str, str], Migration] = ... disk_migrations: Dict[Tuple[str, str], Migration] = ...
applied_migrations: Set[Tuple[str, str]] = ... applied_migrations: Set[Tuple[str, str]] = ...
ignore_no_migrations: bool = ... ignore_no_migrations: bool = ...
def __init__(self, connection: Any, load: bool = ..., ignore_no_migrations: bool = ...) -> None: ... def __init__(
self, connection: Optional[BaseDatabaseWrapper], load: bool = ..., ignore_no_migrations: bool = ...
) -> None: ...
@classmethod @classmethod
def migrations_module(cls, app_label: str) -> Tuple[Optional[str], bool]: ... def migrations_module(cls, app_label: str) -> Tuple[Optional[str], bool]: ...
unmigrated_apps: Set[str] = ... unmigrated_apps: Set[str] = ...
@@ -31,7 +34,7 @@ class MigrationLoader:
graph: Any = ... graph: Any = ...
replacements: Any = ... replacements: Any = ...
def build_graph(self) -> None: ... def build_graph(self) -> None: ...
def check_consistent_history(self, connection: Any) -> None: ... def check_consistent_history(self, connection: BaseDatabaseWrapper) -> None: ...
def detect_conflicts(self) -> Dict[str, Set[str]]: ... def detect_conflicts(self) -> Dict[str, Set[str]]: ...
def project_state( def project_state(
self, nodes: Optional[Union[Tuple[str, str], Sequence[Tuple[str, str]]]] = ..., at_end: bool = ... self, nodes: Optional[Union[Tuple[str, str], Sequence[Tuple[str, str]]]] = ..., at_end: bool = ...

View File

@@ -10,8 +10,8 @@ class MigrationRecorder:
app: Any = ... app: Any = ...
name: Any = ... name: Any = ...
applied: Any = ... applied: Any = ...
connection: Optional[BaseDatabaseWrapper] = ... connection: BaseDatabaseWrapper = ...
def __init__(self, connection: Optional[BaseDatabaseWrapper]) -> None: ... def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
@property @property
def migration_qs(self) -> QuerySet: ... def migration_qs(self) -> QuerySet: ...
def has_table(self) -> bool: ... def has_table(self) -> bool: ...

View File

@@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set,
from django.db.models.lookups import Lookup from django.db.models.lookups import Lookup
from django.db.models.sql.compiler import SQLCompiler from django.db.models.sql.compiler import SQLCompiler
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import Q, QuerySet from django.db.models import Q, QuerySet
from django.db.models.fields import Field from django.db.models.fields import Field
from django.db.models.query import _BaseQuerySet from django.db.models.query import _BaseQuerySet
@@ -12,7 +13,9 @@ from django.db.models.query import _BaseQuerySet
_OutputField = Union[Field, str] _OutputField = Union[Field, str]
class SQLiteNumericMixin: class SQLiteNumericMixin:
def as_sqlite(self, compiler: SQLCompiler, connection: Any, **extra_context: Any) -> Tuple[str, List[float]]: ... def as_sqlite(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any
) -> Tuple[str, List[float]]: ...
_Self = TypeVar("_Self") _Self = TypeVar("_Self")
@@ -57,7 +60,7 @@ class BaseExpression:
filterable: bool = ... filterable: bool = ...
window_compatible: bool = ... window_compatible: bool = ...
def __init__(self, output_field: Optional[_OutputField] = ...) -> None: ... def __init__(self, output_field: Optional[_OutputField] = ...) -> None: ...
def get_db_converters(self, connection: Any) -> List[Callable]: ... def get_db_converters(self, connection: BaseDatabaseWrapper) -> List[Callable]: ...
def get_source_expressions(self) -> List[Any]: ... def get_source_expressions(self) -> List[Any]: ...
def set_source_expressions(self, exprs: Sequence[Combinable]) -> None: ... def set_source_expressions(self, exprs: Sequence[Combinable]) -> None: ...
@property @property
@@ -91,11 +94,11 @@ class BaseExpression:
def reverse_ordering(self): ... def reverse_ordering(self): ...
def flatten(self) -> Iterator[Expression]: ... def flatten(self) -> Iterator[Expression]: ...
def deconstruct(self) -> Any: ... def deconstruct(self) -> Any: ...
def as_sqlite(self, compiler: SQLCompiler, connection: Any) -> Any: ... def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
def as_sql(self, compiler: SQLCompiler, connection: Any, **extra_context: Any) -> Any: ... def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> Any: ...
def as_mysql(self, compiler: Any, connection: Any) -> Any: ... def as_mysql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
def as_postgresql(self, compiler: Any, connection: Any) -> Any: ... def as_postgresql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
def as_oracle(self, compiler: Any, connection: Any): ... def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
class Expression(BaseExpression, Combinable): ... class Expression(BaseExpression, Combinable): ...
@@ -216,7 +219,9 @@ class WindowFrame(Expression):
template: str = ... template: str = ...
frame_type: str = ... frame_type: str = ...
def __init__(self, start: Optional[int] = ..., end: Optional[int] = ...) -> None: ... def __init__(self, start: Optional[int] = ..., end: Optional[int] = ...) -> None: ...
def window_frame_start_end(self, connection: Any, start: Optional[int], end: Optional[int]) -> Tuple[int, int]: ... def window_frame_start_end(
self, connection: BaseDatabaseWrapper, start: Optional[int], end: Optional[int]
) -> Tuple[int, int]: ...
class RowRange(WindowFrame): ... class RowRange(WindowFrame): ...
class ValueRange(WindowFrame): ... class ValueRange(WindowFrame): ...

View File

@@ -18,9 +18,9 @@ from typing import (
) )
from django.core.checks import CheckMessage from django.core.checks import CheckMessage
from django.db.models import Model
from django.core.exceptions import FieldDoesNotExist as FieldDoesNotExist from django.core.exceptions import FieldDoesNotExist as FieldDoesNotExist
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import Model
from django.db.models.expressions import Combinable, Col from django.db.models.expressions import Combinable, Col
from django.db.models.query_utils import RegisterLookupMixin from django.db.models.query_utils import RegisterLookupMixin
from django.forms import Field as FormField, Widget from django.forms import Field as FormField, Widget
@@ -110,12 +110,12 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
def __get__(self: _T, instance, owner) -> _T: ... def __get__(self: _T, instance, owner) -> _T: ...
def deconstruct(self) -> Any: ... def deconstruct(self) -> Any: ...
def set_attributes_from_name(self, name: str) -> None: ... def set_attributes_from_name(self, name: str) -> None: ...
def db_type(self, connection: Any) -> str: ... def db_type(self, connection: BaseDatabaseWrapper) -> str: ...
def db_parameters(self, connection: Any) -> Dict[str, str]: ... def db_parameters(self, connection: BaseDatabaseWrapper) -> Dict[str, str]: ...
def pre_save(self, model_instance: Model, add: bool) -> Any: ... def pre_save(self, model_instance: Model, add: bool) -> Any: ...
def get_prep_value(self, value: Any) -> Any: ... def get_prep_value(self, value: Any) -> Any: ...
def get_db_prep_value(self, value: Any, connection: Any, prepared: bool) -> Any: ... def get_db_prep_value(self, value: Any, connection: BaseDatabaseWrapper, prepared: bool) -> Any: ...
def get_db_prep_save(self, value: Any, connection: Any) -> Any: ... def get_db_prep_save(self, value: Any, connection: BaseDatabaseWrapper) -> Any: ...
def get_internal_type(self) -> str: ... def get_internal_type(self) -> str: ...
# TODO: plugin support # TODO: plugin support
def formfield(self, **kwargs) -> Any: ... def formfield(self, **kwargs) -> Any: ...
@@ -150,7 +150,7 @@ class IntegerField(Field[_ST, _GT]):
_pyi_lookup_exact_type: Union[str, int] _pyi_lookup_exact_type: Union[str, int]
class PositiveIntegerRelDbTypeMixin: class PositiveIntegerRelDbTypeMixin:
def rel_db_type(self, connection: Any): ... def rel_db_type(self, connection: BaseDatabaseWrapper) -> str: ...
class PositiveIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField[_ST, _GT]): ... class PositiveIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField[_ST, _GT]): ...
class PositiveSmallIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField[_ST, _GT]): ... class PositiveSmallIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField[_ST, _GT]): ...

View File

@@ -1,7 +1,9 @@
from . import Field from . import Field
from .mixins import CheckFieldDefaultMixin from .mixins import CheckFieldDefaultMixin
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import lookups from django.db.models import lookups
from django.db.models.lookups import PostgresOperatorLookup, Transform from django.db.models.lookups import PostgresOperatorLookup, Transform
from django.db.models.sql.compiler import SQLCompiler
from typing import Any, Optional, Callable from typing import Any, Optional, Callable
class JSONField(CheckFieldDefaultMixin, Field): class JSONField(CheckFieldDefaultMixin, Field):
@@ -25,7 +27,7 @@ class JSONExact(lookups.Exact): ...
class KeyTransform(Transform): class KeyTransform(Transform):
key_name: Any = ... key_name: Any = ...
def __init__(self, key_name: Any, *args: Any, **kwargs: Any) -> None: ... def __init__(self, key_name: Any, *args: Any, **kwargs: Any) -> None: ...
def preprocess_lhs(self, compiler: Any, connection: Any, lhs_only: bool = ...): ... def preprocess_lhs(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs_only: bool = ...) -> Any: ...
class KeyTextTransform(KeyTransform): ... class KeyTextTransform(KeyTransform): ...

View File

@@ -27,15 +27,19 @@ class Lookup(Generic[_T]):
def get_source_expressions(self) -> List[Expression]: ... def get_source_expressions(self) -> List[Expression]: ...
def set_source_expressions(self, new_exprs: List[Expression]) -> None: ... def set_source_expressions(self, new_exprs: List[Expression]) -> None: ...
def get_prep_lookup(self) -> Any: ... def get_prep_lookup(self) -> Any: ...
def get_db_prep_lookup(self, value: Union[int, str], connection: BaseDatabaseWrapper) -> Tuple[str, List[SafeText]]: ... def get_db_prep_lookup(
self, value: Union[int, str], connection: BaseDatabaseWrapper
) -> Tuple[str, List[SafeText]]: ...
def process_lhs( def process_lhs(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Optional[Expression] = ... self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Optional[Expression] = ...
) -> Tuple[str, List[Union[int, str]]]: ... ) -> Tuple[str, List[Union[int, str]]]: ...
def process_rhs(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, List[Union[int, str]]]: ... def process_rhs(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
) -> Tuple[str, List[Union[int, str]]]: ...
def rhs_is_direct_value(self) -> bool: ... def rhs_is_direct_value(self) -> bool: ...
def relabeled_clone(self: _T, relabels: Mapping[str, str]) -> _T: ... def relabeled_clone(self: _T, relabels: Mapping[str, str]) -> _T: ...
def get_group_by_cols(self) -> List[Expression]: ... def get_group_by_cols(self) -> List[Expression]: ...
def as_sql(self, compiler: Any, connection: Any) -> Any: ... def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
def contains_aggregate(self) -> bool: ... def contains_aggregate(self) -> bool: ...
def contains_over_clause(self) -> bool: ... def contains_over_clause(self) -> bool: ...
@property @property
@@ -61,7 +65,7 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
class PostgresOperatorLookup(FieldGetDbPrepValueMixin, Lookup): class PostgresOperatorLookup(FieldGetDbPrepValueMixin, Lookup):
postgres_operator: str = ... postgres_operator: str = ...
def as_postgresql(self, compiler: Any, connection: Any): ... def as_postgresql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
class Exact(FieldGetDbPrepValueMixin, BuiltinLookup): ... class Exact(FieldGetDbPrepValueMixin, BuiltinLookup): ...
class IExact(BuiltinLookup): ... class IExact(BuiltinLookup): ...
@@ -78,7 +82,7 @@ class IntegerGreaterThanOrEqual(IntegerFieldFloatRounding, GreaterThanOrEqual[Un
class IntegerLessThan(IntegerFieldFloatRounding, LessThan[Union[int, float]]): ... class IntegerLessThan(IntegerFieldFloatRounding, LessThan[Union[int, float]]): ...
class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup): class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
def split_parameter_list_as_sql(self, compiler: Any, connection: Any): ... def split_parameter_list_as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
class PatternLookup(BuiltinLookup[str]): class PatternLookup(BuiltinLookup[str]):
param_pattern: str = ... param_pattern: str = ...
@@ -109,7 +113,7 @@ class YearLte(YearComparisonLookup): ...
class UUIDTextMixin: class UUIDTextMixin:
rhs: Any = ... rhs: Any = ...
def process_rhs(self, qn: Any, connection: Any): ... def process_rhs(self, qn: Any, connection: BaseDatabaseWrapper) -> Any: ...
class UUIDIExact(UUIDTextMixin, IExact): ... class UUIDIExact(UUIDTextMixin, IExact): ...
class UUIDContains(UUIDTextMixin, Contains): ... class UUIDContains(UUIDTextMixin, Contains): ...

View File

@@ -1,6 +1,7 @@
from collections import namedtuple from collections import namedtuple
from typing import Any, Collection, Dict, Iterator, List, Mapping, Optional, Sequence, Set, Tuple, Type from typing import Any, Collection, Dict, Iterator, List, Mapping, Optional, Sequence, Set, Tuple, Type
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.fields.mixins import FieldCacheMixin from django.db.models.fields.mixins import FieldCacheMixin
from django.db.models.sql.compiler import SQLCompiler from django.db.models.sql.compiler import SQLCompiler
@@ -20,7 +21,7 @@ class QueryWrapper:
contains_aggregate: bool = ... contains_aggregate: bool = ...
data: Tuple[str, List[Any]] = ... data: Tuple[str, List[Any]] = ...
def __init__(self, sql: str, params: List[Any]) -> None: ... def __init__(self, sql: str, params: List[Any]) -> None: ...
def as_sql(self, compiler: SQLCompiler = ..., connection: Any = ...) -> Any: ... def as_sql(self, compiler: Optional[SQLCompiler] = ..., connection: Optional[BaseDatabaseWrapper] = ...) -> Any: ...
class Q(tree.Node): class Q(tree.Node):
AND: str = ... AND: str = ...
@@ -76,4 +77,4 @@ class FilteredRelation:
def __init__(self, relation_name: str, *, condition: Any = ...) -> None: ... def __init__(self, relation_name: str, *, condition: Any = ...) -> None: ...
def clone(self) -> FilteredRelation: ... def clone(self) -> FilteredRelation: ...
def resolve_expression(self, *args: Any, **kwargs: Any) -> None: ... def resolve_expression(self, *args: Any, **kwargs: Any) -> None: ...
def as_sql(self, compiler: SQLCompiler, connection: Any) -> Any: ... def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...

View File

@@ -4,23 +4,25 @@ from itertools import chain
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union
from uuid import UUID from uuid import UUID
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.expressions import BaseExpression, Expression from django.db.models.expressions import BaseExpression, Expression
from django.db.models.sql.query import Query, RawQuery from django.db.models.sql.query import Query, RawQuery
FORCE: Any FORCE: Any
class SQLCompiler: class SQLCompiler:
query: Any = ... query: Any = ...
connection: Any = ... connection: BaseDatabaseWrapper = ...
using: Any = ... using: Any = ...
quote_cache: Any = ... quote_cache: Any = ...
select: Any = ... select: Any = ...
annotation_col_map: Any = ... annotation_col_map: Any = ...
klass_info: Any = ... klass_info: Any = ...
ordering_parts: Any = ... ordering_parts: Any = ...
def __init__(self, query: Union[Query, RawQuery], connection: Any, using: Optional[str]) -> None: ... def __init__(
self, query: Union[Query, RawQuery], connection: BaseDatabaseWrapper, using: Optional[str]
) -> None: ...
col_count: Any = ... col_count: Any = ...
def setup_query(self) -> None: ... def setup_query(self) -> None: ...
has_extra_select: Any = ... has_extra_select: Any = ...

View File

@@ -1,6 +1,7 @@
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.fields.mixins import FieldCacheMixin from django.db.models.fields.mixins import FieldCacheMixin
from django.db.models.query_utils import FilteredRelation, PathInfo from django.db.models.query_utils import FilteredRelation, PathInfo
from django.db.models.sql.compiler import SQLCompiler from django.db.models.sql.compiler import SQLCompiler
@@ -31,7 +32,7 @@ class Join:
nullable: bool, nullable: bool,
filtered_relation: Optional[FilteredRelation] = ..., filtered_relation: Optional[FilteredRelation] = ...,
) -> None: ... ) -> None: ...
def as_sql(self, compiler: SQLCompiler, connection: Any) -> Tuple[str, List[Union[int, str]]]: ... def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, List[Union[int, str]]]: ...
def relabeled_clone(self, change_map: Union[Dict[str, str], OrderedDict]) -> Join: ... def relabeled_clone(self, change_map: Union[Dict[str, str], OrderedDict]) -> Join: ...
def equals(self, other: Union[BaseTable, Join], with_filtered_relation: bool) -> bool: ... def equals(self, other: Union[BaseTable, Join], with_filtered_relation: bool) -> bool: ...
def demote(self) -> Join: ... def demote(self) -> Join: ...
@@ -44,6 +45,6 @@ class BaseTable:
table_name: str = ... table_name: str = ...
table_alias: Optional[str] = ... table_alias: Optional[str] = ...
def __init__(self, table_name: str, alias: Optional[str]) -> None: ... def __init__(self, table_name: str, alias: Optional[str]) -> None: ...
def as_sql(self, compiler: SQLCompiler, connection: Any) -> Tuple[str, List[Any]]: ... def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, List[Any]]: ...
def relabeled_clone(self, change_map: OrderedDict) -> BaseTable: ... def relabeled_clone(self, change_map: OrderedDict) -> BaseTable: ...
def equals(self, other: Join, with_filtered_relation: bool) -> bool: ... def equals(self, other: Join, with_filtered_relation: bool) -> bool: ...

View File

@@ -2,15 +2,15 @@ import collections
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, Union, Iterable from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, Union, Iterable
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import Expression, Field, FilteredRelation, Model, Q, QuerySet
from django.db.models.expressions import Combinable
from django.db.models.lookups import Lookup, Transform from django.db.models.lookups import Lookup, Transform
from django.db.models.query_utils import PathInfo, RegisterLookupMixin from django.db.models.query_utils import PathInfo, RegisterLookupMixin
from django.db.models.sql.compiler import SQLCompiler from django.db.models.sql.compiler import SQLCompiler
from django.db.models.sql.datastructures import BaseTable from django.db.models.sql.datastructures import BaseTable
from django.db.models.sql.where import WhereNode from django.db.models.sql.where import WhereNode
from django.db.models import Expression, Field, FilteredRelation, Model, Q, QuerySet
from django.db.models.expressions import Combinable
JoinInfo = namedtuple("JoinInfo", ["final_field", "targets", "opts", "joins", "path", "transform_function"]) JoinInfo = namedtuple("JoinInfo", ["final_field", "targets", "opts", "joins", "path", "transform_function"])
class RawQuery: class RawQuery:
@@ -81,7 +81,7 @@ class Query:
def has_select_fields(self) -> bool: ... def has_select_fields(self) -> bool: ...
def sql_with_params(self) -> Tuple[str, Tuple]: ... def sql_with_params(self) -> Tuple[str, Tuple]: ...
def __deepcopy__(self, memo: Dict[str, Any]) -> Query: ... def __deepcopy__(self, memo: Dict[str, Any]) -> Query: ...
def get_compiler(self, using: Optional[str] = ..., connection: Any = ...) -> SQLCompiler: ... def get_compiler(self, using: Optional[str] = ..., connection: BaseDatabaseWrapper = ...) -> SQLCompiler: ...
def clone(self) -> Query: ... def clone(self) -> Query: ...
def chain(self, klass: Optional[Type[Query]] = ...) -> Query: ... def chain(self, klass: Optional[Type[Query]] = ...) -> Query: ...
def relabeled_clone(self, change_map: Union[Dict[Any, Any], OrderedDict]) -> Query: ... def relabeled_clone(self, change_map: Union[Dict[Any, Any], OrderedDict]) -> Query: ...
@@ -101,7 +101,7 @@ class Query:
def get_initial_alias(self) -> str: ... def get_initial_alias(self) -> str: ...
def count_active_tables(self) -> int: ... def count_active_tables(self) -> int: ...
def resolve_expression(self, query: Query, *args: Any, **kwargs: Any) -> Query: ... def resolve_expression(self, query: Query, *args: Any, **kwargs: Any) -> Query: ...
def as_sql(self, compiler: SQLCompiler, connection: Any) -> Any: ... def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
def resolve_lookup_value(self, value: Any, can_reuse: Optional[Set[str]], allow_joins: bool) -> Any: ... def resolve_lookup_value(self, value: Any, can_reuse: Optional[Set[str]], allow_joins: bool) -> Any: ...
def solve_lookup_type(self, lookup: str) -> Tuple[Sequence[str], Sequence[str], bool]: ... def solve_lookup_type(self, lookup: str) -> Tuple[Sequence[str], Sequence[str], bool]: ...
def build_filter( def build_filter(

View File

@@ -1,6 +1,7 @@
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.expressions import Expression from django.db.models.expressions import Expression
from django.db.models.sql.compiler import SQLCompiler from django.db.models.sql.compiler import SQLCompiler
from django.db.models.sql.query import Query from django.db.models.sql.query import Query
@@ -18,7 +19,7 @@ class WhereNode(tree.Node):
resolved: bool = ... resolved: bool = ...
conditional: bool = ... conditional: bool = ...
def split_having(self, negated: bool = ...) -> Tuple[Optional[WhereNode], Optional[WhereNode]]: ... def split_having(self, negated: bool = ...) -> Tuple[Optional[WhereNode], Optional[WhereNode]]: ...
def as_sql(self, compiler: SQLCompiler, connection: Any) -> Any: ... def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
def get_group_by_cols(self) -> List[Expression]: ... def get_group_by_cols(self) -> List[Expression]: ...
def relabel_aliases(self, change_map: Union[Dict[Optional[str], str], OrderedDict]) -> None: ... def relabel_aliases(self, change_map: Union[Dict[Optional[str], str], OrderedDict]) -> None: ...
def clone(self) -> WhereNode: ... def clone(self) -> WhereNode: ...
@@ -27,14 +28,16 @@ class WhereNode(tree.Node):
class NothingNode: class NothingNode:
contains_aggregate: bool = ... contains_aggregate: bool = ...
def as_sql(self, compiler: SQLCompiler = ..., connection: Any = ...) -> Any: ... def as_sql(self, compiler: Optional[SQLCompiler] = ..., connection: Optional[BaseDatabaseWrapper] = ...) -> Any: ...
class ExtraWhere: class ExtraWhere:
contains_aggregate: bool = ... contains_aggregate: bool = ...
sqls: List[str] = ... sqls: List[str] = ...
params: Optional[Union[List[int], List[str]]] = ... params: Optional[Union[List[int], List[str]]] = ...
def __init__(self, sqls: List[str], params: Optional[Union[List[int], List[str]]]) -> None: ... def __init__(self, sqls: List[str], params: Optional[Union[List[int], List[str]]]) -> None: ...
def as_sql(self, compiler: SQLCompiler = ..., connection: Any = ...) -> Tuple[str, Union[List[int], List[str]]]: ... def as_sql(
self, compiler: Optional[SQLCompiler] = ..., connection: Optional[BaseDatabaseWrapper] = ...
) -> Tuple[str, Union[List[int], List[str]]]: ...
class SubqueryConstraint: class SubqueryConstraint:
contains_aggregate: bool = ... contains_aggregate: bool = ...
@@ -43,4 +46,4 @@ class SubqueryConstraint:
targets: List[str] = ... targets: List[str] = ...
query_object: Query = ... query_object: Query = ...
def __init__(self, alias: str, columns: List[str], targets: List[str], query_object: Query) -> None: ... def __init__(self, alias: str, columns: List[str], targets: List[str], query_object: Query) -> None: ...
def as_sql(self, compiler: SQLCompiler, connection: Any) -> Tuple[str, Tuple]: ... def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, Tuple]: ...

View File

@@ -1,4 +1,8 @@
from typing import Any, Dict, Iterable, List, Optional from typing import Any, Dict, Iterable, List, Optional, Iterator, Type
from django.apps import AppConfig
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import Model
DEFAULT_DB_ALIAS: str DEFAULT_DB_ALIAS: str
DJANGO_VERSION_PICKLE_KEY: str DJANGO_VERSION_PICKLE_KEY: str
@@ -22,14 +26,22 @@ class ConnectionHandler:
def __init__(self, databases: Dict[str, Dict[str, Optional[Any]]] = ...) -> None: ... def __init__(self, databases: Dict[str, Dict[str, Optional[Any]]] = ...) -> None: ...
def ensure_defaults(self, alias: str) -> None: ... def ensure_defaults(self, alias: str) -> None: ...
def prepare_test_settings(self, alias: str) -> None: ... def prepare_test_settings(self, alias: str) -> None: ...
def __getitem__(self, alias: str) -> Any: ... def __getitem__(self, alias: str) -> BaseDatabaseWrapper: ...
def __setitem__(self, key: Any, value: Any) -> None: ... def __setitem__(self, key: str, value: BaseDatabaseWrapper) -> None: ...
def __delitem__(self, key: Any) -> None: ... def __delitem__(self, key: BaseDatabaseWrapper) -> None: ...
def __iter__(self): ... def __iter__(self) -> Iterator[str]: ...
def all(self) -> List[Any]: ... def all(self) -> List[BaseDatabaseWrapper]: ...
def close_all(self) -> None: ... def close_all(self) -> None: ...
class ConnectionRouter: class ConnectionRouter:
def __init__(self, routers: Optional[Iterable[Any]] = ...) -> None: ... def __init__(self, routers: Optional[Iterable[Any]] = ...) -> None: ...
@property @property
def routers(self) -> List[Any]: ... def routers(self) -> List[Any]: ...
def db_for_read(self, model: Type[Model], **hints: Any) -> str: ...
def db_for_write(self, model: Type[Model], **hints: Any) -> str: ...
def allow_relation(self, obj1: Model, obj2: Model, **hints: Any) -> bool: ...
def allow_migrate(self, db: str, app_label: str, **hints: Any) -> bool: ...
def allow_migrate_model(self, db: str, model: Type[Model]) -> bool: ...
def get_migratable_models(
self, app_config: AppConfig, db: str, include_auto_created: bool = ...
) -> List[Type[Model]]: ...

View File

@@ -20,7 +20,7 @@ from django.db import connections as connections # noqa: F401
class _AssertNumQueriesContext(CaptureQueriesContext): class _AssertNumQueriesContext(CaptureQueriesContext):
test_case: SimpleTestCase = ... test_case: SimpleTestCase = ...
num: int = ... num: int = ...
def __init__(self, test_case: Any, num: Any, connection: Any) -> None: ... def __init__(self, test_case: Any, num: Any, connection: BaseDatabaseWrapper) -> None: ...
class _AssertTemplateUsedContext: class _AssertTemplateUsedContext:
test_case: SimpleTestCase = ... test_case: SimpleTestCase = ...
@@ -200,7 +200,7 @@ class LiveServerThread(threading.Thread):
is_ready: threading.Event = ... is_ready: threading.Event = ...
error: Optional[ImproperlyConfigured] = ... error: Optional[ImproperlyConfigured] = ...
static_handler: Type[WSGIHandler] = ... static_handler: Type[WSGIHandler] = ...
connections_override: Dict[str, Any] = ... connections_override: Dict[str, BaseDatabaseWrapper] = ...
def __init__( def __init__(
self, self,
host: str, host: str,

View File

@@ -27,6 +27,7 @@ from django.test.runner import DiscoverRunner
from django.test.testcases import SimpleTestCase from django.test.testcases import SimpleTestCase
from django.conf import LazySettings, Settings from django.conf import LazySettings, Settings
from django.db.backends.base.base import BaseDatabaseWrapper
_TestClass = Type[SimpleTestCase] _TestClass = Type[SimpleTestCase]
_DecoratedTest = Union[Callable, _TestClass] _DecoratedTest = Union[Callable, _TestClass]
@@ -84,11 +85,11 @@ class override_system_checks(TestContextDecorator):
old_deployment_checks: Set[Callable] = ... old_deployment_checks: Set[Callable] = ...
class CaptureQueriesContext: class CaptureQueriesContext:
connection: Any = ... connection: BaseDatabaseWrapper = ...
force_debug_cursor: bool = ... force_debug_cursor: bool = ...
initial_queries: int = ... initial_queries: int = ...
final_queries: Optional[int] = ... final_queries: Optional[int] = ...
def __init__(self, connection: Any) -> None: ... def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
def __iter__(self): ... def __iter__(self): ...
def __getitem__(self, index: int) -> Dict[str, str]: ... def __getitem__(self, index: int) -> Dict[str, str]: ...
def __len__(self) -> int: ... def __len__(self) -> int: ...

View File

@@ -0,0 +1,13 @@
- case: raw_default_connection
main: |
from django.db import connection
with connection.cursor() as cursor:
reveal_type(cursor) # N: Revealed type is 'django.db.backends.utils.CursorWrapper'
cursor.execute("SELECT %s", [123])
- case: raw_connections
main: |
from django.db import connections
reveal_type(connections["test"]) # N: Revealed type is 'django.db.backends.base.base.BaseDatabaseWrapper'
for connection in connections.all():
with connection.cursor() as cursor:
reveal_type(cursor) # N: Revealed type is 'django.db.backends.utils.CursorWrapper'