mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-06 20:24:31 +08:00
improve annotations in some places (#202)
* improve annotations in some places * linting
This commit is contained in:
@@ -81,9 +81,7 @@ class ChangeList:
|
||||
paginator: Any = ...
|
||||
def get_results(self, request: WSGIRequest) -> None: ...
|
||||
def get_ordering_field(self, field_name: Union[Callable, str]) -> Optional[Union[CombinedExpression, str]]: ...
|
||||
def get_ordering(
|
||||
self, request: WSGIRequest, queryset: QuerySet
|
||||
) -> Union[List[Union[Combinable, str]], List[Union[OrderBy, str]]]: ...
|
||||
def get_ordering(self, request: WSGIRequest, queryset: QuerySet) -> List[Union[OrderBy, Combinable, str]]: ...
|
||||
def get_ordering_field_columns(self) -> OrderedDict: ...
|
||||
def get_queryset(self, request: WSGIRequest) -> QuerySet: ...
|
||||
def apply_select_related(self, qs: QuerySet) -> QuerySet: ...
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections import namedtuple
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Type
|
||||
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.backends.utils import CursorWrapper
|
||||
@@ -13,7 +13,7 @@ class BaseDatabaseIntrospection:
|
||||
data_types_reverse: Any = ...
|
||||
connection: Any = ...
|
||||
def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
|
||||
def get_field_type(self, data_type: str, description: FieldInfo) -> Union[Tuple[str, Dict[str, int]], str]: ...
|
||||
def get_field_type(self, data_type: str, description: FieldInfo) -> str: ...
|
||||
def table_name_converter(self, name: str) -> str: ...
|
||||
def column_name_converter(self, name: str) -> str: ...
|
||||
def table_names(self, cursor: Optional[CursorWrapper] = ..., include_views: bool = ...) -> List[str]: ...
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Any, List, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
from django.core.management.color import Style
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.backends.sqlite3.base import DatabaseWrapper
|
||||
from django.db.backends.utils import CursorWrapper
|
||||
from django.db.models.base import Model
|
||||
from django.db.models.expressions import Case, Expression
|
||||
@@ -13,6 +12,8 @@ from django.db.models.sql.compiler import SQLCompiler
|
||||
from django.db import DefaultConnectionProxy
|
||||
from django.db.models.fields import Field
|
||||
|
||||
_Connection = Union[DefaultConnectionProxy, BaseDatabaseWrapper]
|
||||
|
||||
class BaseDatabaseOperations:
|
||||
compiler_module: str = ...
|
||||
integer_field_ranges: Any = ...
|
||||
@@ -25,8 +26,8 @@ class BaseDatabaseOperations:
|
||||
UNBOUNDED_FOLLOWING: Any = ...
|
||||
CURRENT_ROW: str = ...
|
||||
explain_prefix: Any = ...
|
||||
connection: Any = ...
|
||||
def __init__(self, connection: Optional[Union[DefaultConnectionProxy, BaseDatabaseWrapper]]) -> None: ...
|
||||
connection: _Connection = ...
|
||||
def __init__(self, connection: Optional[_Connection]) -> None: ...
|
||||
def autoinc_sql(self, table: str, column: str) -> None: ...
|
||||
def bulk_batch_size(self, fields: Any, objs: Any): ...
|
||||
def cache_key_culling_sql(self) -> str: ...
|
||||
@@ -75,9 +76,9 @@ class BaseDatabaseOperations:
|
||||
def prep_for_like_query(self, x: str) -> str: ...
|
||||
prep_for_iexact_query: Any = ...
|
||||
def validate_autopk_value(self, value: int) -> int: ...
|
||||
def adapt_unknown_value(self, value: Union[datetime, Decimal, int, str]) -> Union[int, str]: ...
|
||||
def adapt_unknown_value(self, value: Any) -> Any: ...
|
||||
def adapt_datefield_value(self, value: Optional[date]) -> Optional[str]: ...
|
||||
def adapt_datetimefield_value(self, value: None) -> None: ...
|
||||
def adapt_datetimefield_value(self, value: Optional[datetime]) -> Optional[str]: ...
|
||||
def adapt_timefield_value(self, value: Optional[datetime]) -> Optional[str]: ...
|
||||
def adapt_decimalfield_value(
|
||||
self, value: Optional[Decimal], max_digits: Optional[int] = ..., decimal_places: Optional[int] = ...
|
||||
@@ -87,19 +88,17 @@ class BaseDatabaseOperations:
|
||||
def year_lookup_bounds_for_datetime_field(self, value: int) -> List[str]: ...
|
||||
def get_db_converters(self, expression: Expression) -> List[Any]: ...
|
||||
def convert_durationfield_value(
|
||||
self, value: Optional[float], expression: Expression, connection: DatabaseWrapper
|
||||
self, value: Optional[float], expression: Expression, connection: _Connection
|
||||
) -> Optional[timedelta]: ...
|
||||
def check_expression_support(self, expression: Any) -> None: ...
|
||||
def combine_expression(self, connector: str, sub_expressions: List[str]) -> str: ...
|
||||
def combine_duration_expression(self, connector: Any, sub_expressions: Any): ...
|
||||
def binary_placeholder_sql(self, value: Optional[Case]) -> str: ...
|
||||
def modify_insert_params(
|
||||
self, placeholder: str, params: Union[List[None], List[bool], List[float], List[str]]
|
||||
) -> Union[List[None], List[bool], List[float], List[str]]: ...
|
||||
def modify_insert_params(self, placeholder: str, params: Any) -> Any: ...
|
||||
def integer_field_range(self, internal_type: Any): ...
|
||||
def subtract_temporals(self, internal_type: Any, lhs: Any, rhs: Any): ...
|
||||
def window_frame_start(self, start: Any): ...
|
||||
def window_frame_end(self, end: Any): ...
|
||||
def window_frame_rows_start_end(self, start: None = ..., end: None = ...) -> Any: ...
|
||||
def window_frame_range_start_end(self, start: Optional[Any] = ..., end: Optional[Any] = ...): ...
|
||||
def window_frame_rows_start_end(self, start: Optional[int] = ..., end: Optional[int] = ...) -> Any: ...
|
||||
def window_frame_range_start_end(self, start: Optional[int] = ..., end: Optional[int] = ...) -> Any: ...
|
||||
def explain_query_prefix(self, format: Optional[str] = ..., **options: Any) -> str: ...
|
||||
|
||||
@@ -48,7 +48,7 @@ class BaseDatabaseSchemaEditor(ContextManager[Any]):
|
||||
def quote_name(self, name: str) -> str: ...
|
||||
def column_sql(
|
||||
self, model: Type[Model], field: Field, include_default: bool = ...
|
||||
) -> Union[Tuple[None, None], Tuple[str, List[Any]]]: ...
|
||||
) -> Tuple[Optional[str], Optional[List[Any]]]: ...
|
||||
def skip_default(self, field: Any): ...
|
||||
def prepare_default(self, value: Any) -> None: ...
|
||||
def effective_default(self, field: Field) -> Optional[Union[int, str]]: ...
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
from django.db.backends.base.introspection import BaseDatabaseIntrospection
|
||||
|
||||
@@ -8,6 +8,6 @@ def get_field_size(name: str) -> Optional[int]: ...
|
||||
|
||||
class FlexibleFieldLookupDict:
|
||||
base_data_types_reverse: Any = ...
|
||||
def __getitem__(self, key: str) -> Union[Tuple[str, Dict[str, int]], str]: ...
|
||||
def __getitem__(self, key: str) -> Any: ...
|
||||
|
||||
class DatabaseIntrospection(BaseDatabaseIntrospection): ...
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union, Collection
|
||||
|
||||
from django.core.checks.messages import CheckMessage
|
||||
from django.db.models.manager import Manager
|
||||
@@ -18,10 +18,10 @@ class Model(metaclass=ModelBase):
|
||||
pk: Any = ...
|
||||
def __init__(self: _Self, *args, **kwargs) -> None: ...
|
||||
def delete(self, using: Any = ..., keep_parents: bool = ...) -> Tuple[int, Dict[str, int]]: ...
|
||||
def full_clean(self, exclude: Optional[List[str]] = ..., validate_unique: bool = ...) -> None: ...
|
||||
def full_clean(self, exclude: Optional[Collection[str]] = ..., validate_unique: bool = ...) -> None: ...
|
||||
def clean(self) -> None: ...
|
||||
def clean_fields(self, exclude: List[str] = ...) -> None: ...
|
||||
def validate_unique(self, exclude: List[str] = ...) -> None: ...
|
||||
def clean_fields(self, exclude: Optional[Collection[str]] = ...) -> None: ...
|
||||
def validate_unique(self, exclude: Optional[Collection[str]] = ...) -> None: ...
|
||||
def save(
|
||||
self,
|
||||
force_insert: bool = ...,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Type, Union, Mapping, TypeVar, Generic
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Type, Union, Mapping, TypeVar, Generic, Sequence
|
||||
|
||||
from django.db.backends.sqlite3.base import DatabaseWrapper
|
||||
from django.db.models.expressions import Expression, Func
|
||||
@@ -33,9 +33,9 @@ class Lookup(Generic[_T]):
|
||||
) -> Tuple[str, List[Union[int, str]]]: ...
|
||||
def process_rhs(
|
||||
self, compiler: SQLCompiler, connection: DatabaseWrapper
|
||||
) -> Tuple[str, Union[List[Union[int, str]], Tuple[int, int]]]: ...
|
||||
) -> Tuple[str, Sequence[Union[int, str]]]: ...
|
||||
def rhs_is_direct_value(self) -> bool: ...
|
||||
def relabeled_clone(self, relabels: Mapping[str, str]) -> Union[BuiltinLookup, FieldGetDbPrepValueMixin]: ...
|
||||
def relabeled_clone(self: _T, relabels: Mapping[str, str]) -> _T: ...
|
||||
def get_group_by_cols(self) -> List[Expression]: ...
|
||||
def as_sql(self, compiler: Any, connection: Any) -> Any: ...
|
||||
def contains_aggregate(self) -> bool: ...
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import collections
|
||||
from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, Sequence, Set, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, Sequence, Tuple, Type, TypeVar, Union
|
||||
|
||||
from django.apps.config import AppConfig
|
||||
from django.apps.registry import Apps
|
||||
@@ -23,8 +23,8 @@ IMMUTABLE_WARNING: str
|
||||
DEFAULT_NAMES: Tuple[str, ...]
|
||||
|
||||
def normalize_together(
|
||||
option_together: Any
|
||||
) -> Union[List[Union[Tuple[str, str], int]], Set[Tuple[str, str]], Tuple, int, str]: ...
|
||||
option_together: Union[Sequence[Tuple[str, str]], Tuple[str, str]]
|
||||
) -> Tuple[Tuple[str, str], ...]: ...
|
||||
def make_immutable_fields_list(
|
||||
name: str, data: Union[Iterator[Any], List[Union[ArrayField, CIText]], List[Union[Field, FieldCacheMixin]]]
|
||||
) -> ImmutableList: ...
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from collections import OrderedDict, namedtuple
|
||||
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Type, Union
|
||||
from collections import namedtuple
|
||||
from typing import Any, Collection, Dict, Iterator, List, Mapping, Optional, Sequence, Set, Tuple, Type, Union
|
||||
|
||||
from django.db.models.base import Model
|
||||
from django.db.models.expressions import Expression
|
||||
from django.db.models.fields import Field
|
||||
from django.db.models.fields.mixins import FieldCacheMixin
|
||||
from django.db.models.sql.compiler import SQLCompiler
|
||||
from django.db.models.sql.query import Query
|
||||
from django.db.models.sql.where import WhereNode
|
||||
|
||||
from django.db.models.fields import Field
|
||||
from django.utils import tree
|
||||
|
||||
PathInfo = namedtuple("PathInfo", "from_opts to_opts target_fields join_field m2m direct filtered_relation")
|
||||
@@ -65,15 +65,11 @@ class RegisterLookupMixin:
|
||||
def select_related_descend(
|
||||
field: Field,
|
||||
restricted: bool,
|
||||
requested: Optional[
|
||||
Union[Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, Dict[Any, Any]]]]]]]], bool]
|
||||
],
|
||||
load_fields: Optional[Set[str]],
|
||||
requested: Optional[Mapping[str, Any]],
|
||||
load_fields: Optional[Collection[str]],
|
||||
reverse: bool = ...,
|
||||
) -> bool: ...
|
||||
def refs_expression(
|
||||
lookup_parts: List[str], annotations: OrderedDict
|
||||
) -> Union[Tuple[bool, Tuple], Tuple[Expression, List[str]]]: ...
|
||||
def refs_expression(lookup_parts: Sequence[str], annotations: Mapping[str, bool]) -> Tuple[bool, Sequence[str]]: ...
|
||||
def check_rel_lookup_compatibility(model: Type[Model], target_opts: Any, field: FieldCacheMixin) -> bool: ...
|
||||
|
||||
class FilteredRelation:
|
||||
|
||||
88
scripts/catch_non_abstract_annotation.py
Normal file
88
scripts/catch_non_abstract_annotation.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import libcst
|
||||
from libcst import Annotation, BaseExpression, FunctionDef, Name, Subscript
|
||||
from libcst.metadata import SyntacticPositionProvider
|
||||
|
||||
BASE_DIR = 'django-stubs'
|
||||
|
||||
fpath = os.path.join(BASE_DIR, 'core', 'checks', 'model_checks.pyi')
|
||||
with open(fpath, 'r') as f:
|
||||
contents = f.read()
|
||||
|
||||
|
||||
tree = libcst.parse_module(contents)
|
||||
|
||||
|
||||
class TypeAnnotationsAnalyzer(libcst.CSTVisitor):
|
||||
METADATA_DEPENDENCIES = (SyntacticPositionProvider,)
|
||||
|
||||
def __init__(self, fpath: str):
|
||||
super().__init__()
|
||||
self.fpath = fpath
|
||||
|
||||
def get_node_location(self, node: FunctionDef) -> str:
|
||||
start_line = self.get_metadata(SyntacticPositionProvider, node).start.line
|
||||
return f'{self.fpath}:{start_line}'
|
||||
|
||||
def show_error_for_node(self, node: FunctionDef, error_message: str):
|
||||
print(self.get_node_location(node), error_message)
|
||||
|
||||
def check_subscripted_annotation(self, annotation: BaseExpression) -> Optional[str]:
|
||||
if isinstance(annotation, Subscript):
|
||||
if isinstance(annotation.value, Name):
|
||||
error_message = self.check_concrete_class_usage(annotation.value)
|
||||
if error_message:
|
||||
return error_message
|
||||
|
||||
if annotation.value.value == 'Union':
|
||||
for slice_param in annotation.slice:
|
||||
if isinstance(slice_param.slice.value, Name):
|
||||
error_message = self.check_concrete_class_usage(annotation.value)
|
||||
if error_message:
|
||||
return error_message
|
||||
|
||||
def check_concrete_class_usage(self, name_node: Name) -> Optional[str]:
|
||||
if name_node.value == 'List':
|
||||
return (f'Concrete class {name_node.value!r} used for an iterable annotation. '
|
||||
f'Use abstract collection (Iterable, Collection, Sequence) instead')
|
||||
|
||||
def visit_FunctionDef(self, node: FunctionDef) -> Optional[bool]:
|
||||
params_node = node.params
|
||||
for param_node in [*params_node.params, *params_node.default_params]:
|
||||
param_name = param_node.name.value
|
||||
annotation_node = param_node.annotation # type: Annotation
|
||||
if annotation_node is not None:
|
||||
annotation = annotation_node.annotation
|
||||
if annotation.value == 'None':
|
||||
self.show_error_for_node(node, f'"None" type annotation used for parameter {param_name!r}')
|
||||
continue
|
||||
|
||||
error_message = self.check_subscripted_annotation(annotation)
|
||||
if error_message is not None:
|
||||
self.show_error_for_node(node, error_message)
|
||||
continue
|
||||
|
||||
if node.returns is not None:
|
||||
return_annotation = node.returns.annotation
|
||||
if isinstance(return_annotation, Subscript) and return_annotation.value.value == 'Union':
|
||||
self.show_error_for_node(node, 'Union is return type annotation')
|
||||
|
||||
return False
|
||||
|
||||
|
||||
for dirpath, dirnames, filenames in os.walk(BASE_DIR):
|
||||
for filename in filenames:
|
||||
fpath = os.path.join(dirpath, filename)
|
||||
# skip all other checks for now, low priority
|
||||
if not fpath.startswith(('django-stubs/db', 'django-stubs/views', 'django-stubs/apps',
|
||||
'django-stubs/http', 'django-stubs/contrib/postgres')):
|
||||
continue
|
||||
|
||||
with open(fpath, 'r') as f:
|
||||
contents = f.read()
|
||||
|
||||
tree = libcst.MetadataWrapper(libcst.parse_module(contents))
|
||||
analyzer = TypeAnnotationsAnalyzer(fpath)
|
||||
tree.visit(analyzer)
|
||||
Reference in New Issue
Block a user