mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-10 14:01:56 +08:00
Improve hints for database connections (DatabaseWrapper) (#612)
* `django.db.{connection, connections, router}` are now hinted -- including `ConnectionHandler` and `ConnectionRouter` classes.
* Several improvements to `BaseDatabaseWrapper` attribute hints.
* In many places, database connections were hinted as `Any`, which I changed to `BaseDatabaseWrapper`.
* In a few places I added additional `SQLCompiler` hints.
* Minor tweaks to nearby code.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from django.core.management.base import BaseCommand as BaseCommand, CommandError as CommandError
|
||||
from django.db import DEFAULT_DB_ALIAS as DEFAULT_DB_ALIAS, connections as connections
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.models.constants import LOOKUP_SEP as LOOKUP_SEP
|
||||
from typing import Any, Dict, Iterable, List, Tuple
|
||||
|
||||
@@ -10,7 +11,9 @@ class Command(BaseCommand):
|
||||
def normalize_col_name(
|
||||
self, col_name: str, used_column_names: List[str], is_relation: bool
|
||||
) -> Tuple[str, Dict[str, str], List[str]]: ...
|
||||
def get_field_type(self, connection: Any, table_name: Any, row: Any) -> Tuple[str, Dict[str, str], List[str]]: ...
|
||||
def get_field_type(
|
||||
self, connection: BaseDatabaseWrapper, table_name: str, row: Any
|
||||
) -> Tuple[str, Dict[str, str], List[str]]: ...
|
||||
def get_meta(
|
||||
self, table_name: str, constraints: Any, column_to_field_name: Any, is_view: Any, is_partition: Any
|
||||
) -> List[str]: ...
|
||||
|
||||
@@ -9,6 +9,7 @@ from django.core.management.sql import (
|
||||
emit_pre_migrate_signal as emit_pre_migrate_signal,
|
||||
)
|
||||
from django.db import DEFAULT_DB_ALIAS as DEFAULT_DB_ALIAS, connections as connections, router as router
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.migrations.autodetector import MigrationAutodetector as MigrationAutodetector
|
||||
from django.db.migrations.executor import MigrationExecutor as MigrationExecutor
|
||||
from django.db.migrations.loader import AmbiguityError as AmbiguityError
|
||||
@@ -23,6 +24,6 @@ class Command(BaseCommand):
|
||||
interactive: bool = ...
|
||||
start: float = ...
|
||||
def migration_progress_callback(self, action: str, migration: Optional[Any] = ..., fake: bool = ...) -> None: ...
|
||||
def sync_apps(self, connection: Any, app_labels: List[str]) -> None: ...
|
||||
def sync_apps(self, connection: BaseDatabaseWrapper, app_labels: List[str]) -> None: ...
|
||||
@staticmethod
|
||||
def describe_operation(operation: Operation, backwards: bool) -> str: ...
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from django.apps import apps as apps
|
||||
from django.core.management.base import BaseCommand as BaseCommand
|
||||
from django.db import DEFAULT_DB_ALIAS as DEFAULT_DB_ALIAS, connections as connections
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.migrations.loader import MigrationLoader as MigrationLoader
|
||||
from typing import Any, List, Optional
|
||||
|
||||
class Command(BaseCommand):
|
||||
verbosity: int = ...
|
||||
def show_list(self, connection: Any, app_names: Optional[List[str]] = ...) -> None: ...
|
||||
def show_plan(self, connection: Any, app_names: Optional[List[str]] = ...) -> None: ...
|
||||
def show_list(self, connection: BaseDatabaseWrapper, app_names: Optional[List[str]] = ...) -> None: ...
|
||||
def show_plan(self, connection: BaseDatabaseWrapper, app_names: Optional[List[str]] = ...) -> None: ...
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
from typing import Any, List
|
||||
|
||||
from django.core.management.color import Style
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
|
||||
def sql_flush(
|
||||
style: Style, connection: Any, only_django: bool = ..., reset_sequences: bool = ..., allow_cascade: bool = ...
|
||||
style: Style,
|
||||
connection: BaseDatabaseWrapper,
|
||||
only_django: bool = ...,
|
||||
reset_sequences: bool = ...,
|
||||
allow_cascade: bool = ...,
|
||||
) -> List[str]: ...
|
||||
def emit_pre_migrate_signal(verbosity: int, interactive: bool, db: str, **kwargs: Any) -> None: ...
|
||||
def emit_post_migrate_signal(verbosity: int, interactive: bool, db: str, **kwargs: Any) -> None: ...
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from .backends.base.base import BaseDatabaseWrapper
|
||||
from .utils import (
|
||||
DEFAULT_DB_ALIAS as DEFAULT_DB_ALIAS,
|
||||
DJANGO_VERSION_PICKLE_KEY as DJANGO_VERSION_PICKLE_KEY,
|
||||
@@ -11,16 +12,19 @@ from .utils import (
|
||||
NotSupportedError as NotSupportedError,
|
||||
InternalError as InternalError,
|
||||
InterfaceError as InterfaceError,
|
||||
ConnectionHandler as ConnectionHandler,
|
||||
Error as Error,
|
||||
ConnectionDoesNotExist as ConnectionDoesNotExist,
|
||||
# Not exported in __all__
|
||||
ConnectionHandler,
|
||||
ConnectionRouter,
|
||||
)
|
||||
|
||||
from . import migrations
|
||||
|
||||
connections: Any
|
||||
router: Any
|
||||
connection: Any
|
||||
connections: ConnectionHandler
|
||||
router: ConnectionRouter
|
||||
# Actually DefaultConnectionProxy, but quacks exactly like BaseDatabaseWrapper, it's not worth distinguishing the two.
|
||||
connection: BaseDatabaseWrapper
|
||||
|
||||
class DefaultConnectionProxy:
|
||||
def __getattr__(self, item: str) -> Any: ...
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional
|
||||
from datetime import tzinfo
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Type, TypeVar
|
||||
|
||||
from django.db.backends.base.client import BaseDatabaseClient
|
||||
from django.db.backends.base.creation import BaseDatabaseCreation
|
||||
@@ -11,25 +12,29 @@ from django.db.backends.base.features import BaseDatabaseFeatures
|
||||
|
||||
from django.db.backends.base.introspection import BaseDatabaseIntrospection
|
||||
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
|
||||
NO_DB_ALIAS: str
|
||||
|
||||
_T = TypeVar("_T", bound="BaseDatabaseWrapper")
|
||||
_ExecuteWrapper = Callable[[Callable[[str, Any, bool, Dict[str, Any]], Any], str, Any, bool, Dict[str, Any]], Any]
|
||||
|
||||
class BaseDatabaseWrapper:
|
||||
data_types: Any = ...
|
||||
data_types_suffix: Any = ...
|
||||
data_type_check_constraints: Any = ...
|
||||
ops: Any = ...
|
||||
data_types: Dict[str, str] = ...
|
||||
data_types_suffix: Dict[str, str] = ...
|
||||
data_type_check_constraints: Dict[str, str] = ...
|
||||
vendor: str = ...
|
||||
display_name: str = ...
|
||||
SchemaEditorClass: Optional[BaseDatabaseSchemaEditor] = ...
|
||||
client_class: Any = ...
|
||||
creation_class: Any = ...
|
||||
features_class: Any = ...
|
||||
introspection_class: Any = ...
|
||||
ops_class: Any = ...
|
||||
validation_class: Any = ...
|
||||
SchemaEditorClass: Type[BaseDatabaseSchemaEditor] = ...
|
||||
client_class: Type[BaseDatabaseClient] = ...
|
||||
creation_class: Type[BaseDatabaseCreation] = ...
|
||||
features_class: Type[BaseDatabaseFeatures] = ...
|
||||
introspection_class: Type[BaseDatabaseIntrospection] = ...
|
||||
ops_class: Type[BaseDatabaseOperations] = ...
|
||||
validation_class: Type[BaseDatabaseValidation] = ...
|
||||
queries_limit: int = ...
|
||||
connection: Any = ...
|
||||
settings_dict: Any = ...
|
||||
settings_dict: Dict[str, Any] = ...
|
||||
alias: str = ...
|
||||
queries_log: Any = ...
|
||||
force_debug_cursor: bool = ...
|
||||
@@ -39,24 +44,23 @@ class BaseDatabaseWrapper:
|
||||
savepoint_ids: Any = ...
|
||||
commit_on_exit: bool = ...
|
||||
needs_rollback: bool = ...
|
||||
close_at: Optional[Any] = ...
|
||||
close_at: Optional[float] = ...
|
||||
closed_in_transaction: bool = ...
|
||||
errors_occurred: bool = ...
|
||||
allow_thread_sharing: bool = ...
|
||||
run_on_commit: List[Any] = ...
|
||||
run_on_commit: List[Callable[[], None]] = ...
|
||||
run_commit_hooks_on_set_autocommit_on: bool = ...
|
||||
execute_wrappers: List[Any] = ...
|
||||
execute_wrappers: List[_ExecuteWrapper] = ...
|
||||
client: BaseDatabaseClient = ...
|
||||
creation: BaseDatabaseCreation = ...
|
||||
features: BaseDatabaseFeatures = ...
|
||||
introspection: BaseDatabaseIntrospection = ...
|
||||
ops: BaseDatabaseOperations = ...
|
||||
validation: BaseDatabaseValidation = ...
|
||||
def __init__(
|
||||
self, settings_dict: Dict[str, Dict[str, str]], alias: str = ..., allow_thread_sharing: bool = ...
|
||||
) -> None: ...
|
||||
def __init__(self, settings_dict: Dict[str, Any], alias: str = ..., allow_thread_sharing: bool = ...) -> None: ...
|
||||
def ensure_timezone(self) -> bool: ...
|
||||
def timezone(self): ...
|
||||
def timezone_name(self): ...
|
||||
def timezone(self) -> tzinfo: ...
|
||||
def timezone_name(self) -> str: ...
|
||||
@property
|
||||
def queries_logged(self) -> bool: ...
|
||||
@property
|
||||
@@ -86,7 +90,7 @@ class BaseDatabaseWrapper:
|
||||
def disable_constraint_checking(self): ...
|
||||
def enable_constraint_checking(self) -> None: ...
|
||||
def check_constraints(self, table_names: Optional[Any] = ...) -> None: ...
|
||||
def is_usable(self) -> None: ...
|
||||
def is_usable(self) -> bool: ...
|
||||
def close_if_unusable_or_obsolete(self) -> None: ...
|
||||
def validate_thread_sharing(self) -> None: ...
|
||||
def prepare_database(self) -> None: ...
|
||||
@@ -96,7 +100,7 @@ class BaseDatabaseWrapper:
|
||||
def make_cursor(self, cursor: CursorWrapper) -> CursorWrapper: ...
|
||||
def temporary_connection(self) -> None: ...
|
||||
def schema_editor(self, *args: Any, **kwargs: Any) -> BaseDatabaseSchemaEditor: ...
|
||||
def on_commit(self, func: Callable) -> None: ...
|
||||
def on_commit(self, func: Callable[[], None]) -> None: ...
|
||||
def run_and_clear_commit_hooks(self) -> None: ...
|
||||
def execute_wrapper(self, wrapper: Callable) -> Iterator[None]: ...
|
||||
def copy(self, alias: None = ..., allow_thread_sharing: None = ...) -> Any: ...
|
||||
def execute_wrapper(self, wrapper: _ExecuteWrapper) -> Iterator[None]: ...
|
||||
def copy(self: _T, alias: Optional[str] = ...) -> _T: ...
|
||||
|
||||
@@ -4,6 +4,6 @@ from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
|
||||
class BaseDatabaseClient:
|
||||
executable_name: Any = ...
|
||||
connection: Any = ...
|
||||
connection: BaseDatabaseWrapper
|
||||
def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
|
||||
def runshell(self) -> None: ...
|
||||
|
||||
@@ -5,7 +5,7 @@ from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
TEST_DATABASE_PREFIX: str
|
||||
|
||||
class BaseDatabaseCreation:
|
||||
connection: Any = ...
|
||||
connection: BaseDatabaseWrapper
|
||||
def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
|
||||
def create_test_db(
|
||||
self, verbosity: int = ..., autoclobber: bool = ..., serialize: bool = ..., keepdb: bool = ...
|
||||
|
||||
@@ -94,7 +94,7 @@ class BaseDatabaseFeatures:
|
||||
db_functions_convert_bytes_to_str: bool = ...
|
||||
supported_explain_formats: Any = ...
|
||||
validates_explain_options: bool = ...
|
||||
connection: Any = ...
|
||||
connection: BaseDatabaseWrapper
|
||||
def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
|
||||
def supports_explaining_query_execution(self) -> bool: ...
|
||||
def supports_transactions(self): ...
|
||||
|
||||
@@ -11,7 +11,7 @@ FieldInfo = namedtuple("FieldInfo", "name type_code display_size internal_size p
|
||||
|
||||
class BaseDatabaseIntrospection:
|
||||
data_types_reverse: Any = ...
|
||||
connection: Any = ...
|
||||
connection: BaseDatabaseWrapper
|
||||
def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
|
||||
def get_field_type(self, data_type: str, description: FieldInfo) -> str: ...
|
||||
def table_name_converter(self, name: str) -> str: ...
|
||||
|
||||
@@ -8,12 +8,8 @@ from django.db.backends.utils import CursorWrapper
|
||||
from django.db.models.base import Model
|
||||
from django.db.models.expressions import Case, Expression
|
||||
from django.db.models.sql.compiler import SQLCompiler
|
||||
|
||||
from django.db import DefaultConnectionProxy
|
||||
from django.db.models.fields import Field
|
||||
|
||||
_Connection = Union[DefaultConnectionProxy, BaseDatabaseWrapper]
|
||||
|
||||
class BaseDatabaseOperations:
|
||||
compiler_module: str = ...
|
||||
integer_field_ranges: Any = ...
|
||||
@@ -25,9 +21,9 @@ class BaseDatabaseOperations:
|
||||
UNBOUNDED_PRECEDING: Any = ...
|
||||
UNBOUNDED_FOLLOWING: Any = ...
|
||||
CURRENT_ROW: str = ...
|
||||
explain_prefix: Any = ...
|
||||
connection: _Connection = ...
|
||||
def __init__(self, connection: Optional[_Connection]) -> None: ...
|
||||
explain_prefix: Optional[str] = ...
|
||||
connection: BaseDatabaseWrapper
|
||||
def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
|
||||
def autoinc_sql(self, table: str, column: str) -> None: ...
|
||||
def bulk_batch_size(self, fields: Any, objs: Any): ...
|
||||
def cache_key_culling_sql(self) -> str: ...
|
||||
@@ -88,7 +84,7 @@ class BaseDatabaseOperations:
|
||||
def year_lookup_bounds_for_datetime_field(self, value: int) -> List[str]: ...
|
||||
def get_db_converters(self, expression: Expression) -> List[Any]: ...
|
||||
def convert_durationfield_value(
|
||||
self, value: Optional[float], expression: Expression, connection: _Connection
|
||||
self, value: Optional[float], expression: Expression, connection: BaseDatabaseWrapper
|
||||
) -> Optional[timedelta]: ...
|
||||
def check_expression_support(self, expression: Any) -> None: ...
|
||||
def combine_expression(self, connector: str, sub_expressions: List[str]) -> str: ...
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any, ContextManager, List, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.backends.ddl_references import Statement
|
||||
from django.db.models.base import Model
|
||||
from django.db.models.indexes import Index
|
||||
@@ -35,11 +36,11 @@ class BaseDatabaseSchemaEditor(ContextManager[Any]):
|
||||
sql_create_pk: str = ...
|
||||
sql_delete_pk: str = ...
|
||||
sql_delete_procedure: str = ...
|
||||
connection: Any = ...
|
||||
connection: BaseDatabaseWrapper = ...
|
||||
collect_sql: bool = ...
|
||||
collected_sql: Any = ...
|
||||
atomic_migration: Any = ...
|
||||
def __init__(self, connection: Any, collect_sql: bool = ..., atomic: bool = ...) -> None: ...
|
||||
def __init__(self, connection: BaseDatabaseWrapper, collect_sql: bool = ..., atomic: bool = ...) -> None: ...
|
||||
deferred_sql: Any = ...
|
||||
atomic: Any = ...
|
||||
def __enter__(self) -> BaseDatabaseSchemaEditor: ...
|
||||
|
||||
@@ -5,7 +5,7 @@ from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.models.fields import Field
|
||||
|
||||
class BaseDatabaseValidation:
|
||||
connection: Any = ...
|
||||
connection: BaseDatabaseWrapper = ...
|
||||
def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
|
||||
def check(self, **kwargs: Any) -> List[Any]: ...
|
||||
def check_field(self, field: Field, **kwargs: Any) -> List[Any]: ...
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any, Callable, List, Optional, Set, Tuple, Union
|
||||
|
||||
from django.db import DefaultConnectionProxy
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.migrations.migration import Migration
|
||||
|
||||
@@ -9,13 +8,13 @@ from .recorder import MigrationRecorder
|
||||
from .state import ProjectState
|
||||
|
||||
class MigrationExecutor:
|
||||
connection: Any = ...
|
||||
connection: BaseDatabaseWrapper = ...
|
||||
loader: MigrationLoader = ...
|
||||
recorder: MigrationRecorder = ...
|
||||
progress_callback: Callable = ...
|
||||
def __init__(
|
||||
self,
|
||||
connection: Optional[Union[DefaultConnectionProxy, BaseDatabaseWrapper]],
|
||||
connection: Optional[BaseDatabaseWrapper],
|
||||
progress_callback: Optional[Callable] = ...,
|
||||
) -> None: ...
|
||||
def migration_plan(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.migrations.migration import Migration
|
||||
from django.db.migrations.state import ProjectState
|
||||
|
||||
@@ -13,11 +14,13 @@ from .exceptions import (
|
||||
MIGRATIONS_MODULE_NAME: str
|
||||
|
||||
class MigrationLoader:
|
||||
connection: Any = ...
|
||||
connection: Optional[BaseDatabaseWrapper] = ...
|
||||
disk_migrations: Dict[Tuple[str, str], Migration] = ...
|
||||
applied_migrations: Set[Tuple[str, str]] = ...
|
||||
ignore_no_migrations: bool = ...
|
||||
def __init__(self, connection: Any, load: bool = ..., ignore_no_migrations: bool = ...) -> None: ...
|
||||
def __init__(
|
||||
self, connection: Optional[BaseDatabaseWrapper], load: bool = ..., ignore_no_migrations: bool = ...
|
||||
) -> None: ...
|
||||
@classmethod
|
||||
def migrations_module(cls, app_label: str) -> Tuple[Optional[str], bool]: ...
|
||||
unmigrated_apps: Set[str] = ...
|
||||
@@ -31,7 +34,7 @@ class MigrationLoader:
|
||||
graph: Any = ...
|
||||
replacements: Any = ...
|
||||
def build_graph(self) -> None: ...
|
||||
def check_consistent_history(self, connection: Any) -> None: ...
|
||||
def check_consistent_history(self, connection: BaseDatabaseWrapper) -> None: ...
|
||||
def detect_conflicts(self) -> Dict[str, Set[str]]: ...
|
||||
def project_state(
|
||||
self, nodes: Optional[Union[Tuple[str, str], Sequence[Tuple[str, str]]]] = ..., at_end: bool = ...
|
||||
|
||||
@@ -10,8 +10,8 @@ class MigrationRecorder:
|
||||
app: Any = ...
|
||||
name: Any = ...
|
||||
applied: Any = ...
|
||||
connection: Optional[BaseDatabaseWrapper] = ...
|
||||
def __init__(self, connection: Optional[BaseDatabaseWrapper]) -> None: ...
|
||||
connection: BaseDatabaseWrapper = ...
|
||||
def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
|
||||
@property
|
||||
def migration_qs(self) -> QuerySet: ...
|
||||
def has_table(self) -> bool: ...
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set,
|
||||
from django.db.models.lookups import Lookup
|
||||
from django.db.models.sql.compiler import SQLCompiler
|
||||
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.models import Q, QuerySet
|
||||
from django.db.models.fields import Field
|
||||
from django.db.models.query import _BaseQuerySet
|
||||
@@ -12,7 +13,9 @@ from django.db.models.query import _BaseQuerySet
|
||||
_OutputField = Union[Field, str]
|
||||
|
||||
class SQLiteNumericMixin:
|
||||
def as_sqlite(self, compiler: SQLCompiler, connection: Any, **extra_context: Any) -> Tuple[str, List[float]]: ...
|
||||
def as_sqlite(
|
||||
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any
|
||||
) -> Tuple[str, List[float]]: ...
|
||||
|
||||
_Self = TypeVar("_Self")
|
||||
|
||||
@@ -57,7 +60,7 @@ class BaseExpression:
|
||||
filterable: bool = ...
|
||||
window_compatible: bool = ...
|
||||
def __init__(self, output_field: Optional[_OutputField] = ...) -> None: ...
|
||||
def get_db_converters(self, connection: Any) -> List[Callable]: ...
|
||||
def get_db_converters(self, connection: BaseDatabaseWrapper) -> List[Callable]: ...
|
||||
def get_source_expressions(self) -> List[Any]: ...
|
||||
def set_source_expressions(self, exprs: Sequence[Combinable]) -> None: ...
|
||||
@property
|
||||
@@ -91,11 +94,11 @@ class BaseExpression:
|
||||
def reverse_ordering(self): ...
|
||||
def flatten(self) -> Iterator[Expression]: ...
|
||||
def deconstruct(self) -> Any: ...
|
||||
def as_sqlite(self, compiler: SQLCompiler, connection: Any) -> Any: ...
|
||||
def as_sql(self, compiler: SQLCompiler, connection: Any, **extra_context: Any) -> Any: ...
|
||||
def as_mysql(self, compiler: Any, connection: Any) -> Any: ...
|
||||
def as_postgresql(self, compiler: Any, connection: Any) -> Any: ...
|
||||
def as_oracle(self, compiler: Any, connection: Any): ...
|
||||
def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
|
||||
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> Any: ...
|
||||
def as_mysql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
|
||||
def as_postgresql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
|
||||
def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
|
||||
|
||||
class Expression(BaseExpression, Combinable): ...
|
||||
|
||||
@@ -216,7 +219,9 @@ class WindowFrame(Expression):
|
||||
template: str = ...
|
||||
frame_type: str = ...
|
||||
def __init__(self, start: Optional[int] = ..., end: Optional[int] = ...) -> None: ...
|
||||
def window_frame_start_end(self, connection: Any, start: Optional[int], end: Optional[int]) -> Tuple[int, int]: ...
|
||||
def window_frame_start_end(
|
||||
self, connection: BaseDatabaseWrapper, start: Optional[int], end: Optional[int]
|
||||
) -> Tuple[int, int]: ...
|
||||
|
||||
class RowRange(WindowFrame): ...
|
||||
class ValueRange(WindowFrame): ...
|
||||
|
||||
@@ -18,9 +18,9 @@ from typing import (
|
||||
)
|
||||
|
||||
from django.core.checks import CheckMessage
|
||||
|
||||
from django.db.models import Model
|
||||
from django.core.exceptions import FieldDoesNotExist as FieldDoesNotExist
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.models import Model
|
||||
from django.db.models.expressions import Combinable, Col
|
||||
from django.db.models.query_utils import RegisterLookupMixin
|
||||
from django.forms import Field as FormField, Widget
|
||||
@@ -110,12 +110,12 @@ class Field(RegisterLookupMixin, Generic[_ST, _GT]):
|
||||
def __get__(self: _T, instance, owner) -> _T: ...
|
||||
def deconstruct(self) -> Any: ...
|
||||
def set_attributes_from_name(self, name: str) -> None: ...
|
||||
def db_type(self, connection: Any) -> str: ...
|
||||
def db_parameters(self, connection: Any) -> Dict[str, str]: ...
|
||||
def db_type(self, connection: BaseDatabaseWrapper) -> str: ...
|
||||
def db_parameters(self, connection: BaseDatabaseWrapper) -> Dict[str, str]: ...
|
||||
def pre_save(self, model_instance: Model, add: bool) -> Any: ...
|
||||
def get_prep_value(self, value: Any) -> Any: ...
|
||||
def get_db_prep_value(self, value: Any, connection: Any, prepared: bool) -> Any: ...
|
||||
def get_db_prep_save(self, value: Any, connection: Any) -> Any: ...
|
||||
def get_db_prep_value(self, value: Any, connection: BaseDatabaseWrapper, prepared: bool) -> Any: ...
|
||||
def get_db_prep_save(self, value: Any, connection: BaseDatabaseWrapper) -> Any: ...
|
||||
def get_internal_type(self) -> str: ...
|
||||
# TODO: plugin support
|
||||
def formfield(self, **kwargs) -> Any: ...
|
||||
@@ -150,7 +150,7 @@ class IntegerField(Field[_ST, _GT]):
|
||||
_pyi_lookup_exact_type: Union[str, int]
|
||||
|
||||
class PositiveIntegerRelDbTypeMixin:
|
||||
def rel_db_type(self, connection: Any): ...
|
||||
def rel_db_type(self, connection: BaseDatabaseWrapper) -> str: ...
|
||||
|
||||
class PositiveIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField[_ST, _GT]): ...
|
||||
class PositiveSmallIntegerField(PositiveIntegerRelDbTypeMixin, IntegerField[_ST, _GT]): ...
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from . import Field
|
||||
from .mixins import CheckFieldDefaultMixin
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.models import lookups
|
||||
from django.db.models.lookups import PostgresOperatorLookup, Transform
|
||||
from django.db.models.sql.compiler import SQLCompiler
|
||||
from typing import Any, Optional, Callable
|
||||
|
||||
class JSONField(CheckFieldDefaultMixin, Field):
|
||||
@@ -25,7 +27,7 @@ class JSONExact(lookups.Exact): ...
|
||||
class KeyTransform(Transform):
|
||||
key_name: Any = ...
|
||||
def __init__(self, key_name: Any, *args: Any, **kwargs: Any) -> None: ...
|
||||
def preprocess_lhs(self, compiler: Any, connection: Any, lhs_only: bool = ...): ...
|
||||
def preprocess_lhs(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs_only: bool = ...) -> Any: ...
|
||||
|
||||
class KeyTextTransform(KeyTransform): ...
|
||||
|
||||
|
||||
@@ -27,15 +27,19 @@ class Lookup(Generic[_T]):
|
||||
def get_source_expressions(self) -> List[Expression]: ...
|
||||
def set_source_expressions(self, new_exprs: List[Expression]) -> None: ...
|
||||
def get_prep_lookup(self) -> Any: ...
|
||||
def get_db_prep_lookup(self, value: Union[int, str], connection: BaseDatabaseWrapper) -> Tuple[str, List[SafeText]]: ...
|
||||
def get_db_prep_lookup(
|
||||
self, value: Union[int, str], connection: BaseDatabaseWrapper
|
||||
) -> Tuple[str, List[SafeText]]: ...
|
||||
def process_lhs(
|
||||
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, lhs: Optional[Expression] = ...
|
||||
) -> Tuple[str, List[Union[int, str]]]: ...
|
||||
def process_rhs(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, List[Union[int, str]]]: ...
|
||||
def process_rhs(
|
||||
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
|
||||
) -> Tuple[str, List[Union[int, str]]]: ...
|
||||
def rhs_is_direct_value(self) -> bool: ...
|
||||
def relabeled_clone(self: _T, relabels: Mapping[str, str]) -> _T: ...
|
||||
def get_group_by_cols(self) -> List[Expression]: ...
|
||||
def as_sql(self, compiler: Any, connection: Any) -> Any: ...
|
||||
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
|
||||
def contains_aggregate(self) -> bool: ...
|
||||
def contains_over_clause(self) -> bool: ...
|
||||
@property
|
||||
@@ -61,7 +65,7 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
|
||||
|
||||
class PostgresOperatorLookup(FieldGetDbPrepValueMixin, Lookup):
|
||||
postgres_operator: str = ...
|
||||
def as_postgresql(self, compiler: Any, connection: Any): ...
|
||||
def as_postgresql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
|
||||
|
||||
class Exact(FieldGetDbPrepValueMixin, BuiltinLookup): ...
|
||||
class IExact(BuiltinLookup): ...
|
||||
@@ -78,7 +82,7 @@ class IntegerGreaterThanOrEqual(IntegerFieldFloatRounding, GreaterThanOrEqual[Un
|
||||
class IntegerLessThan(IntegerFieldFloatRounding, LessThan[Union[int, float]]): ...
|
||||
|
||||
class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
|
||||
def split_parameter_list_as_sql(self, compiler: Any, connection: Any): ...
|
||||
def split_parameter_list_as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
|
||||
|
||||
class PatternLookup(BuiltinLookup[str]):
|
||||
param_pattern: str = ...
|
||||
@@ -109,7 +113,7 @@ class YearLte(YearComparisonLookup): ...
|
||||
|
||||
class UUIDTextMixin:
|
||||
rhs: Any = ...
|
||||
def process_rhs(self, qn: Any, connection: Any): ...
|
||||
def process_rhs(self, qn: Any, connection: BaseDatabaseWrapper) -> Any: ...
|
||||
|
||||
class UUIDIExact(UUIDTextMixin, IExact): ...
|
||||
class UUIDContains(UUIDTextMixin, Contains): ...
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections import namedtuple
|
||||
from typing import Any, Collection, Dict, Iterator, List, Mapping, Optional, Sequence, Set, Tuple, Type
|
||||
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.models.base import Model
|
||||
from django.db.models.fields.mixins import FieldCacheMixin
|
||||
from django.db.models.sql.compiler import SQLCompiler
|
||||
@@ -20,7 +21,7 @@ class QueryWrapper:
|
||||
contains_aggregate: bool = ...
|
||||
data: Tuple[str, List[Any]] = ...
|
||||
def __init__(self, sql: str, params: List[Any]) -> None: ...
|
||||
def as_sql(self, compiler: SQLCompiler = ..., connection: Any = ...) -> Any: ...
|
||||
def as_sql(self, compiler: Optional[SQLCompiler] = ..., connection: Optional[BaseDatabaseWrapper] = ...) -> Any: ...
|
||||
|
||||
class Q(tree.Node):
|
||||
AND: str = ...
|
||||
@@ -76,4 +77,4 @@ class FilteredRelation:
|
||||
def __init__(self, relation_name: str, *, condition: Any = ...) -> None: ...
|
||||
def clone(self) -> FilteredRelation: ...
|
||||
def resolve_expression(self, *args: Any, **kwargs: Any) -> None: ...
|
||||
def as_sql(self, compiler: SQLCompiler, connection: Any) -> Any: ...
|
||||
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
|
||||
|
||||
@@ -4,23 +4,25 @@ from itertools import chain
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union
|
||||
from uuid import UUID
|
||||
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.models.base import Model
|
||||
from django.db.models.expressions import BaseExpression, Expression
|
||||
|
||||
from django.db.models.sql.query import Query, RawQuery
|
||||
|
||||
FORCE: Any
|
||||
|
||||
class SQLCompiler:
|
||||
query: Any = ...
|
||||
connection: Any = ...
|
||||
connection: BaseDatabaseWrapper = ...
|
||||
using: Any = ...
|
||||
quote_cache: Any = ...
|
||||
select: Any = ...
|
||||
annotation_col_map: Any = ...
|
||||
klass_info: Any = ...
|
||||
ordering_parts: Any = ...
|
||||
def __init__(self, query: Union[Query, RawQuery], connection: Any, using: Optional[str]) -> None: ...
|
||||
def __init__(
|
||||
self, query: Union[Query, RawQuery], connection: BaseDatabaseWrapper, using: Optional[str]
|
||||
) -> None: ...
|
||||
col_count: Any = ...
|
||||
def setup_query(self) -> None: ...
|
||||
has_extra_select: Any = ...
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.models.fields.mixins import FieldCacheMixin
|
||||
from django.db.models.query_utils import FilteredRelation, PathInfo
|
||||
from django.db.models.sql.compiler import SQLCompiler
|
||||
@@ -31,7 +32,7 @@ class Join:
|
||||
nullable: bool,
|
||||
filtered_relation: Optional[FilteredRelation] = ...,
|
||||
) -> None: ...
|
||||
def as_sql(self, compiler: SQLCompiler, connection: Any) -> Tuple[str, List[Union[int, str]]]: ...
|
||||
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, List[Union[int, str]]]: ...
|
||||
def relabeled_clone(self, change_map: Union[Dict[str, str], OrderedDict]) -> Join: ...
|
||||
def equals(self, other: Union[BaseTable, Join], with_filtered_relation: bool) -> bool: ...
|
||||
def demote(self) -> Join: ...
|
||||
@@ -44,6 +45,6 @@ class BaseTable:
|
||||
table_name: str = ...
|
||||
table_alias: Optional[str] = ...
|
||||
def __init__(self, table_name: str, alias: Optional[str]) -> None: ...
|
||||
def as_sql(self, compiler: SQLCompiler, connection: Any) -> Tuple[str, List[Any]]: ...
|
||||
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, List[Any]]: ...
|
||||
def relabeled_clone(self, change_map: OrderedDict) -> BaseTable: ...
|
||||
def equals(self, other: Join, with_filtered_relation: bool) -> bool: ...
|
||||
|
||||
@@ -2,15 +2,15 @@ import collections
|
||||
from collections import OrderedDict, namedtuple
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, Union, Iterable
|
||||
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.models import Expression, Field, FilteredRelation, Model, Q, QuerySet
|
||||
from django.db.models.expressions import Combinable
|
||||
from django.db.models.lookups import Lookup, Transform
|
||||
from django.db.models.query_utils import PathInfo, RegisterLookupMixin
|
||||
from django.db.models.sql.compiler import SQLCompiler
|
||||
from django.db.models.sql.datastructures import BaseTable
|
||||
from django.db.models.sql.where import WhereNode
|
||||
|
||||
from django.db.models import Expression, Field, FilteredRelation, Model, Q, QuerySet
|
||||
from django.db.models.expressions import Combinable
|
||||
|
||||
JoinInfo = namedtuple("JoinInfo", ["final_field", "targets", "opts", "joins", "path", "transform_function"])
|
||||
|
||||
class RawQuery:
|
||||
@@ -81,7 +81,7 @@ class Query:
|
||||
def has_select_fields(self) -> bool: ...
|
||||
def sql_with_params(self) -> Tuple[str, Tuple]: ...
|
||||
def __deepcopy__(self, memo: Dict[str, Any]) -> Query: ...
|
||||
def get_compiler(self, using: Optional[str] = ..., connection: Any = ...) -> SQLCompiler: ...
|
||||
def get_compiler(self, using: Optional[str] = ..., connection: BaseDatabaseWrapper = ...) -> SQLCompiler: ...
|
||||
def clone(self) -> Query: ...
|
||||
def chain(self, klass: Optional[Type[Query]] = ...) -> Query: ...
|
||||
def relabeled_clone(self, change_map: Union[Dict[Any, Any], OrderedDict]) -> Query: ...
|
||||
@@ -101,7 +101,7 @@ class Query:
|
||||
def get_initial_alias(self) -> str: ...
|
||||
def count_active_tables(self) -> int: ...
|
||||
def resolve_expression(self, query: Query, *args: Any, **kwargs: Any) -> Query: ...
|
||||
def as_sql(self, compiler: SQLCompiler, connection: Any) -> Any: ...
|
||||
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
|
||||
def resolve_lookup_value(self, value: Any, can_reuse: Optional[Set[str]], allow_joins: bool) -> Any: ...
|
||||
def solve_lookup_type(self, lookup: str) -> Tuple[Sequence[str], Sequence[str], bool]: ...
|
||||
def build_filter(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.models.expressions import Expression
|
||||
from django.db.models.sql.compiler import SQLCompiler
|
||||
from django.db.models.sql.query import Query
|
||||
@@ -18,7 +19,7 @@ class WhereNode(tree.Node):
|
||||
resolved: bool = ...
|
||||
conditional: bool = ...
|
||||
def split_having(self, negated: bool = ...) -> Tuple[Optional[WhereNode], Optional[WhereNode]]: ...
|
||||
def as_sql(self, compiler: SQLCompiler, connection: Any) -> Any: ...
|
||||
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
|
||||
def get_group_by_cols(self) -> List[Expression]: ...
|
||||
def relabel_aliases(self, change_map: Union[Dict[Optional[str], str], OrderedDict]) -> None: ...
|
||||
def clone(self) -> WhereNode: ...
|
||||
@@ -27,14 +28,16 @@ class WhereNode(tree.Node):
|
||||
|
||||
class NothingNode:
|
||||
contains_aggregate: bool = ...
|
||||
def as_sql(self, compiler: SQLCompiler = ..., connection: Any = ...) -> Any: ...
|
||||
def as_sql(self, compiler: Optional[SQLCompiler] = ..., connection: Optional[BaseDatabaseWrapper] = ...) -> Any: ...
|
||||
|
||||
class ExtraWhere:
|
||||
contains_aggregate: bool = ...
|
||||
sqls: List[str] = ...
|
||||
params: Optional[Union[List[int], List[str]]] = ...
|
||||
def __init__(self, sqls: List[str], params: Optional[Union[List[int], List[str]]]) -> None: ...
|
||||
def as_sql(self, compiler: SQLCompiler = ..., connection: Any = ...) -> Tuple[str, Union[List[int], List[str]]]: ...
|
||||
def as_sql(
|
||||
self, compiler: Optional[SQLCompiler] = ..., connection: Optional[BaseDatabaseWrapper] = ...
|
||||
) -> Tuple[str, Union[List[int], List[str]]]: ...
|
||||
|
||||
class SubqueryConstraint:
|
||||
contains_aggregate: bool = ...
|
||||
@@ -43,4 +46,4 @@ class SubqueryConstraint:
|
||||
targets: List[str] = ...
|
||||
query_object: Query = ...
|
||||
def __init__(self, alias: str, columns: List[str], targets: List[str], query_object: Query) -> None: ...
|
||||
def as_sql(self, compiler: SQLCompiler, connection: Any) -> Tuple[str, Tuple]: ...
|
||||
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Tuple[str, Tuple]: ...
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
from typing import Any, Dict, Iterable, List, Optional, Iterator, Type
|
||||
|
||||
from django.apps import AppConfig
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.models import Model
|
||||
|
||||
DEFAULT_DB_ALIAS: str
|
||||
DJANGO_VERSION_PICKLE_KEY: str
|
||||
@@ -22,14 +26,22 @@ class ConnectionHandler:
|
||||
def __init__(self, databases: Dict[str, Dict[str, Optional[Any]]] = ...) -> None: ...
|
||||
def ensure_defaults(self, alias: str) -> None: ...
|
||||
def prepare_test_settings(self, alias: str) -> None: ...
|
||||
def __getitem__(self, alias: str) -> Any: ...
|
||||
def __setitem__(self, key: Any, value: Any) -> None: ...
|
||||
def __delitem__(self, key: Any) -> None: ...
|
||||
def __iter__(self): ...
|
||||
def all(self) -> List[Any]: ...
|
||||
def __getitem__(self, alias: str) -> BaseDatabaseWrapper: ...
|
||||
def __setitem__(self, key: str, value: BaseDatabaseWrapper) -> None: ...
|
||||
def __delitem__(self, key: BaseDatabaseWrapper) -> None: ...
|
||||
def __iter__(self) -> Iterator[str]: ...
|
||||
def all(self) -> List[BaseDatabaseWrapper]: ...
|
||||
def close_all(self) -> None: ...
|
||||
|
||||
class ConnectionRouter:
|
||||
def __init__(self, routers: Optional[Iterable[Any]] = ...) -> None: ...
|
||||
@property
|
||||
def routers(self) -> List[Any]: ...
|
||||
def db_for_read(self, model: Type[Model], **hints: Any) -> str: ...
|
||||
def db_for_write(self, model: Type[Model], **hints: Any) -> str: ...
|
||||
def allow_relation(self, obj1: Model, obj2: Model, **hints: Any) -> bool: ...
|
||||
def allow_migrate(self, db: str, app_label: str, **hints: Any) -> bool: ...
|
||||
def allow_migrate_model(self, db: str, model: Type[Model]) -> bool: ...
|
||||
def get_migratable_models(
|
||||
self, app_config: AppConfig, db: str, include_auto_created: bool = ...
|
||||
) -> List[Type[Model]]: ...
|
||||
|
||||
@@ -20,7 +20,7 @@ from django.db import connections as connections # noqa: F401
|
||||
class _AssertNumQueriesContext(CaptureQueriesContext):
|
||||
test_case: SimpleTestCase = ...
|
||||
num: int = ...
|
||||
def __init__(self, test_case: Any, num: Any, connection: Any) -> None: ...
|
||||
def __init__(self, test_case: Any, num: Any, connection: BaseDatabaseWrapper) -> None: ...
|
||||
|
||||
class _AssertTemplateUsedContext:
|
||||
test_case: SimpleTestCase = ...
|
||||
@@ -200,7 +200,7 @@ class LiveServerThread(threading.Thread):
|
||||
is_ready: threading.Event = ...
|
||||
error: Optional[ImproperlyConfigured] = ...
|
||||
static_handler: Type[WSGIHandler] = ...
|
||||
connections_override: Dict[str, Any] = ...
|
||||
connections_override: Dict[str, BaseDatabaseWrapper] = ...
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
|
||||
@@ -27,6 +27,7 @@ from django.test.runner import DiscoverRunner
|
||||
from django.test.testcases import SimpleTestCase
|
||||
|
||||
from django.conf import LazySettings, Settings
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
|
||||
_TestClass = Type[SimpleTestCase]
|
||||
_DecoratedTest = Union[Callable, _TestClass]
|
||||
@@ -84,11 +85,11 @@ class override_system_checks(TestContextDecorator):
|
||||
old_deployment_checks: Set[Callable] = ...
|
||||
|
||||
class CaptureQueriesContext:
|
||||
connection: Any = ...
|
||||
connection: BaseDatabaseWrapper = ...
|
||||
force_debug_cursor: bool = ...
|
||||
initial_queries: int = ...
|
||||
final_queries: Optional[int] = ...
|
||||
def __init__(self, connection: Any) -> None: ...
|
||||
def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
|
||||
def __iter__(self): ...
|
||||
def __getitem__(self, index: int) -> Dict[str, str]: ...
|
||||
def __len__(self) -> int: ...
|
||||
|
||||
13
tests/typecheck/db/test_connection.yml
Normal file
13
tests/typecheck/db/test_connection.yml
Normal file
@@ -0,0 +1,13 @@
|
||||
- case: raw_default_connection
|
||||
main: |
|
||||
from django.db import connection
|
||||
with connection.cursor() as cursor:
|
||||
reveal_type(cursor) # N: Revealed type is 'django.db.backends.utils.CursorWrapper'
|
||||
cursor.execute("SELECT %s", [123])
|
||||
- case: raw_connections
|
||||
main: |
|
||||
from django.db import connections
|
||||
reveal_type(connections["test"]) # N: Revealed type is 'django.db.backends.base.base.BaseDatabaseWrapper'
|
||||
for connection in connections.all():
|
||||
with connection.cursor() as cursor:
|
||||
reveal_type(cursor) # N: Revealed type is 'django.db.backends.utils.CursorWrapper'
|
||||
Reference in New Issue
Block a user