From 45f6dc0362a032af4e8f9d6e7eac4de583dcab5f Mon Sep 17 00:00:00 2001 From: Marti Raudsepp Date: Thu, 13 May 2021 20:22:35 +0300 Subject: [PATCH] Improve hints for database connections (DatabaseWrapper) (#612) * `django.db.{connection, connections, router}` are now hinted -- including `ConnectionHandler` and `ConnectionRouter` classes. * Several improvements to `BaseDatabaseWrapper` attribute hints. * In many places, database connections were hinted as `Any`, which I changed to `BaseDatabaseWrapper`. * In a few places I added additional `SQLCompiler` hints. * Minor tweaks to nearby code. --- .../core/management/commands/inspectdb.pyi | 5 +- .../core/management/commands/migrate.pyi | 3 +- .../management/commands/showmigrations.pyi | 5 +- django-stubs/core/management/sql.pyi | 7 ++- django-stubs/db/__init__.pyi | 12 +++-- django-stubs/db/backends/base/base.pyi | 54 ++++++++++--------- django-stubs/db/backends/base/client.pyi | 2 +- django-stubs/db/backends/base/creation.pyi | 2 +- django-stubs/db/backends/base/features.pyi | 2 +- .../db/backends/base/introspection.pyi | 2 +- django-stubs/db/backends/base/operations.pyi | 12 ++--- django-stubs/db/backends/base/schema.pyi | 5 +- django-stubs/db/backends/base/validation.pyi | 2 +- django-stubs/db/migrations/executor.pyi | 5 +- django-stubs/db/migrations/loader.pyi | 9 ++-- django-stubs/db/migrations/recorder.pyi | 4 +- django-stubs/db/models/expressions.pyi | 21 +++++--- django-stubs/db/models/fields/__init__.pyi | 14 ++--- django-stubs/db/models/fields/json.pyi | 4 +- django-stubs/db/models/lookups.pyi | 16 +++--- django-stubs/db/models/query_utils.pyi | 5 +- django-stubs/db/models/sql/compiler.pyi | 8 +-- django-stubs/db/models/sql/datastructures.pyi | 5 +- django-stubs/db/models/sql/query.pyi | 10 ++-- django-stubs/db/models/sql/where.pyi | 11 ++-- django-stubs/db/utils.pyi | 24 ++++++--- django-stubs/test/testcases.pyi | 4 +- django-stubs/test/utils.pyi | 5 +- tests/typecheck/db/test_connection.yml | 13 +++++ 29 files changed, 166 insertions(+), 105 deletions(-) create mode 100644 tests/typecheck/db/test_connection.yml diff --git a/django-stubs/core/management/commands/inspectdb.pyi b/django-stubs/core/management/commands/inspectdb.pyi index de0e24d..1777997 100644 --- a/django-stubs/core/management/commands/inspectdb.pyi +++ b/django-stubs/core/management/commands/inspectdb.pyi @@ -1,5 +1,6 @@ from django.core.management.base import BaseCommand as BaseCommand, CommandError as CommandError from django.db import DEFAULT_DB_ALIAS as DEFAULT_DB_ALIAS, connections as connections +from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models.constants import LOOKUP_SEP as LOOKUP_SEP from typing import Any, Dict, Iterable, List, Tuple @@ -10,7 +11,9 @@ class Command(BaseCommand): def normalize_col_name( self, col_name: str, used_column_names: List[str], is_relation: bool ) -> Tuple[str, Dict[str, str], List[str]]: ... - def get_field_type(self, connection: Any, table_name: Any, row: Any) -> Tuple[str, Dict[str, str], List[str]]: ... + def get_field_type( + self, connection: BaseDatabaseWrapper, table_name: str, row: Any + ) -> Tuple[str, Dict[str, str], List[str]]: ... def get_meta( self, table_name: str, constraints: Any, column_to_field_name: Any, is_view: Any, is_partition: Any ) -> List[str]: ... diff --git a/django-stubs/core/management/commands/migrate.pyi b/django-stubs/core/management/commands/migrate.pyi index 17c432e..e0b709a 100644 --- a/django-stubs/core/management/commands/migrate.pyi +++ b/django-stubs/core/management/commands/migrate.pyi @@ -9,6 +9,7 @@ from django.core.management.sql import ( emit_pre_migrate_signal as emit_pre_migrate_signal, ) from django.db import DEFAULT_DB_ALIAS as DEFAULT_DB_ALIAS, connections as connections, router as router +from django.db.backends.base.base import BaseDatabaseWrapper from django.db.migrations.autodetector import MigrationAutodetector as MigrationAutodetector from django.db.migrations.executor import MigrationExecutor as MigrationExecutor from django.db.migrations.loader import AmbiguityError as AmbiguityError @@ -23,6 +24,6 @@ class Command(BaseCommand): interactive: bool = ... start: float = ... def migration_progress_callback(self, action: str, migration: Optional[Any] = ..., fake: bool = ...) -> None: ... - def sync_apps(self, connection: Any, app_labels: List[str]) -> None: ... + def sync_apps(self, connection: BaseDatabaseWrapper, app_labels: List[str]) -> None: ... @staticmethod def describe_operation(operation: Operation, backwards: bool) -> str: ... diff --git a/django-stubs/core/management/commands/showmigrations.pyi b/django-stubs/core/management/commands/showmigrations.pyi index 52272e6..17313df 100644 --- a/django-stubs/core/management/commands/showmigrations.pyi +++ b/django-stubs/core/management/commands/showmigrations.pyi @@ -1,10 +1,11 @@ from django.apps import apps as apps from django.core.management.base import BaseCommand as BaseCommand from django.db import DEFAULT_DB_ALIAS as DEFAULT_DB_ALIAS, connections as connections +from django.db.backends.base.base import BaseDatabaseWrapper from django.db.migrations.loader import MigrationLoader as MigrationLoader from typing import Any, List, Optional class Command(BaseCommand): verbosity: int = ... - def show_list(self, connection: Any, app_names: Optional[List[str]] = ...) -> None: ... - def show_plan(self, connection: Any, app_names: Optional[List[str]] = ...) -> None: ... + def show_list(self, connection: BaseDatabaseWrapper, app_names: Optional[List[str]] = ...) -> None: ... + def show_plan(self, connection: BaseDatabaseWrapper, app_names: Optional[List[str]] = ...) -> None: ... diff --git a/django-stubs/core/management/sql.pyi b/django-stubs/core/management/sql.pyi index 8798e12..9ebb4b9 100644 --- a/django-stubs/core/management/sql.pyi +++ b/django-stubs/core/management/sql.pyi @@ -1,9 +1,14 @@ from typing import Any, List from django.core.management.color import Style +from django.db.backends.base.base import BaseDatabaseWrapper def sql_flush( - style: Style, connection: Any, only_django: bool = ..., reset_sequences: bool = ..., allow_cascade: bool = ... + style: Style, + connection: BaseDatabaseWrapper, + only_django: bool = ..., + reset_sequences: bool = ..., + allow_cascade: bool = ..., ) -> List[str]: ... def emit_pre_migrate_signal(verbosity: int, interactive: bool, db: str, **kwargs: Any) -> None: ... def emit_post_migrate_signal(verbosity: int, interactive: bool, db: str, **kwargs: Any) -> None: ... diff --git a/django-stubs/db/__init__.pyi b/django-stubs/db/__init__.pyi index 5211356..b823f1c 100644 --- a/django-stubs/db/__init__.pyi +++ b/django-stubs/db/__init__.pyi @@ -1,5 +1,6 @@ from typing import Any +from .backends.base.base import BaseDatabaseWrapper from .utils import ( DEFAULT_DB_ALIAS as DEFAULT_DB_ALIAS, DJANGO_VERSION_PICKLE_KEY as DJANGO_VERSION_PICKLE_KEY, @@ -11,16 +12,19 @@ from .utils import ( NotSupportedError as NotSupportedError, InternalError as InternalError, InterfaceError as InterfaceError, - ConnectionHandler as ConnectionHandler, Error as Error, ConnectionDoesNotExist as ConnectionDoesNotExist, + # Not exported in __all__ + ConnectionHandler, + ConnectionRouter, ) from . import migrations -connections: Any -router: Any -connection: Any +connections: ConnectionHandler +router: ConnectionRouter +# Actually DefaultConnectionProxy, but quacks exactly like BaseDatabaseWrapper, it's not worth distinguishing the two. +connection: BaseDatabaseWrapper class DefaultConnectionProxy: def __getattr__(self, item: str) -> Any: ... diff --git a/django-stubs/db/backends/base/base.pyi b/django-stubs/db/backends/base/base.pyi index eafdad5..2bec936 100644 --- a/django-stubs/db/backends/base/base.pyi +++ b/django-stubs/db/backends/base/base.pyi @@ -1,4 +1,5 @@ -from typing import Any, Callable, Dict, Iterator, List, Optional +from datetime import tzinfo +from typing import Any, Callable, Dict, Iterator, List, Optional, Type, TypeVar from django.db.backends.base.client import BaseDatabaseClient from django.db.backends.base.creation import BaseDatabaseCreation @@ -11,25 +12,29 @@ from django.db.backends.base.features import BaseDatabaseFeatures from django.db.backends.base.introspection import BaseDatabaseIntrospection +from django.db.backends.base.operations import BaseDatabaseOperations + NO_DB_ALIAS: str +_T = TypeVar("_T", bound="BaseDatabaseWrapper") +_ExecuteWrapper = Callable[[Callable[[str, Any, bool, Dict[str, Any]], Any], str, Any, bool, Dict[str, Any]], Any] + class BaseDatabaseWrapper: - data_types: Any = ... - data_types_suffix: Any = ... - data_type_check_constraints: Any = ... - ops: Any = ... + data_types: Dict[str, str] = ... + data_types_suffix: Dict[str, str] = ... + data_type_check_constraints: Dict[str, str] = ... vendor: str = ... display_name: str = ... - SchemaEditorClass: Optional[BaseDatabaseSchemaEditor] = ... - client_class: Any = ... - creation_class: Any = ... - features_class: Any = ... - introspection_class: Any = ... - ops_class: Any = ... - validation_class: Any = ... + SchemaEditorClass: Type[BaseDatabaseSchemaEditor] = ... + client_class: Type[BaseDatabaseClient] = ... + creation_class: Type[BaseDatabaseCreation] = ... + features_class: Type[BaseDatabaseFeatures] = ... + introspection_class: Type[BaseDatabaseIntrospection] = ... + ops_class: Type[BaseDatabaseOperations] = ... + validation_class: Type[BaseDatabaseValidation] = ... queries_limit: int = ... connection: Any = ... - settings_dict: Any = ... + settings_dict: Dict[str, Any] = ... alias: str = ... queries_log: Any = ... force_debug_cursor: bool = ... @@ -39,24 +44,23 @@ class BaseDatabaseWrapper: savepoint_ids: Any = ... commit_on_exit: bool = ... needs_rollback: bool = ... - close_at: Optional[Any] = ... + close_at: Optional[float] = ... closed_in_transaction: bool = ... errors_occurred: bool = ... allow_thread_sharing: bool = ... - run_on_commit: List[Any] = ... + run_on_commit: List[Callable[[], None]] = ... run_commit_hooks_on_set_autocommit_on: bool = ... - execute_wrappers: List[Any] = ... + execute_wrappers: List[_ExecuteWrapper] = ... client: BaseDatabaseClient = ... creation: BaseDatabaseCreation = ... features: BaseDatabaseFeatures = ... introspection: BaseDatabaseIntrospection = ... + ops: BaseDatabaseOperations = ... validation: BaseDatabaseValidation = ... - def __init__( - self, settings_dict: Dict[str, Dict[str, str]], alias: str = ..., allow_thread_sharing: bool = ... - ) -> None: ... + def __init__(self, settings_dict: Dict[str, Any], alias: str = ..., allow_thread_sharing: bool = ...) -> None: ... def ensure_timezone(self) -> bool: ... - def timezone(self): ... - def timezone_name(self): ... + def timezone(self) -> tzinfo: ... + def timezone_name(self) -> str: ... @property def queries_logged(self) -> bool: ... @property @@ -86,7 +90,7 @@ class BaseDatabaseWrapper: def disable_constraint_checking(self): ... def enable_constraint_checking(self) -> None: ... def check_constraints(self, table_names: Optional[Any] = ...) -> None: ... - def is_usable(self) -> None: ... + def is_usable(self) -> bool: ... def close_if_unusable_or_obsolete(self) -> None: ... def validate_thread_sharing(self) -> None: ... def prepare_database(self) -> None: ... @@ -96,7 +100,7 @@ class BaseDatabaseWrapper: def make_cursor(self, cursor: CursorWrapper) -> CursorWrapper: ... def temporary_connection(self) -> None: ... def schema_editor(self, *args: Any, **kwargs: Any) -> BaseDatabaseSchemaEditor: ... - def on_commit(self, func: Callable) -> None: ... + def on_commit(self, func: Callable[[], None]) -> None: ... def run_and_clear_commit_hooks(self) -> None: ... - def execute_wrapper(self, wrapper: Callable) -> Iterator[None]: ... - def copy(self, alias: None = ..., allow_thread_sharing: None = ...) -> Any: ... + def execute_wrapper(self, wrapper: _ExecuteWrapper) -> Iterator[None]: ... + def copy(self: _T, alias: Optional[str] = ...) -> _T: ... diff --git a/django-stubs/db/backends/base/client.pyi b/django-stubs/db/backends/base/client.pyi index 6fdf5ce..eaebd92 100644 --- a/django-stubs/db/backends/base/client.pyi +++ b/django-stubs/db/backends/base/client.pyi @@ -4,6 +4,6 @@ from django.db.backends.base.base import BaseDatabaseWrapper class BaseDatabaseClient: executable_name: Any = ... - connection: Any = ... + connection: BaseDatabaseWrapper def __init__(self, connection: BaseDatabaseWrapper) -> None: ... def runshell(self) -> None: ... diff --git a/django-stubs/db/backends/base/creation.pyi b/django-stubs/db/backends/base/creation.pyi index 8904232..f904624 100644 --- a/django-stubs/db/backends/base/creation.pyi +++ b/django-stubs/db/backends/base/creation.pyi @@ -5,7 +5,7 @@ from django.db.backends.base.base import BaseDatabaseWrapper TEST_DATABASE_PREFIX: str class BaseDatabaseCreation: - connection: Any = ... + connection: BaseDatabaseWrapper def __init__(self, connection: BaseDatabaseWrapper) -> None: ... def create_test_db( self, verbosity: int = ..., autoclobber: bool = ..., serialize: bool = ..., keepdb: bool = ... diff --git a/django-stubs/db/backends/base/features.pyi b/django-stubs/db/backends/base/features.pyi index 88032c2..d7b382a 100644 --- a/django-stubs/db/backends/base/features.pyi +++ b/django-stubs/db/backends/base/features.pyi @@ -94,7 +94,7 @@ class BaseDatabaseFeatures: db_functions_convert_bytes_to_str: bool = ... supported_explain_formats: Any = ... validates_explain_options: bool = ... - connection: Any = ... + connection: BaseDatabaseWrapper def __init__(self, connection: BaseDatabaseWrapper) -> None: ... def supports_explaining_query_execution(self) -> bool: ... def supports_transactions(self): ... diff --git a/django-stubs/db/backends/base/introspection.pyi b/django-stubs/db/backends/base/introspection.pyi index 33d128f..3bcd9a0 100644 --- a/django-stubs/db/backends/base/introspection.pyi +++ b/django-stubs/db/backends/base/introspection.pyi @@ -11,7 +11,7 @@ FieldInfo = namedtuple("FieldInfo", "name type_code display_size internal_size p class BaseDatabaseIntrospection: data_types_reverse: Any = ... - connection: Any = ... + connection: BaseDatabaseWrapper def __init__(self, connection: BaseDatabaseWrapper) -> None: ... def get_field_type(self, data_type: str, description: FieldInfo) -> str: ... def table_name_converter(self, name: str) -> str: ... diff --git a/django-stubs/db/backends/base/operations.pyi b/django-stubs/db/backends/base/operations.pyi index fa9905c..a1a2779 100644 --- a/django-stubs/db/backends/base/operations.pyi +++ b/django-stubs/db/backends/base/operations.pyi @@ -8,12 +8,8 @@ from django.db.backends.utils import CursorWrapper from django.db.models.base import Model from django.db.models.expressions import Case, Expression from django.db.models.sql.compiler import SQLCompiler - -from django.db import DefaultConnectionProxy from django.db.models.fields import Field -_Connection = Union[DefaultConnectionProxy, BaseDatabaseWrapper] - class BaseDatabaseOperations: compiler_module: str = ... integer_field_ranges: Any = ... @@ -25,9 +21,9 @@ class BaseDatabaseOperations: UNBOUNDED_PRECEDING: Any = ... UNBOUNDED_FOLLOWING: Any = ... CURRENT_ROW: str = ... - explain_prefix: Any = ... - connection: _Connection = ... - def __init__(self, connection: Optional[_Connection]) -> None: ... + explain_prefix: Optional[str] = ... + connection: BaseDatabaseWrapper + def __init__(self, connection: BaseDatabaseWrapper) -> None: ... def autoinc_sql(self, table: str, column: str) -> None: ... def bulk_batch_size(self, fields: Any, objs: Any): ... def cache_key_culling_sql(self) -> str: ... @@ -88,7 +84,7 @@ class BaseDatabaseOperations: def year_lookup_bounds_for_datetime_field(self, value: int) -> List[str]: ... def get_db_converters(self, expression: Expression) -> List[Any]: ... def convert_durationfield_value( - self, value: Optional[float], expression: Expression, connection: _Connection + self, value: Optional[float], expression: Expression, connection: BaseDatabaseWrapper ) -> Optional[timedelta]: ... def check_expression_support(self, expression: Any) -> None: ... def combine_expression(self, connector: str, sub_expressions: List[str]) -> str: ... diff --git a/django-stubs/db/backends/base/schema.pyi b/django-stubs/db/backends/base/schema.pyi index 4da4248..d4c0967 100644 --- a/django-stubs/db/backends/base/schema.pyi +++ b/django-stubs/db/backends/base/schema.pyi @@ -1,5 +1,6 @@ from typing import Any, ContextManager, List, Optional, Sequence, Tuple, Type, Union +from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.ddl_references import Statement from django.db.models.base import Model from django.db.models.indexes import Index @@ -35,11 +36,11 @@ class BaseDatabaseSchemaEditor(ContextManager[Any]): sql_create_pk: str = ... sql_delete_pk: str = ... sql_delete_procedure: str = ... - connection: Any = ... + connection: BaseDatabaseWrapper = ... collect_sql: bool = ... collected_sql: Any = ... atomic_migration: Any = ... - def __init__(self, connection: Any, collect_sql: bool = ..., atomic: bool = ...) -> None: ... + def __init__(self, connection: BaseDatabaseWrapper, collect_sql: bool = ..., atomic: bool = ...) -> None: ... deferred_sql: Any = ... atomic: Any = ... def __enter__(self) -> BaseDatabaseSchemaEditor: ... diff --git a/django-stubs/db/backends/base/validation.pyi b/django-stubs/db/backends/base/validation.pyi index 382f4c1..706c19c 100644 --- a/django-stubs/db/backends/base/validation.pyi +++ b/django-stubs/db/backends/base/validation.pyi @@ -5,7 +5,7 @@ from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models.fields import Field class BaseDatabaseValidation: - connection: Any = ... + connection: BaseDatabaseWrapper = ... def __init__(self, connection: BaseDatabaseWrapper) -> None: ... def check(self, **kwargs: Any) -> List[Any]: ... def check_field(self, field: Field, **kwargs: Any) -> List[Any]: ... diff --git a/django-stubs/db/migrations/executor.pyi b/django-stubs/db/migrations/executor.pyi index dc759d9..186e64a 100644 --- a/django-stubs/db/migrations/executor.pyi +++ b/django-stubs/db/migrations/executor.pyi @@ -1,6 +1,5 @@ from typing import Any, Callable, List, Optional, Set, Tuple, Union -from django.db import DefaultConnectionProxy from django.db.backends.base.base import BaseDatabaseWrapper from django.db.migrations.migration import Migration @@ -9,13 +8,13 @@ from .recorder import MigrationRecorder from .state import ProjectState class MigrationExecutor: - connection: Any = ... + connection: BaseDatabaseWrapper = ... loader: MigrationLoader = ... recorder: MigrationRecorder = ... progress_callback: Callable = ... def __init__( self, - connection: Optional[Union[DefaultConnectionProxy, BaseDatabaseWrapper]], + connection: Optional[BaseDatabaseWrapper], progress_callback: Optional[Callable] = ..., ) -> None: ... def migration_plan( diff --git a/django-stubs/db/migrations/loader.pyi b/django-stubs/db/migrations/loader.pyi index 0713719..7492be1 100644 --- a/django-stubs/db/migrations/loader.pyi +++ b/django-stubs/db/migrations/loader.pyi @@ -1,5 +1,6 @@ from typing import Any, Dict, Optional, Sequence, Set, Tuple, Union +from django.db.backends.base.base import BaseDatabaseWrapper from django.db.migrations.migration import Migration from django.db.migrations.state import ProjectState @@ -13,11 +14,13 @@ from .exceptions import ( MIGRATIONS_MODULE_NAME: str class MigrationLoader: - connection: Any = ... + connection: Optional[BaseDatabaseWrapper] = ... disk_migrations: Dict[Tuple[str, str], Migration] = ... applied_migrations: Set[Tuple[str, str]] = ... ignore_no_migrations: bool = ... - def __init__(self, connection: Any, load: bool = ..., ignore_no_migrations: bool = ...) -> None: ... + def __init__( + self, connection: Optional[BaseDatabaseWrapper], load: bool = ..., ignore_no_migrations: bool = ... + ) -> None: ... @classmethod def migrations_module(cls, app_label: str) -> Tuple[Optional[str], bool]: ... unmigrated_apps: Set[str] = ... @@ -31,7 +34,7 @@ class MigrationLoader: graph: Any = ... replacements: Any = ... def build_graph(self) -> None: ... - def check_consistent_history(self, connection: Any) -> None: ... + def check_consistent_history(self, connection: BaseDatabaseWrapper) -> None: ... def detect_conflicts(self) -> Dict[str, Set[str]]: ... def project_state( self, nodes: Optional[Union[Tuple[str, str], Sequence[Tuple[str, str]]]] = ..., at_end: bool = ... diff --git a/django-stubs/db/migrations/recorder.pyi b/django-stubs/db/migrations/recorder.pyi index e2918e4..7a96322 100644 --- a/django-stubs/db/migrations/recorder.pyi +++ b/django-stubs/db/migrations/recorder.pyi @@ -10,8 +10,8 @@ class MigrationRecorder: app: Any = ... name: Any = ... applied: Any = ... - connection: Optional[BaseDatabaseWrapper] = ... - def __init__(self, connection: Optional[BaseDatabaseWrapper]) -> None: ... + connection: BaseDatabaseWrapper = ... + def __init__(self, connection: BaseDatabaseWrapper) -> None: ... @property def migration_qs(self) -> QuerySet: ... def has_table(self) -> bool: ... diff --git a/django-stubs/db/models/expressions.pyi b/django-stubs/db/models/expressions.pyi index e4372ac..0b3b50e 100644 --- a/django-stubs/db/models/expressions.pyi +++ b/django-stubs/db/models/expressions.pyi @@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, from django.db.models.lookups import Lookup from django.db.models.sql.compiler import SQLCompiler +from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models import Q, QuerySet from django.db.models.fields import Field from django.db.models.query import _BaseQuerySet @@ -12,7 +13,9 @@ from django.db.models.query import _BaseQuerySet _OutputField = Union[Field, str] class SQLiteNumericMixin: - def as_sqlite(self, compiler: SQLCompiler, connection: Any, **extra_context: Any) -> Tuple[str, List[float]]: ... + def as_sqlite( + self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any + ) -> Tuple[str, List[float]]: ... _Self = TypeVar("_Self") @@ -57,7 +60,7 @@ class BaseExpression: filterable: bool = ... window_compatible: bool = ... def __init__(self, output_field: Optional[_OutputField] = ...) -> None: ... - def get_db_converters(self, connection: Any) -> List[Callable]: ... + def get_db_converters(self, connection: BaseDatabaseWrapper) -> List[Callable]: ... def get_source_expressions(self) -> List[Any]: ... def set_source_expressions(self, exprs: Sequence[Combinable]) -> None: ... @property @@ -91,11 +94,11 @@ class BaseExpression: def reverse_ordering(self): ... def flatten(self) -> Iterator[Expression]: ... def deconstruct(self) -> Any: ... - def as_sqlite(self, compiler: SQLCompiler, connection: Any) -> Any: ... - def as_sql(self, compiler: SQLCompiler, connection: Any, **extra_context: Any) -> Any: ... - def as_mysql(self, compiler: Any, connection: Any) -> Any: ... - def as_postgresql(self, compiler: Any, connection: Any) -> Any: ... - def as_oracle(self, compiler: Any, connection: Any): ... + def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ... + def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> Any: ... + def as_mysql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ... + def as_postgresql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ... + def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ... class Expression(BaseExpression, Combinable): ... @@ -216,7 +219,9 @@ class WindowFrame(Expression): template: str = ... frame_type: str = ... def __init__(self, start: Optional[int] = ..., end: Optional[int] = ...) -> None: ... - def window_frame_start_end(self, connection: Any, start: Optional[int], end: Optional[int]) -> Tuple[int, int]: ... + def window_frame_start_end( + self, connection: BaseDatabaseWrapper, start: Optional[int], end: Optional[int] + ) -> Tuple[int, int]: ... class RowRange(WindowFrame): ... class ValueRange(WindowFrame): ... diff --git a/django-stubs/db/models/fields/__init__.pyi b/django-stubs/db/models/fields/__init__.pyi index 3013355..6dc2a5c 100644 --- a/django-stubs/db/models/fields/__init__.pyi +++ b/django-stubs/db/models/fields/__init__.pyi @@ -18,9 +18,9 @@ from typing import ( ) from django.core.checks import CheckMessage - -from django.db.models import Model from django.core.exceptions import FieldDoesNotExist as FieldDoesNotExist +from django.db.backends.base.base import BaseDatabaseWrapper +from django.db.models import Model from django.db.models.expressions import Combinable, Col from django.db.models.query_utils import RegisterLookupMixin from django.forms import Field as FormField, Widget @@ -110,12 +110,12 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]): def __get__(self: _T, instance, owner) -> _T: ... def deconstruct(self) -> Any: ... def set_attributes_from_name(self, name: str) -> None: ... - def db_type(self, connection: Any) -> str: ... - def db_parameters(self, connection: Any) -> Dict[str, str]: ... + def db_type(self, connection: BaseDatabaseWrapper) -> str: ... + def db_parameters(self, connection: BaseDatabaseWrapper) -> Dict[str, str]: ... def pre_save(self, model_instance: Model, add: bool) -> Any: ... def get_prep_value(self, value: Any) -> Any: ... - def get_db_prep_value(self, value: Any, connection: Any, prepared: bool) -> Any: ... - def get_db_prep_save(self, value: Any, connection: Any) -> Any: ... + def get_db_prep_value(self, value: Any, connection: BaseDatabaseWrapper, prepared: bool) -> Any: ... + def get_db_prep_save(self, value: Any, connection: BaseDatabaseWrapper) -> Any: ... def get_internal_type(self) -> str: ... # TODO: plugin support def formfield(self, **kwargs) -> Any: ... @@ -150,7 +150,7 @@ class IntegerField(Field[_ST, _GT]): _pyi_lookup_exact_type: Union[str, int] class PositiveIntegerRelDbTypeMixin: - def rel_db_type(self, connection: Any): ... + def rel_db_type(self, connection: BaseDatabaseWrapper) -> str: ... class PositiveIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField[_ST, _GT]): ... class PositiveSmallIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField[_ST, _GT]): ... diff --git a/django-stubs/db/models/fields/json.pyi b/django-stubs/db/models/fields/json.pyi index dc4484c..9dbdffb 100644 --- a/django-stubs/db/models/fields/json.pyi +++ b/django-stubs/db/models/fields/json.pyi @@ -1,7 +1,9 @@ from . import Field from .mixins import CheckFieldDefaultMixin +from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models import lookups from django.db.models.lookups import PostgresOperatorLookup, Transform +from django.db.models.sql.compiler import SQLCompiler from typing import Any, Optional, Callable class JSONField(CheckFieldDefaultMixin, Field): @@ -25,7 +27,7 @@ class JSONExact(lookups.Exact): ... class KeyTransform(Transform): key_name: Any = ... def __init__(self, key_name: Any, *args: Any, **kwargs: Any) -> None: ... - def preprocess_lhs(self, compiler: Any, connection: Any, lhs_only: bool = ...): ... + def preprocess_lhs(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs_only: bool = ...) -> Any: ... class KeyTextTransform(KeyTransform): ... diff --git a/django-stubs/db/models/lookups.pyi b/django-stubs/db/models/lookups.pyi index 58040ed..3f9230c 100644 --- a/django-stubs/db/models/lookups.pyi +++ b/django-stubs/db/models/lookups.pyi @@ -27,15 +27,19 @@ class Lookup(Generic[_T]): def get_source_expressions(self) -> List[Expression]: ... def set_source_expressions(self, new_exprs: List[Expression]) -> None: ... def get_prep_lookup(self) -> Any: ... - def get_db_prep_lookup(self, value: Union[int, str], connection: BaseDatabaseWrapper) -> Tuple[str, List[SafeText]]: ... + def get_db_prep_lookup( + self, value: Union[int, str], connection: BaseDatabaseWrapper + ) -> Tuple[str, List[SafeText]]: ... def process_lhs( self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Optional[Expression] = ... ) -> Tuple[str, List[Union[int, str]]]: ... - def process_rhs(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, List[Union[int, str]]]: ... + def process_rhs( + self, compiler: SQLCompiler, connection: BaseDatabaseWrapper + ) -> Tuple[str, List[Union[int, str]]]: ... def rhs_is_direct_value(self) -> bool: ... def relabeled_clone(self: _T, relabels: Mapping[str, str]) -> _T: ... def get_group_by_cols(self) -> List[Expression]: ... - def as_sql(self, compiler: Any, connection: Any) -> Any: ... + def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ... def contains_aggregate(self) -> bool: ... def contains_over_clause(self) -> bool: ... @property @@ -61,7 +65,7 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin): class PostgresOperatorLookup(FieldGetDbPrepValueMixin, Lookup): postgres_operator: str = ... - def as_postgresql(self, compiler: Any, connection: Any): ... + def as_postgresql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ... class Exact(FieldGetDbPrepValueMixin, BuiltinLookup): ... class IExact(BuiltinLookup): ... @@ -78,7 +82,7 @@ class IntegerGreaterThanOrEqual(IntegerFieldFloatRounding, GreaterThanOrEqual[Un class IntegerLessThan(IntegerFieldFloatRounding, LessThan[Union[int, float]]): ... class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup): - def split_parameter_list_as_sql(self, compiler: Any, connection: Any): ... + def split_parameter_list_as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ... class PatternLookup(BuiltinLookup[str]): param_pattern: str = ... @@ -109,7 +113,7 @@ class YearLte(YearComparisonLookup): ... class UUIDTextMixin: rhs: Any = ... - def process_rhs(self, qn: Any, connection: Any): ... + def process_rhs(self, qn: Any, connection: BaseDatabaseWrapper) -> Any: ... class UUIDIExact(UUIDTextMixin, IExact): ... class UUIDContains(UUIDTextMixin, Contains): ... diff --git a/django-stubs/db/models/query_utils.pyi b/django-stubs/db/models/query_utils.pyi index 7ee77bc..4aedc12 100644 --- a/django-stubs/db/models/query_utils.pyi +++ b/django-stubs/db/models/query_utils.pyi @@ -1,6 +1,7 @@ from collections import namedtuple from typing import Any, Collection, Dict, Iterator, List, Mapping, Optional, Sequence, Set, Tuple, Type +from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models.base import Model from django.db.models.fields.mixins import FieldCacheMixin from django.db.models.sql.compiler import SQLCompiler @@ -20,7 +21,7 @@ class QueryWrapper: contains_aggregate: bool = ... data: Tuple[str, List[Any]] = ... def __init__(self, sql: str, params: List[Any]) -> None: ... - def as_sql(self, compiler: SQLCompiler = ..., connection: Any = ...) -> Any: ... + def as_sql(self, compiler: Optional[SQLCompiler] = ..., connection: Optional[BaseDatabaseWrapper] = ...) -> Any: ... class Q(tree.Node): AND: str = ... @@ -76,4 +77,4 @@ class FilteredRelation: def __init__(self, relation_name: str, *, condition: Any = ...) -> None: ... def clone(self) -> FilteredRelation: ... def resolve_expression(self, *args: Any, **kwargs: Any) -> None: ... - def as_sql(self, compiler: SQLCompiler, connection: Any) -> Any: ... + def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ... diff --git a/django-stubs/db/models/sql/compiler.pyi b/django-stubs/db/models/sql/compiler.pyi index c0c257d..5b9d0ec 100644 --- a/django-stubs/db/models/sql/compiler.pyi +++ b/django-stubs/db/models/sql/compiler.pyi @@ -4,23 +4,25 @@ from itertools import chain from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union from uuid import UUID +from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models.base import Model from django.db.models.expressions import BaseExpression, Expression - from django.db.models.sql.query import Query, RawQuery FORCE: Any class SQLCompiler: query: Any = ... - connection: Any = ... + connection: BaseDatabaseWrapper = ... using: Any = ... quote_cache: Any = ... select: Any = ... annotation_col_map: Any = ... klass_info: Any = ... ordering_parts: Any = ... - def __init__(self, query: Union[Query, RawQuery], connection: Any, using: Optional[str]) -> None: ... + def __init__( + self, query: Union[Query, RawQuery], connection: BaseDatabaseWrapper, using: Optional[str] + ) -> None: ... col_count: Any = ... def setup_query(self) -> None: ... has_extra_select: Any = ... diff --git a/django-stubs/db/models/sql/datastructures.pyi b/django-stubs/db/models/sql/datastructures.pyi index 8ef7656..6b96f8b 100644 --- a/django-stubs/db/models/sql/datastructures.pyi +++ b/django-stubs/db/models/sql/datastructures.pyi @@ -1,6 +1,7 @@ from collections import OrderedDict from typing import Any, Dict, List, Optional, Tuple, Union +from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models.fields.mixins import FieldCacheMixin from django.db.models.query_utils import FilteredRelation, PathInfo from django.db.models.sql.compiler import SQLCompiler @@ -31,7 +32,7 @@ class Join: nullable: bool, filtered_relation: Optional[FilteredRelation] = ..., ) -> None: ... - def as_sql(self, compiler: SQLCompiler, connection: Any) -> Tuple[str, List[Union[int, str]]]: ... + def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, List[Union[int, str]]]: ... def relabeled_clone(self, change_map: Union[Dict[str, str], OrderedDict]) -> Join: ... def equals(self, other: Union[BaseTable, Join], with_filtered_relation: bool) -> bool: ... def demote(self) -> Join: ... @@ -44,6 +45,6 @@ class BaseTable: table_name: str = ... table_alias: Optional[str] = ... def __init__(self, table_name: str, alias: Optional[str]) -> None: ... - def as_sql(self, compiler: SQLCompiler, connection: Any) -> Tuple[str, List[Any]]: ... + def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, List[Any]]: ... def relabeled_clone(self, change_map: OrderedDict) -> BaseTable: ... def equals(self, other: Join, with_filtered_relation: bool) -> bool: ... diff --git a/django-stubs/db/models/sql/query.pyi b/django-stubs/db/models/sql/query.pyi index 2ac6e61..3ab8652 100644 --- a/django-stubs/db/models/sql/query.pyi +++ b/django-stubs/db/models/sql/query.pyi @@ -2,15 +2,15 @@ import collections from collections import OrderedDict, namedtuple from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, Union, Iterable +from django.db.backends.base.base import BaseDatabaseWrapper +from django.db.models import Expression, Field, FilteredRelation, Model, Q, QuerySet +from django.db.models.expressions import Combinable from django.db.models.lookups import Lookup, Transform from django.db.models.query_utils import PathInfo, RegisterLookupMixin from django.db.models.sql.compiler import SQLCompiler from django.db.models.sql.datastructures import BaseTable from django.db.models.sql.where import WhereNode -from django.db.models import Expression, Field, FilteredRelation, Model, Q, QuerySet -from django.db.models.expressions import Combinable - JoinInfo = namedtuple("JoinInfo", ["final_field", "targets", "opts", "joins", "path", "transform_function"]) class RawQuery: @@ -81,7 +81,7 @@ class Query: def has_select_fields(self) -> bool: ... def sql_with_params(self) -> Tuple[str, Tuple]: ... def __deepcopy__(self, memo: Dict[str, Any]) -> Query: ... - def get_compiler(self, using: Optional[str] = ..., connection: Any = ...) -> SQLCompiler: ... + def get_compiler(self, using: Optional[str] = ..., connection: BaseDatabaseWrapper = ...) -> SQLCompiler: ... def clone(self) -> Query: ... def chain(self, klass: Optional[Type[Query]] = ...) -> Query: ... def relabeled_clone(self, change_map: Union[Dict[Any, Any], OrderedDict]) -> Query: ... @@ -101,7 +101,7 @@ class Query: def get_initial_alias(self) -> str: ... def count_active_tables(self) -> int: ... def resolve_expression(self, query: Query, *args: Any, **kwargs: Any) -> Query: ... - def as_sql(self, compiler: SQLCompiler, connection: Any) -> Any: ... + def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ... def resolve_lookup_value(self, value: Any, can_reuse: Optional[Set[str]], allow_joins: bool) -> Any: ... def solve_lookup_type(self, lookup: str) -> Tuple[Sequence[str], Sequence[str], bool]: ... def build_filter( diff --git a/django-stubs/db/models/sql/where.pyi b/django-stubs/db/models/sql/where.pyi index dfab304..0bf2e16 100644 --- a/django-stubs/db/models/sql/where.pyi +++ b/django-stubs/db/models/sql/where.pyi @@ -1,6 +1,7 @@ from collections import OrderedDict from typing import Any, Dict, List, Optional, Tuple, Union +from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models.expressions import Expression from django.db.models.sql.compiler import SQLCompiler from django.db.models.sql.query import Query @@ -18,7 +19,7 @@ class WhereNode(tree.Node): resolved: bool = ... conditional: bool = ... def split_having(self, negated: bool = ...) -> Tuple[Optional[WhereNode], Optional[WhereNode]]: ... - def as_sql(self, compiler: SQLCompiler, connection: Any) -> Any: ... + def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ... def get_group_by_cols(self) -> List[Expression]: ... def relabel_aliases(self, change_map: Union[Dict[Optional[str], str], OrderedDict]) -> None: ... def clone(self) -> WhereNode: ... @@ -27,14 +28,16 @@ class WhereNode(tree.Node): class NothingNode: contains_aggregate: bool = ... - def as_sql(self, compiler: SQLCompiler = ..., connection: Any = ...) -> Any: ... + def as_sql(self, compiler: Optional[SQLCompiler] = ..., connection: Optional[BaseDatabaseWrapper] = ...) -> Any: ... class ExtraWhere: contains_aggregate: bool = ... sqls: List[str] = ... params: Optional[Union[List[int], List[str]]] = ... def __init__(self, sqls: List[str], params: Optional[Union[List[int], List[str]]]) -> None: ... - def as_sql(self, compiler: SQLCompiler = ..., connection: Any = ...) -> Tuple[str, Union[List[int], List[str]]]: ... + def as_sql( + self, compiler: Optional[SQLCompiler] = ..., connection: Optional[BaseDatabaseWrapper] = ... + ) -> Tuple[str, Union[List[int], List[str]]]: ... class SubqueryConstraint: contains_aggregate: bool = ... @@ -43,4 +46,4 @@ class SubqueryConstraint: targets: List[str] = ... query_object: Query = ... def __init__(self, alias: str, columns: List[str], targets: List[str], query_object: Query) -> None: ... - def as_sql(self, compiler: SQLCompiler, connection: Any) -> Tuple[str, Tuple]: ... + def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, Tuple]: ... diff --git a/django-stubs/db/utils.pyi b/django-stubs/db/utils.pyi index 998bbf8..18832e0 100644 --- a/django-stubs/db/utils.pyi +++ b/django-stubs/db/utils.pyi @@ -1,4 +1,8 @@ -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Iterator, Type + +from django.apps import AppConfig +from django.db.backends.base.base import BaseDatabaseWrapper +from django.db.models import Model DEFAULT_DB_ALIAS: str DJANGO_VERSION_PICKLE_KEY: str @@ -22,14 +26,22 @@ class ConnectionHandler: def __init__(self, databases: Dict[str, Dict[str, Optional[Any]]] = ...) -> None: ... def ensure_defaults(self, alias: str) -> None: ... def prepare_test_settings(self, alias: str) -> None: ... - def __getitem__(self, alias: str) -> Any: ... - def __setitem__(self, key: Any, value: Any) -> None: ... - def __delitem__(self, key: Any) -> None: ... - def __iter__(self): ... - def all(self) -> List[Any]: ... + def __getitem__(self, alias: str) -> BaseDatabaseWrapper: ... + def __setitem__(self, key: str, value: BaseDatabaseWrapper) -> None: ... + def __delitem__(self, key: BaseDatabaseWrapper) -> None: ... + def __iter__(self) -> Iterator[str]: ... + def all(self) -> List[BaseDatabaseWrapper]: ... def close_all(self) -> None: ... class ConnectionRouter: def __init__(self, routers: Optional[Iterable[Any]] = ...) -> None: ... @property def routers(self) -> List[Any]: ... + def db_for_read(self, model: Type[Model], **hints: Any) -> str: ... + def db_for_write(self, model: Type[Model], **hints: Any) -> str: ... + def allow_relation(self, obj1: Model, obj2: Model, **hints: Any) -> bool: ... + def allow_migrate(self, db: str, app_label: str, **hints: Any) -> bool: ... + def allow_migrate_model(self, db: str, model: Type[Model]) -> bool: ... + def get_migratable_models( + self, app_config: AppConfig, db: str, include_auto_created: bool = ... + ) -> List[Type[Model]]: ... diff --git a/django-stubs/test/testcases.pyi b/django-stubs/test/testcases.pyi index 2a70b07..eb94d1a 100644 --- a/django-stubs/test/testcases.pyi +++ b/django-stubs/test/testcases.pyi @@ -20,7 +20,7 @@ from django.db import connections as connections # noqa: F401 class _AssertNumQueriesContext(CaptureQueriesContext): test_case: SimpleTestCase = ... num: int = ... - def __init__(self, test_case: Any, num: Any, connection: Any) -> None: ... + def __init__(self, test_case: Any, num: Any, connection: BaseDatabaseWrapper) -> None: ... class _AssertTemplateUsedContext: test_case: SimpleTestCase = ... @@ -200,7 +200,7 @@ class LiveServerThread(threading.Thread): is_ready: threading.Event = ... error: Optional[ImproperlyConfigured] = ... static_handler: Type[WSGIHandler] = ... - connections_override: Dict[str, Any] = ... + connections_override: Dict[str, BaseDatabaseWrapper] = ... def __init__( self, host: str, diff --git a/django-stubs/test/utils.pyi b/django-stubs/test/utils.pyi index 03d4da1..36bec18 100644 --- a/django-stubs/test/utils.pyi +++ b/django-stubs/test/utils.pyi @@ -27,6 +27,7 @@ from django.test.runner import DiscoverRunner from django.test.testcases import SimpleTestCase from django.conf import LazySettings, Settings +from django.db.backends.base.base import BaseDatabaseWrapper _TestClass = Type[SimpleTestCase] _DecoratedTest = Union[Callable, _TestClass] @@ -84,11 +85,11 @@ class override_system_checks(TestContextDecorator): old_deployment_checks: Set[Callable] = ... class CaptureQueriesContext: - connection: Any = ... + connection: BaseDatabaseWrapper = ... force_debug_cursor: bool = ... initial_queries: int = ... final_queries: Optional[int] = ... - def __init__(self, connection: Any) -> None: ... + def __init__(self, connection: BaseDatabaseWrapper) -> None: ... def __iter__(self): ... def __getitem__(self, index: int) -> Dict[str, str]: ... def __len__(self) -> int: ... diff --git a/tests/typecheck/db/test_connection.yml b/tests/typecheck/db/test_connection.yml new file mode 100644 index 0000000..9e2e939 --- /dev/null +++ b/tests/typecheck/db/test_connection.yml @@ -0,0 +1,13 @@ +- case: raw_default_connection + main: | + from django.db import connection + with connection.cursor() as cursor: + reveal_type(cursor) # N: Revealed type is 'django.db.backends.utils.CursorWrapper' + cursor.execute("SELECT %s", [123]) +- case: raw_connections + main: | + from django.db import connections + reveal_type(connections["test"]) # N: Revealed type is 'django.db.backends.base.base.BaseDatabaseWrapper' + for connection in connections.all(): + with connection.cursor() as cursor: + reveal_type(cursor) # N: Revealed type is 'django.db.backends.utils.CursorWrapper'