improve annotations in some places (#202)

* improve annotations in some places

* linting
This commit is contained in:
Maxim Kurnikov
2019-10-07 14:50:45 +03:00
committed by GitHub
parent dceb075152
commit 8402e7c53e
10 changed files with 121 additions and 40 deletions

View File

@@ -81,9 +81,7 @@ class ChangeList:
paginator: Any = ... paginator: Any = ...
def get_results(self, request: WSGIRequest) -> None: ... def get_results(self, request: WSGIRequest) -> None: ...
def get_ordering_field(self, field_name: Union[Callable, str]) -> Optional[Union[CombinedExpression, str]]: ... def get_ordering_field(self, field_name: Union[Callable, str]) -> Optional[Union[CombinedExpression, str]]: ...
def get_ordering( def get_ordering(self, request: WSGIRequest, queryset: QuerySet) -> List[Union[OrderBy, Combinable, str]]: ...
self, request: WSGIRequest, queryset: QuerySet
) -> Union[List[Union[Combinable, str]], List[Union[OrderBy, str]]]: ...
def get_ordering_field_columns(self) -> OrderedDict: ... def get_ordering_field_columns(self) -> OrderedDict: ...
def get_queryset(self, request: WSGIRequest) -> QuerySet: ... def get_queryset(self, request: WSGIRequest) -> QuerySet: ...
def apply_select_related(self, qs: QuerySet) -> QuerySet: ... def apply_select_related(self, qs: QuerySet) -> QuerySet: ...

View File

@@ -1,5 +1,5 @@
from collections import namedtuple from collections import namedtuple
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union from typing import Any, Dict, List, Optional, Set, Type
from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.backends.utils import CursorWrapper from django.db.backends.utils import CursorWrapper
@@ -13,7 +13,7 @@ class BaseDatabaseIntrospection:
data_types_reverse: Any = ... data_types_reverse: Any = ...
connection: Any = ... connection: Any = ...
def __init__(self, connection: BaseDatabaseWrapper) -> None: ... def __init__(self, connection: BaseDatabaseWrapper) -> None: ...
def get_field_type(self, data_type: str, description: FieldInfo) -> Union[Tuple[str, Dict[str, int]], str]: ... def get_field_type(self, data_type: str, description: FieldInfo) -> str: ...
def table_name_converter(self, name: str) -> str: ... def table_name_converter(self, name: str) -> str: ...
def column_name_converter(self, name: str) -> str: ... def column_name_converter(self, name: str) -> str: ...
def table_names(self, cursor: Optional[CursorWrapper] = ..., include_views: bool = ...) -> List[str]: ... def table_names(self, cursor: Optional[CursorWrapper] = ..., include_views: bool = ...) -> List[str]: ...

View File

@@ -4,7 +4,6 @@ from typing import Any, List, Optional, Sequence, Tuple, Type, Union
from django.core.management.color import Style from django.core.management.color import Style
from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.backends.sqlite3.base import DatabaseWrapper
from django.db.backends.utils import CursorWrapper from django.db.backends.utils import CursorWrapper
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.expressions import Case, Expression from django.db.models.expressions import Case, Expression
@@ -13,6 +12,8 @@ from django.db.models.sql.compiler import SQLCompiler
from django.db import DefaultConnectionProxy from django.db import DefaultConnectionProxy
from django.db.models.fields import Field from django.db.models.fields import Field
_Connection = Union[DefaultConnectionProxy, BaseDatabaseWrapper]
class BaseDatabaseOperations: class BaseDatabaseOperations:
compiler_module: str = ... compiler_module: str = ...
integer_field_ranges: Any = ... integer_field_ranges: Any = ...
@@ -25,8 +26,8 @@ class BaseDatabaseOperations:
UNBOUNDED_FOLLOWING: Any = ... UNBOUNDED_FOLLOWING: Any = ...
CURRENT_ROW: str = ... CURRENT_ROW: str = ...
explain_prefix: Any = ... explain_prefix: Any = ...
connection: Any = ... connection: _Connection = ...
def __init__(self, connection: Optional[Union[DefaultConnectionProxy, BaseDatabaseWrapper]]) -> None: ... def __init__(self, connection: Optional[_Connection]) -> None: ...
def autoinc_sql(self, table: str, column: str) -> None: ... def autoinc_sql(self, table: str, column: str) -> None: ...
def bulk_batch_size(self, fields: Any, objs: Any): ... def bulk_batch_size(self, fields: Any, objs: Any): ...
def cache_key_culling_sql(self) -> str: ... def cache_key_culling_sql(self) -> str: ...
@@ -75,9 +76,9 @@ class BaseDatabaseOperations:
def prep_for_like_query(self, x: str) -> str: ... def prep_for_like_query(self, x: str) -> str: ...
prep_for_iexact_query: Any = ... prep_for_iexact_query: Any = ...
def validate_autopk_value(self, value: int) -> int: ... def validate_autopk_value(self, value: int) -> int: ...
def adapt_unknown_value(self, value: Union[datetime, Decimal, int, str]) -> Union[int, str]: ... def adapt_unknown_value(self, value: Any) -> Any: ...
def adapt_datefield_value(self, value: Optional[date]) -> Optional[str]: ... def adapt_datefield_value(self, value: Optional[date]) -> Optional[str]: ...
def adapt_datetimefield_value(self, value: None) -> None: ... def adapt_datetimefield_value(self, value: Optional[datetime]) -> Optional[str]: ...
def adapt_timefield_value(self, value: Optional[datetime]) -> Optional[str]: ... def adapt_timefield_value(self, value: Optional[datetime]) -> Optional[str]: ...
def adapt_decimalfield_value( def adapt_decimalfield_value(
self, value: Optional[Decimal], max_digits: Optional[int] = ..., decimal_places: Optional[int] = ... self, value: Optional[Decimal], max_digits: Optional[int] = ..., decimal_places: Optional[int] = ...
@@ -87,19 +88,17 @@ class BaseDatabaseOperations:
def year_lookup_bounds_for_datetime_field(self, value: int) -> List[str]: ... def year_lookup_bounds_for_datetime_field(self, value: int) -> List[str]: ...
def get_db_converters(self, expression: Expression) -> List[Any]: ... def get_db_converters(self, expression: Expression) -> List[Any]: ...
def convert_durationfield_value( def convert_durationfield_value(
self, value: Optional[float], expression: Expression, connection: DatabaseWrapper self, value: Optional[float], expression: Expression, connection: _Connection
) -> Optional[timedelta]: ... ) -> Optional[timedelta]: ...
def check_expression_support(self, expression: Any) -> None: ... def check_expression_support(self, expression: Any) -> None: ...
def combine_expression(self, connector: str, sub_expressions: List[str]) -> str: ... def combine_expression(self, connector: str, sub_expressions: List[str]) -> str: ...
def combine_duration_expression(self, connector: Any, sub_expressions: Any): ... def combine_duration_expression(self, connector: Any, sub_expressions: Any): ...
def binary_placeholder_sql(self, value: Optional[Case]) -> str: ... def binary_placeholder_sql(self, value: Optional[Case]) -> str: ...
def modify_insert_params( def modify_insert_params(self, placeholder: str, params: Any) -> Any: ...
self, placeholder: str, params: Union[List[None], List[bool], List[float], List[str]]
) -> Union[List[None], List[bool], List[float], List[str]]: ...
def integer_field_range(self, internal_type: Any): ... def integer_field_range(self, internal_type: Any): ...
def subtract_temporals(self, internal_type: Any, lhs: Any, rhs: Any): ... def subtract_temporals(self, internal_type: Any, lhs: Any, rhs: Any): ...
def window_frame_start(self, start: Any): ... def window_frame_start(self, start: Any): ...
def window_frame_end(self, end: Any): ... def window_frame_end(self, end: Any): ...
def window_frame_rows_start_end(self, start: None = ..., end: None = ...) -> Any: ... def window_frame_rows_start_end(self, start: Optional[int] = ..., end: Optional[int] = ...) -> Any: ...
def window_frame_range_start_end(self, start: Optional[Any] = ..., end: Optional[Any] = ...): ... def window_frame_range_start_end(self, start: Optional[int] = ..., end: Optional[int] = ...) -> Any: ...
def explain_query_prefix(self, format: Optional[str] = ..., **options: Any) -> str: ... def explain_query_prefix(self, format: Optional[str] = ..., **options: Any) -> str: ...

View File

@@ -48,7 +48,7 @@ class BaseDatabaseSchemaEditor(ContextManager[Any]):
def quote_name(self, name: str) -> str: ... def quote_name(self, name: str) -> str: ...
def column_sql( def column_sql(
self, model: Type[Model], field: Field, include_default: bool = ... self, model: Type[Model], field: Field, include_default: bool = ...
) -> Union[Tuple[None, None], Tuple[str, List[Any]]]: ... ) -> Tuple[Optional[str], Optional[List[Any]]]: ...
def skip_default(self, field: Any): ... def skip_default(self, field: Any): ...
def prepare_default(self, value: Any) -> None: ... def prepare_default(self, value: Any) -> None: ...
def effective_default(self, field: Field) -> Optional[Union[int, str]]: ... def effective_default(self, field: Field) -> Optional[Union[int, str]]: ...

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Optional
from django.db.backends.base.introspection import BaseDatabaseIntrospection from django.db.backends.base.introspection import BaseDatabaseIntrospection
@@ -8,6 +8,6 @@ def get_field_size(name: str) -> Optional[int]: ...
class FlexibleFieldLookupDict: class FlexibleFieldLookupDict:
base_data_types_reverse: Any = ... base_data_types_reverse: Any = ...
def __getitem__(self, key: str) -> Union[Tuple[str, Dict[str, int]], str]: ... def __getitem__(self, key: str) -> Any: ...
class DatabaseIntrospection(BaseDatabaseIntrospection): ... class DatabaseIntrospection(BaseDatabaseIntrospection): ...

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union, Collection
from django.core.checks.messages import CheckMessage from django.core.checks.messages import CheckMessage
from django.db.models.manager import Manager from django.db.models.manager import Manager
@@ -18,10 +18,10 @@ class Model(metaclass=ModelBase):
pk: Any = ... pk: Any = ...
def __init__(self: _Self, *args, **kwargs) -> None: ... def __init__(self: _Self, *args, **kwargs) -> None: ...
def delete(self, using: Any = ..., keep_parents: bool = ...) -> Tuple[int, Dict[str, int]]: ... def delete(self, using: Any = ..., keep_parents: bool = ...) -> Tuple[int, Dict[str, int]]: ...
def full_clean(self, exclude: Optional[List[str]] = ..., validate_unique: bool = ...) -> None: ... def full_clean(self, exclude: Optional[Collection[str]] = ..., validate_unique: bool = ...) -> None: ...
def clean(self) -> None: ... def clean(self) -> None: ...
def clean_fields(self, exclude: List[str] = ...) -> None: ... def clean_fields(self, exclude: Optional[Collection[str]] = ...) -> None: ...
def validate_unique(self, exclude: List[str] = ...) -> None: ... def validate_unique(self, exclude: Optional[Collection[str]] = ...) -> None: ...
def save( def save(
self, self,
force_insert: bool = ..., force_insert: bool = ...,

View File

@@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Any, Iterable, List, Optional, Tuple, Type, Union, Mapping, TypeVar, Generic from typing import Any, Iterable, List, Optional, Tuple, Type, Union, Mapping, TypeVar, Generic, Sequence
from django.db.backends.sqlite3.base import DatabaseWrapper from django.db.backends.sqlite3.base import DatabaseWrapper
from django.db.models.expressions import Expression, Func from django.db.models.expressions import Expression, Func
@@ -33,9 +33,9 @@ class Lookup(Generic[_T]):
) -> Tuple[str, List[Union[int, str]]]: ... ) -> Tuple[str, List[Union[int, str]]]: ...
def process_rhs( def process_rhs(
self, compiler: SQLCompiler, connection: DatabaseWrapper self, compiler: SQLCompiler, connection: DatabaseWrapper
) -> Tuple[str, Union[List[Union[int, str]], Tuple[int, int]]]: ... ) -> Tuple[str, Sequence[Union[int, str]]]: ...
def rhs_is_direct_value(self) -> bool: ... def rhs_is_direct_value(self) -> bool: ...
def relabeled_clone(self, relabels: Mapping[str, str]) -> Union[BuiltinLookup, FieldGetDbPrepValueMixin]: ... def relabeled_clone(self: _T, relabels: Mapping[str, str]) -> _T: ...
def get_group_by_cols(self) -> List[Expression]: ... def get_group_by_cols(self) -> List[Expression]: ...
def as_sql(self, compiler: Any, connection: Any) -> Any: ... def as_sql(self, compiler: Any, connection: Any) -> Any: ...
def contains_aggregate(self) -> bool: ... def contains_aggregate(self) -> bool: ...

View File

@@ -1,5 +1,5 @@
import collections import collections
from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, Sequence, Set, Tuple, Type, TypeVar, Union from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, Sequence, Tuple, Type, TypeVar, Union
from django.apps.config import AppConfig from django.apps.config import AppConfig
from django.apps.registry import Apps from django.apps.registry import Apps
@@ -23,8 +23,8 @@ IMMUTABLE_WARNING: str
DEFAULT_NAMES: Tuple[str, ...] DEFAULT_NAMES: Tuple[str, ...]
def normalize_together( def normalize_together(
option_together: Any option_together: Union[Sequence[Tuple[str, str]], Tuple[str, str]]
) -> Union[List[Union[Tuple[str, str], int]], Set[Tuple[str, str]], Tuple, int, str]: ... ) -> Tuple[Tuple[str, str], ...]: ...
def make_immutable_fields_list( def make_immutable_fields_list(
name: str, data: Union[Iterator[Any], List[Union[ArrayField, CIText]], List[Union[Field, FieldCacheMixin]]] name: str, data: Union[Iterator[Any], List[Union[ArrayField, CIText]], List[Union[Field, FieldCacheMixin]]]
) -> ImmutableList: ... ) -> ImmutableList: ...

View File

@@ -1,13 +1,13 @@
from collections import OrderedDict, namedtuple from collections import namedtuple
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Type, Union from typing import Any, Collection, Dict, Iterator, List, Mapping, Optional, Sequence, Set, Tuple, Type, Union
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.expressions import Expression
from django.db.models.fields import Field
from django.db.models.fields.mixins import FieldCacheMixin from django.db.models.fields.mixins import FieldCacheMixin
from django.db.models.sql.compiler import SQLCompiler from django.db.models.sql.compiler import SQLCompiler
from django.db.models.sql.query import Query from django.db.models.sql.query import Query
from django.db.models.sql.where import WhereNode from django.db.models.sql.where import WhereNode
from django.db.models.fields import Field
from django.utils import tree from django.utils import tree
PathInfo = namedtuple("PathInfo", "from_opts to_opts target_fields join_field m2m direct filtered_relation") PathInfo = namedtuple("PathInfo", "from_opts to_opts target_fields join_field m2m direct filtered_relation")
@@ -65,15 +65,11 @@ class RegisterLookupMixin:
def select_related_descend( def select_related_descend(
field: Field, field: Field,
restricted: bool, restricted: bool,
requested: Optional[ requested: Optional[Mapping[str, Any]],
Union[Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, Dict[str, Dict[Any, Any]]]]]]]], bool] load_fields: Optional[Collection[str]],
],
load_fields: Optional[Set[str]],
reverse: bool = ..., reverse: bool = ...,
) -> bool: ... ) -> bool: ...
def refs_expression( def refs_expression(lookup_parts: Sequence[str], annotations: Mapping[str, bool]) -> Tuple[bool, Sequence[str]]: ...
lookup_parts: List[str], annotations: OrderedDict
) -> Union[Tuple[bool, Tuple], Tuple[Expression, List[str]]]: ...
def check_rel_lookup_compatibility(model: Type[Model], target_opts: Any, field: FieldCacheMixin) -> bool: ... def check_rel_lookup_compatibility(model: Type[Model], target_opts: Any, field: FieldCacheMixin) -> bool: ...
class FilteredRelation: class FilteredRelation:

