From 8402e7c53ee4d2fa2ea31a506b6a986b67e10c1a Mon Sep 17 00:00:00 2001 From: Maxim Kurnikov Date: Mon, 7 Oct 2019 14:50:45 +0300 Subject: [PATCH] improve annotations in some places (#202) * improve annotations in some places * linting --- django-stubs/contrib/admin/views/main.pyi | 4 +- .../db/backends/base/introspection.pyi | 4 +- django-stubs/db/backends/base/operations.pyi | 21 +++-- django-stubs/db/backends/base/schema.pyi | 2 +- .../db/backends/sqlite3/introspection.pyi | 4 +- django-stubs/db/models/base.pyi | 8 +- django-stubs/db/models/lookups.pyi | 6 +- django-stubs/db/models/options.pyi | 6 +- django-stubs/db/models/query_utils.pyi | 18 ++-- scripts/catch_non_abstract_annotation.py | 88 +++++++++++++++++++ 10 files changed, 121 insertions(+), 40 deletions(-) create mode 100644 scripts/catch_non_abstract_annotation.py diff --git a/django-stubs/contrib/admin/views/main.pyi b/django-stubs/contrib/admin/views/main.pyi index 2c1fd60..ad2deeb 100644 --- a/django-stubs/contrib/admin/views/main.pyi +++ b/django-stubs/contrib/admin/views/main.pyi @@ -81,9 +81,7 @@ class ChangeList: paginator: Any = ... def get_results(self, request: WSGIRequest) -> None: ... def get_ordering_field(self, field_name: Union[Callable, str]) -> Optional[Union[CombinedExpression, str]]: ... - def get_ordering( - self, request: WSGIRequest, queryset: QuerySet - ) -> Union[List[Union[Combinable, str]], List[Union[OrderBy, str]]]: ... + def get_ordering(self, request: WSGIRequest, queryset: QuerySet) -> List[Union[OrderBy, Combinable, str]]: ... def get_ordering_field_columns(self) -> OrderedDict: ... def get_queryset(self, request: WSGIRequest) -> QuerySet: ... def apply_select_related(self, qs: QuerySet) -> QuerySet: ... diff --git a/django-stubs/db/backends/base/introspection.pyi b/django-stubs/db/backends/base/introspection.pyi index c189e28..33d128f 100644 --- a/django-stubs/db/backends/base/introspection.pyi +++ b/django-stubs/db/backends/base/introspection.pyi @@ -1,5 +1,5 @@ from collections import namedtuple -from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Set, Type from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.utils import CursorWrapper @@ -13,7 +13,7 @@ class BaseDatabaseIntrospection: data_types_reverse: Any = ... connection: Any = ... def __init__(self, connection: BaseDatabaseWrapper) -> None: ... - def get_field_type(self, data_type: str, description: FieldInfo) -> Union[Tuple[str, Dict[str, int]], str]: ... + def get_field_type(self, data_type: str, description: FieldInfo) -> str: ... def table_name_converter(self, name: str) -> str: ... def column_name_converter(self, name: str) -> str: ... def table_names(self, cursor: Optional[CursorWrapper] = ..., include_views: bool = ...) -> List[str]: ... diff --git a/django-stubs/db/backends/base/operations.pyi b/django-stubs/db/backends/base/operations.pyi index a8d5575..43926e4 100644 --- a/django-stubs/db/backends/base/operations.pyi +++ b/django-stubs/db/backends/base/operations.pyi @@ -4,7 +4,6 @@ from typing import Any, List, Optional, Sequence, Tuple, Type, Union from django.core.management.color import Style from django.db.backends.base.base import BaseDatabaseWrapper -from django.db.backends.sqlite3.base import DatabaseWrapper from django.db.backends.utils import CursorWrapper from django.db.models.base import Model from django.db.models.expressions import Case, Expression @@ -13,6 +12,8 @@ from django.db.models.sql.compiler import SQLCompiler from django.db import DefaultConnectionProxy from django.db.models.fields import Field +_Connection = Union[DefaultConnectionProxy, BaseDatabaseWrapper] + class BaseDatabaseOperations: compiler_module: str = ... integer_field_ranges: Any = ... @@ -25,8 +26,8 @@ class BaseDatabaseOperations: UNBOUNDED_FOLLOWING: Any = ... CURRENT_ROW: str = ... explain_prefix: Any = ... - connection: Any = ... - def __init__(self, connection: Optional[Union[DefaultConnectionProxy, BaseDatabaseWrapper]]) -> None: ... + connection: _Connection = ... + def __init__(self, connection: Optional[_Connection]) -> None: ... def autoinc_sql(self, table: str, column: str) -> None: ... def bulk_batch_size(self, fields: Any, objs: Any): ... def cache_key_culling_sql(self) -> str: ... @@ -75,9 +76,9 @@ class BaseDatabaseOperations: def prep_for_like_query(self, x: str) -> str: ... prep_for_iexact_query: Any = ... def validate_autopk_value(self, value: int) -> int: ... - def adapt_unknown_value(self, value: Union[datetime, Decimal, int, str]) -> Union[int, str]: ... + def adapt_unknown_value(self, value: Any) -> Any: ... def adapt_datefield_value(self, value: Optional[date]) -> Optional[str]: ... - def adapt_datetimefield_value(self, value: None) -> None: ... + def adapt_datetimefield_value(self, value: Optional[datetime]) -> Optional[str]: ... def adapt_timefield_value(self, value: Optional[datetime]) -> Optional[str]: ... def adapt_decimalfield_value( self, value: Optional[Decimal], max_digits: Optional[int] = ..., decimal_places: Optional[int] = ... @@ -87,19 +88,17 @@ class BaseDatabaseOperations: def year_lookup_bounds_for_datetime_field(self, value: int) -> List[str]: ... def get_db_converters(self, expression: Expression) -> List[Any]: ... def convert_durationfield_value( - self, value: Optional[float], expression: Expression, connection: DatabaseWrapper + self, value: Optional[float], expression: Expression, connection: _Connection ) -> Optional[timedelta]: ... def check_expression_support(self, expression: Any) -> None: ... def combine_expression(self, connector: str, sub_expressions: List[str]) -> str: ... def combine_duration_expression(self, connector: Any, sub_expressions: Any): ... def binary_placeholder_sql(self, value: Optional[Case]) -> str: ... - def modify_insert_params( - self, placeholder: str, params: Union[List[None], List[bool], List[float], List[str]] - ) -> Union[List[None], List[bool], List[float], List[str]]: ... + def modify_insert_params(self, placeholder: str, params: Any) -> Any: ... def integer_field_range(self, internal_type: Any): ... def subtract_temporals(self, internal_type: Any, lhs: Any, rhs: Any): ... def window_frame_start(self, start: Any): ... def window_frame_end(self, end: Any): ... - def window_frame_rows_start_end(self, start: None = ..., end: None = ...) -> Any: ... - def window_frame_range_start_end(self, start: Optional[Any] = ..., end: Optional[Any] = ...): ... + def window_frame_rows_start_end(self, start: Optional[int] = ..., end: Optional[int] = ...) -> Any: ... + def window_frame_range_start_end(self, start: Optional[int] = ..., end: Optional[int] = ...) -> Any: ... def explain_query_prefix(self, format: Optional[str] = ..., **options: Any) -> str: ... diff --git a/django-stubs/db/backends/base/schema.pyi b/django-stubs/db/backends/base/schema.pyi index 586a25b..4da4248 100644 --- a/django-stubs/db/backends/base/schema.pyi +++ b/django-stubs/db/backends/base/schema.pyi @@ -48,7 +48,7 @@ class BaseDatabaseSchemaEditor(ContextManager[Any]): def quote_name(self, name: str) -> str: ... def column_sql( self, model: Type[Model], field: Field, include_default: bool = ... - ) -> Union[Tuple[None, None], Tuple[str, List[Any]]]: ... + ) -> Tuple[Optional[str], Optional[List[Any]]]: ... def skip_default(self, field: Any): ... def prepare_default(self, value: Any) -> None: ... def effective_default(self, field: Field) -> Optional[Union[int, str]]: ... diff --git a/django-stubs/db/backends/sqlite3/introspection.pyi b/django-stubs/db/backends/sqlite3/introspection.pyi index e4da384..f628216 100644 --- a/django-stubs/db/backends/sqlite3/introspection.pyi +++ b/django-stubs/db/backends/sqlite3/introspection.pyi @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional from django.db.backends.base.introspection import BaseDatabaseIntrospection @@ -8,6 +8,6 @@ def get_field_size(name: str) -> Optional[int]: ... class FlexibleFieldLookupDict: base_data_types_reverse: Any = ... - def __getitem__(self, key: str) -> Union[Tuple[str, Dict[str, int]], str]: ... + def __getitem__(self, key: str) -> Any: ... class DatabaseIntrospection(BaseDatabaseIntrospection): ... diff --git a/django-stubs/db/models/base.pyi b/django-stubs/db/models/base.pyi index 19f8aa1..439535c 100644 --- a/django-stubs/db/models/base.pyi +++ b/django-stubs/db/models/base.pyi @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union, Collection from django.core.checks.messages import CheckMessage from django.db.models.manager import Manager @@ -18,10 +18,10 @@ class Model(metaclass=ModelBase): pk: Any = ... def __init__(self: _Self, *args, **kwargs) -> None: ... def delete(self, using: Any = ..., keep_parents: bool = ...) -> Tuple[int, Dict[str, int]]: ... - def full_clean(self, exclude: Optional[List[str]] = ..., validate_unique: bool = ...) -> None: ... + def full_clean(self, exclude: Optional[Collection[str]] = ..., validate_unique: bool = ...) -> None: ... def clean(self) -> None: ... - def clean_fields(self, exclude: List[str] = ...) -> None: ... - def validate_unique(self, exclude: List[str] = ...) -> None: ... + def clean_fields(self, exclude: Optional[Collection[str]] = ...) -> None: ... + def validate_unique(self, exclude: Optional[Collection[str]] = ...) -> None: ... def save( self, force_insert: bool = ..., diff --git a/django-stubs/db/models/lookups.pyi b/django-stubs/db/models/lookups.pyi index 3a248a5..ddf341f 100644 --- a/django-stubs/db/models/lookups.pyi +++ b/django-stubs/db/models/lookups.pyi @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Iterable, List, Optional, Tuple, Type, Union, Mapping, TypeVar, Generic +from typing import Any, Iterable, List, Optional, Tuple, Type, Union, Mapping, TypeVar, Generic, Sequence from django.db.backends.sqlite3.base import DatabaseWrapper from django.db.models.expressions import Expression, Func @@ -33,9 +33,9 @@ class Lookup(Generic[_T]): ) -> Tuple[str, List[Union[int, str]]]: ... def process_rhs( self, compiler: SQLCompiler, connection: DatabaseWrapper - ) -> Tuple[str, Union[List[Union[int, str]], Tuple[int, int]]]: ... + ) -> Tuple[str, Sequence[Union[int, str]]]: ... def rhs_is_direct_value(self) -> bool: ... - def relabeled_clone(self, relabels: Mapping[str, str]) -> Union[BuiltinLookup, FieldGetDbPrepValueMixin]: ... + def relabeled_clone(self: _T, relabels: Mapping[str, str]) -> _T: ... def get_group_by_cols(self) -> List[Expression]: ... def as_sql(self, compiler: Any, connection: Any) -> Any: ... def contains_aggregate(self) -> bool: ... diff --git a/django-stubs/db/models/options.pyi b/django-stubs/db/models/options.pyi index d6e7d2f..2a0b19b 100644 --- a/django-stubs/db/models/options.pyi +++ b/django-stubs/db/models/options.pyi @@ -1,5 +1,5 @@ import collections -from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, Sequence, Set, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, Sequence, Tuple, Type, TypeVar, Union from django.apps.config import AppConfig from django.apps.registry import Apps @@ -23,8 +23,8 @@ IMMUTABLE_WARNING: str DEFAULT_NAMES: Tuple[str, ...] def normalize_together( - option_together: Any -) -> Union[List[Union[Tuple[str, str], int]], Set[Tuple[str, str]], Tuple, int, str]: ... + option_together: Union[Sequence[Tuple[str, str]], Tuple[str, str]] +) -> Tuple[Tuple[str, str], ...]: ... def make_immutable_fields_list( name: str, data: Union[Iterator[Any], List[Union[ArrayField, CIText]], List[Union[Field, FieldCacheMixin]]] ) -> ImmutableList: ... diff --git a/django-stubs/db/models/query_utils.pyi b/django-stubs/db/models/query_utils.pyi index 141083c..45a4944 100644 --- a/django-stubs/db/models/query_utils.pyi +++ b/django-stubs/db/models/query_utils.pyi @@ -1,13 +1,13 @@ -from collections import OrderedDict, namedtuple -from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Type, Union +from collections import namedtuple +from typing import Any, Collection, Dict, Iterator, List, Mapping, Optional, Sequence, Set, Tuple, Type, Union from django.db.models.base import Model -from django.db.models.expressions import Expression -from django.db.models.fields import Field from django.db.models.fields.mixins import FieldCacheMixin from django.db.models.sql.compiler import SQLCompiler from django.db.models.sql.query import Query from django.db.models.sql.where import WhereNode + +from django.db.models.fields import Field from django.utils import tree PathInfo = namedtuple("PathInfo", "from_opts to_opts target_fields join_field m2m direct filtered_relation") @@ -65,15 +65,11 @@ class RegisterLookupMixin: def select_related_descend( field: Field, restricted: bool, - requested: Optional[ - Union[Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, Dict[Any, Any]]]]]]]], bool] - ], - load_fields: Optional[Set[str]], + requested: Optional[Mapping[str, Any]], + load_fields: Optional[Collection[str]], reverse: bool = ..., ) -> bool: ... -def refs_expression( - lookup_parts: List[str], annotations: OrderedDict -) -> Union[Tuple[bool, Tuple], Tuple[Expression, List[str]]]: ... +def refs_expression(lookup_parts: Sequence[str], annotations: Mapping[str, bool]) -> Tuple[bool, Sequence[str]]: ... def check_rel_lookup_compatibility(model: Type[Model], target_opts: Any, field: FieldCacheMixin) -> bool: ... class FilteredRelation: diff --git a/scripts/catch_non_abstract_annotation.py b/scripts/catch_non_abstract_annotation.py new file mode 100644 index 0000000..df1417a --- /dev/null +++ b/scripts/catch_non_abstract_annotation.py @@ -0,0 +1,88 @@ +import os +from typing import Optional + +import libcst +from libcst import Annotation, BaseExpression, FunctionDef, Name, Subscript +from libcst.metadata import SyntacticPositionProvider + +BASE_DIR = 'django-stubs' + +fpath = os.path.join(BASE_DIR, 'core', 'checks', 'model_checks.pyi') +with open(fpath, 'r') as f: + contents = f.read() + + +tree = libcst.parse_module(contents) + + +class TypeAnnotationsAnalyzer(libcst.CSTVisitor): + METADATA_DEPENDENCIES = (SyntacticPositionProvider,) + + def __init__(self, fpath: str): + super().__init__() + self.fpath = fpath + + def get_node_location(self, node: FunctionDef) -> str: + start_line = self.get_metadata(SyntacticPositionProvider, node).start.line + return f'{self.fpath}:{start_line}' + + def show_error_for_node(self, node: FunctionDef, error_message: str): + print(self.get_node_location(node), error_message) + + def check_subscripted_annotation(self, annotation: BaseExpression) -> Optional[str]: + if isinstance(annotation, Subscript): + if isinstance(annotation.value, Name): + error_message = self.check_concrete_class_usage(annotation.value) + if error_message: + return error_message + + if annotation.value.value == 'Union': + for slice_param in annotation.slice: + if isinstance(slice_param.slice.value, Name): + error_message = self.check_concrete_class_usage(annotation.value) + if error_message: + return error_message + + def check_concrete_class_usage(self, name_node: Name) -> Optional[str]: + if name_node.value == 'List': + return (f'Concrete class {name_node.value!r} used for an iterable annotation. ' + f'Use abstract collection (Iterable, Collection, Sequence) instead') + + def visit_FunctionDef(self, node: FunctionDef) -> Optional[bool]: + params_node = node.params + for param_node in [*params_node.params, *params_node.default_params]: + param_name = param_node.name.value + annotation_node = param_node.annotation # type: Annotation + if annotation_node is not None: + annotation = annotation_node.annotation + if annotation.value == 'None': + self.show_error_for_node(node, f'"None" type annotation used for parameter {param_name!r}') + continue + + error_message = self.check_subscripted_annotation(annotation) + if error_message is not None: + self.show_error_for_node(node, error_message) + continue + + if node.returns is not None: + return_annotation = node.returns.annotation + if isinstance(return_annotation, Subscript) and return_annotation.value.value == 'Union': + self.show_error_for_node(node, 'Union is return type annotation') + + return False + + +for dirpath, dirnames, filenames in os.walk(BASE_DIR): + for filename in filenames: + fpath = os.path.join(dirpath, filename) + # skip all other checks for now, low priority + if not fpath.startswith(('django-stubs/db', 'django-stubs/views', 'django-stubs/apps', + 'django-stubs/http', 'django-stubs/contrib/postgres')): + continue + + with open(fpath, 'r') as f: + contents = f.read() + + tree = libcst.MetadataWrapper(libcst.parse_module(contents)) + analyzer = TypeAnnotationsAnalyzer(fpath) + tree.visit(analyzer)