View File

@@ -0,0 +1,88 @@
import os
from typing import Optional
import libcst
from libcst import Annotation, BaseExpression, FunctionDef, Name, Subscript
from libcst.metadata import SyntacticPositionProvider
BASE_DIR = 'django-stubs'
fpath = os.path.join(BASE_DIR, 'core', 'checks', 'model_checks.pyi')
with open(fpath, 'r') as f:
contents = f.read()
tree = libcst.parse_module(contents)
class TypeAnnotationsAnalyzer(libcst.CSTVisitor):
METADATA_DEPENDENCIES = (SyntacticPositionProvider,)
def __init__(self, fpath: str):
super().__init__()
self.fpath = fpath
def get_node_location(self, node: FunctionDef) -> str:
start_line = self.get_metadata(SyntacticPositionProvider, node).start.line
return f'{self.fpath}:{start_line}'
def show_error_for_node(self, node: FunctionDef, error_message: str):
print(self.get_node_location(node), error_message)
def check_subscripted_annotation(self, annotation: BaseExpression) -> Optional[str]:
if isinstance(annotation, Subscript):
if isinstance(annotation.value, Name):
error_message = self.check_concrete_class_usage(annotation.value)
if error_message:
return error_message
if annotation.value.value == 'Union':
for slice_param in annotation.slice:
if isinstance(slice_param.slice.value, Name):
error_message = self.check_concrete_class_usage(annotation.value)
if error_message:
return error_message
def check_concrete_class_usage(self, name_node: Name) -> Optional[str]:
if name_node.value == 'List':
return (f'Concrete class {name_node.value!r} used for an iterable annotation. '
f'Use abstract collection (Iterable, Collection, Sequence) instead')
def visit_FunctionDef(self, node: FunctionDef) -> Optional[bool]:
params_node = node.params
for param_node in [*params_node.params, *params_node.default_params]:
param_name = param_node.name.value
annotation_node = param_node.annotation # type: Annotation
if annotation_node is not None:
annotation = annotation_node.annotation
if annotation.value == 'None':
self.show_error_for_node(node, f'"None" type annotation used for parameter {param_name!r}')
continue
error_message = self.check_subscripted_annotation(annotation)
if error_message is not None:
self.show_error_for_node(node, error_message)
continue
if node.returns is not None:
return_annotation = node.returns.annotation
if isinstance(return_annotation, Subscript) and return_annotation.value.value == 'Union':
self.show_error_for_node(node, 'Union is return type annotation')
return False
for dirpath, dirnames, filenames in os.walk(BASE_DIR):
for filename in filenames:
fpath = os.path.join(dirpath, filename)
# skip all other checks for now, low priority
if not fpath.startswith(('django-stubs/db', 'django-stubs/views', 'django-stubs/apps',
'django-stubs/http', 'django-stubs/contrib/postgres')):
continue
with open(fpath, 'r') as f:
contents = f.read()
tree = libcst.MetadataWrapper(libcst.parse_module(contents))
analyzer = TypeAnnotationsAnalyzer(fpath)
tree.visit(analyzer)