Merge pull request #139 from mkurnikov/none-to-optional-annotations

Replace None annotations with Optional[...] around, stability fixes
This commit is contained in:
Maxim Kurnikov
2019-08-24 17:32:14 +03:00
committed by GitHub
15 changed files with 98 additions and 52 deletions

View File

@@ -6,10 +6,6 @@ from django.http.response import HttpResponse, HttpResponseBase
logger: Any logger: Any
class BaseHandler: class BaseHandler:
_view_middleware: None = ...
_template_response_middleware: None = ...
_exception_middleware: None = ...
_middleware_chain: None = ...
def load_middleware(self) -> None: ... def load_middleware(self) -> None: ...
def make_view_atomic(self, view: Callable) -> Callable: ... def make_view_atomic(self, view: Callable) -> Callable: ...
def get_exception_response(self, request: Any, resolver: Any, status_code: Any, exception: Any): ... def get_exception_response(self, request: Any, resolver: Any, status_code: Any, exception: Any): ...

View File

@@ -5,6 +5,6 @@ from django.db.models.base import Model
def popen_wrapper(args: List[str], stdout_encoding: str = ...) -> Tuple[str, str, int]: ... def popen_wrapper(args: List[str], stdout_encoding: str = ...) -> Tuple[str, str, int]: ...
def handle_extensions(extensions: List[str]) -> Set[str]: ... def handle_extensions(extensions: List[str]) -> Set[str]: ...
def find_command(cmd: str, path: None = ..., pathext: None = ...) -> Optional[str]: ... def find_command(cmd: str, path: Optional[str] = ..., pathext: Optional[str] = ...) -> Optional[str]: ...
def get_random_secret_key(): ... def get_random_secret_key(): ...
def parse_apps_and_model_labels(labels: List[str]) -> Tuple[Set[Type[Model]], Set[AppConfig]]: ... def parse_apps_and_model_labels(labels: List[str]) -> Tuple[Set[Type[Model]], Set[AppConfig]]: ...

View File

@@ -1,6 +1,7 @@
import types
from datetime import date, datetime, time from datetime import date, datetime, time
from decimal import Decimal from decimal import Decimal
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Union from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Union, Type
from uuid import UUID from uuid import UUID
logger: Any logger: Any
@@ -16,7 +17,12 @@ class CursorWrapper:
def __getattr__(self, attr: str) -> Any: ... def __getattr__(self, attr: str) -> Any: ...
def __iter__(self) -> None: ... def __iter__(self) -> None: ...
def __enter__(self) -> CursorWrapper: ... def __enter__(self) -> CursorWrapper: ...
def __exit__(self, type: None, value: None, traceback: None) -> None: ... def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
tb: Optional[types.TracebackType],
) -> None: ...
def callproc(self, procname: str, params: List[Any] = ..., kparams: Dict[str, int] = ...) -> Any: ... def callproc(self, procname: str, params: List[Any] = ..., kparams: Dict[str, int] = ...) -> Any: ...
def execute( def execute(
self, sql: str, params: Optional[Union[Sequence[_SQLType], Mapping[str, _SQLType]]] = ... self, sql: str, params: Optional[Union[Sequence[_SQLType], Mapping[str, _SQLType]]] = ...

View File

@@ -80,9 +80,9 @@ class ForeignObject(RelatedField[_ST, _GT]):
on_delete: Callable[..., None], on_delete: Callable[..., None],
from_fields: Sequence[str], from_fields: Sequence[str],
to_fields: Sequence[str], to_fields: Sequence[str],
rel: None = ..., rel: Optional[ForeignObjectRel] = ...,
related_name: Optional[str] = ..., related_name: Optional[str] = ...,
related_query_name: None = ..., related_query_name: Optional[str] = ...,
limit_choices_to: Optional[Union[Dict[str, Any], Callable[[], Any]]] = ..., limit_choices_to: Optional[Union[Dict[str, Any], Callable[[], Any]]] = ...,
parent_link: bool = ..., parent_link: bool = ...,
db_constraint: bool = ..., db_constraint: bool = ...,

View File

@@ -71,7 +71,7 @@ class Options(Generic[_M]):
abstract: bool = ... abstract: bool = ...
managed: bool = ... managed: bool = ...
proxy: bool = ... proxy: bool = ...
proxy_for_model: None = ... proxy_for_model: Optional[Type[Model]] = ...
concrete_model: Optional[Type[Model]] = ... concrete_model: Optional[Type[Model]] = ...
swappable: None = ... swappable: None = ...
parents: collections.OrderedDict = ... parents: collections.OrderedDict = ...

View File

@@ -92,7 +92,7 @@ class QuerySet(Generic[_T], Collection[_T], Sized):
) -> ValuesQuerySet[_T, Any]: ... ) -> ValuesQuerySet[_T, Any]: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> ValuesQuerySet[_T, datetime.date]: ... def dates(self, field_name: str, kind: str, order: str = ...) -> ValuesQuerySet[_T, datetime.date]: ...
def datetimes( def datetimes(
self, field_name: str, kind: str, order: str = ..., tzinfo: None = ... self, field_name: str, kind: str, order: str = ..., tzinfo: Optional[datetime.tzinfo] = ...
) -> ValuesQuerySet[_T, datetime.datetime]: ... ) -> ValuesQuerySet[_T, datetime.datetime]: ...
def none(self) -> QuerySet[_T]: ... def none(self) -> QuerySet[_T]: ...
def all(self) -> QuerySet[_T]: ... def all(self) -> QuerySet[_T]: ...

View File

@@ -119,15 +119,17 @@ class SimpleTestCase(unittest.TestCase):
field_kwargs: None = ..., field_kwargs: None = ...,
empty_value: str = ..., empty_value: str = ...,
) -> Any: ... ) -> Any: ...
def assertHTMLEqual(self, html1: str, html2: str, msg: None = ...) -> None: ... def assertHTMLEqual(self, html1: str, html2: str, msg: Optional[str] = ...) -> None: ...
def assertHTMLNotEqual(self, html1: str, html2: str, msg: None = ...) -> None: ... def assertHTMLNotEqual(self, html1: str, html2: str, msg: Optional[str] = ...) -> None: ...
def assertInHTML( def assertInHTML(
self, needle: str, haystack: SafeText, count: Optional[int] = ..., msg_prefix: str = ... self, needle: str, haystack: SafeText, count: Optional[int] = ..., msg_prefix: str = ...
) -> None: ... ) -> None: ...
def assertJSONEqual(self, raw: str, expected_data: Union[Dict[str, str], bool, str], msg: None = ...) -> None: ... def assertJSONEqual(
def assertJSONNotEqual(self, raw: str, expected_data: str, msg: None = ...) -> None: ... self, raw: str, expected_data: Union[Dict[str, str], bool, str], msg: Optional[str] = ...
def assertXMLEqual(self, xml1: str, xml2: str, msg: None = ...) -> None: ... ) -> None: ...
def assertXMLNotEqual(self, xml1: str, xml2: str, msg: None = ...) -> None: ... def assertJSONNotEqual(self, raw: str, expected_data: str, msg: Optional[str] = ...) -> None: ...
def assertXMLEqual(self, xml1: str, xml2: str, msg: Optional[str] = ...) -> None: ...
def assertXMLNotEqual(self, xml1: str, xml2: str, msg: Optional[str] = ...) -> None: ...
class TransactionTestCase(SimpleTestCase): class TransactionTestCase(SimpleTestCase):
reset_sequences: bool = ... reset_sequences: bool = ...
@@ -141,7 +143,7 @@ class TransactionTestCase(SimpleTestCase):
values: Union[List[None], List[Tuple[str, str]], List[date], List[int], List[str], Set[str], QuerySet], values: Union[List[None], List[Tuple[str, str]], List[date], List[int], List[str], Set[str], QuerySet],
transform: Union[Callable, Type[str]] = ..., transform: Union[Callable, Type[str]] = ...,
ordered: bool = ..., ordered: bool = ...,
msg: None = ..., msg: Optional[str] = ...,
) -> None: ... ) -> None: ...
def assertNumQueries( def assertNumQueries(
self, num: int, func: Optional[Union[Callable, Type[list]]] = ..., *args: Any, using: Any = ..., **kwargs: Any self, num: int, func: Optional[Union[Callable, Type[list]]] = ..., *args: Any, using: Any = ..., **kwargs: Any

View File

@@ -1,8 +1,10 @@
from datetime import date from datetime import date
from typing import Any, Optional from typing import Any, Optional, Dict
TIME_STRINGS: Any TIME_STRINGS: Dict[str, str]
TIMESINCE_CHUNKS: Any TIMESINCE_CHUNKS: Any
def timesince(d: date, now: Optional[date] = ..., reversed: bool = ..., time_strings: None = ...) -> str: ... def timesince(
def timeuntil(d: date, now: Optional[date] = ..., time_strings: None = ...) -> str: ... d: date, now: Optional[date] = ..., reversed: bool = ..., time_strings: Optional[Dict[str, str]] = ...
) -> str: ...
def timeuntil(d: date, now: Optional[date] = ..., time_strings: Optional[Dict[str, str]] = ...) -> str: ...

View File

@@ -1,7 +1,7 @@
import os import os
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, Iterator, Optional, Set, TYPE_CHECKING, Tuple, Type from typing import Dict, Iterator, Optional, Set, TYPE_CHECKING, Tuple, Type, Union
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db.models.base import Model from django.db.models.base import Model
@@ -113,11 +113,13 @@ class DjangoFieldsContext:
is_nullable = self.get_field_nullability(field, method) is_nullable = self.get_field_nullability(field, method)
if isinstance(field, RelatedField): if isinstance(field, RelatedField):
related_model_cls = self.django_context.fields_context.get_related_model_cls(field)
if method == 'values': if method == 'values':
primary_key_field = self.django_context.get_primary_key_field(field.related_model) primary_key_field = self.django_context.get_primary_key_field(related_model_cls)
return self.get_field_get_type(api, primary_key_field, method=method) return self.get_field_get_type(api, primary_key_field, method=method)
model_info = helpers.lookup_class_typeinfo(api, field.related_model) model_info = helpers.lookup_class_typeinfo(api, related_model_cls)
if model_info is None: if model_info is None:
return AnyType(TypeOfAny.unannotated) return AnyType(TypeOfAny.unannotated)
@@ -126,6 +128,17 @@ class DjangoFieldsContext:
return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type', return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
is_nullable=is_nullable) is_nullable=is_nullable)
def get_related_model_cls(self, field: Union[RelatedField, ForeignObjectRel]) -> Type[Model]:
if isinstance(field, RelatedField):
related_model_cls = field.remote_field.model
else:
related_model_cls = field.field.model
if isinstance(related_model_cls, str):
related_model_cls = self.django_context.apps_registry.get_model(related_model_cls)
return related_model_cls
class DjangoLookupsContext: class DjangoLookupsContext:
def __init__(self, django_context: 'DjangoContext'): def __init__(self, django_context: 'DjangoContext'):
@@ -144,12 +157,12 @@ class DjangoLookupsContext:
return self.django_context.get_primary_key_field(currently_observed_model) return self.django_context.get_primary_key_field(currently_observed_model)
current_field = currently_observed_model._meta.get_field(field_part) current_field = currently_observed_model._meta.get_field(field_part)
if not isinstance(current_field, (ForeignObjectRel, RelatedField)):
continue
currently_observed_model = self.django_context.fields_context.get_related_model_cls(current_field)
if isinstance(current_field, ForeignObjectRel): if isinstance(current_field, ForeignObjectRel):
currently_observed_model = current_field.related_model
current_field = self.django_context.get_primary_key_field(currently_observed_model) current_field = self.django_context.get_primary_key_field(currently_observed_model)
else:
if isinstance(current_field, RelatedField):
currently_observed_model = current_field.related_model
# if it is None, solve_lookup_type() will fail earlier # if it is None, solve_lookup_type() will fail earlier
assert current_field is not None assert current_field is not None
@@ -213,7 +226,8 @@ class DjangoContext:
from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.fields import GenericForeignKey
expected_types = {} expected_types = {}
# add pk # add pk if not abstract=True
if not model_cls._meta.abstract:
primary_key_field = self.get_primary_key_field(model_cls) primary_key_field = self.get_primary_key_field(model_cls)
field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method=method) field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method=method)
expected_types['pk'] = field_set_type expected_types['pk'] = field_set_type
@@ -232,9 +246,9 @@ class DjangoContext:
expected_types[field_name] = AnyType(TypeOfAny.unannotated) expected_types[field_name] = AnyType(TypeOfAny.unannotated)
continue continue
related_model = field.related_model related_model = self.fields_context.get_related_model_cls(field)
if related_model._meta.proxy_for_model: if related_model._meta.proxy_for_model is not None:
related_model = field.related_model._meta.proxy_for_model related_model = related_model._meta.proxy_for_model
related_model_info = helpers.lookup_class_typeinfo(api, related_model) related_model_info = helpers.lookup_class_typeinfo(api, related_model)
if related_model_info is None: if related_model_info is None:

View File

@@ -1,19 +1,15 @@
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union, cast
from mypy import checker from mypy import checker
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.mro import calculate_mro from mypy.mro import calculate_mro
from mypy.nodes import ( from mypy.nodes import (Block, ClassDef, Expression, GDEF, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode,
GDEF, MDEF, Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode, SymbolTable, SymbolTable, SymbolTableNode, TypeInfo, Var)
SymbolTableNode, TypeInfo, Var,
)
from mypy.plugin import ( from mypy.plugin import (
AttributeContext, CheckerPluginInterface, FunctionContext, MethodContext, AttributeContext, CheckerPluginInterface, FunctionContext, MethodContext,
) )
from mypy.types import AnyType, Instance, NoneTyp, TupleType from mypy.types import AnyType, Instance, NoneTyp, TupleType, Type as MypyType, TypeOfAny, TypedDictType, UnionType
from mypy.types import Type as MypyType
from mypy.types import TypedDictType, TypeOfAny, UnionType
if TYPE_CHECKING: if TYPE_CHECKING:
from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.django.context import DjangoContext

View File

@@ -148,12 +148,14 @@ class NewSemanalDjangoPlugin(Plugin):
# forward relations # forward relations
for field in self.django_context.get_model_fields(model_class): for field in self.django_context.get_model_fields(model_class):
if isinstance(field, RelatedField): if isinstance(field, RelatedField):
related_model_module = field.related_model.__module__ related_model_cls = self.django_context.fields_context.get_related_model_cls(field)
related_model_module = related_model_cls.__module__
if related_model_module != file.fullname(): if related_model_module != file.fullname():
deps.add(self._new_dependency(related_model_module)) deps.add(self._new_dependency(related_model_module))
# reverse relations # reverse relations
for relation in model_class._meta.related_objects: for relation in model_class._meta.related_objects:
related_model_module = relation.related_model.__module__ related_model_cls = self.django_context.fields_context.get_related_model_cls(relation)
related_model_module = related_model_cls.__module__
if related_model_module != file.fullname(): if related_model_module != file.fullname():
deps.add(self._new_dependency(related_model_module)) deps.add(self._new_dependency(related_model_module))
return list(deps) return list(deps)

View File

@@ -44,9 +44,12 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
assert isinstance(current_field, RelatedField) assert isinstance(current_field, RelatedField)
related_model = related_model_to_set = current_field.related_model related_model_cls = django_context.fields_context.get_related_model_cls(current_field)
if related_model_to_set._meta.proxy_for_model:
related_model_to_set = related_model._meta.proxy_for_model related_model = related_model_cls
related_model_to_set = related_model_cls
if related_model_to_set._meta.proxy_for_model is not None:
related_model_to_set = related_model_to_set._meta.proxy_for_model
typechecker_api = helpers.get_typechecker_api(ctx) typechecker_api = helpers.get_typechecker_api(ctx)

View File

@@ -100,7 +100,8 @@ class AddRelatedModelsId(ModelClassInitializer):
def run_with_model_cls(self, model_cls: Type[Model]) -> None: def run_with_model_cls(self, model_cls: Type[Model]) -> None:
for field in model_cls._meta.get_fields(): for field in model_cls._meta.get_fields():
if isinstance(field, ForeignKey): if isinstance(field, ForeignKey):
rel_primary_key_field = self.django_context.get_primary_key_field(field.related_model) related_model_cls = self.django_context.fields_context.get_related_model_cls(field)
rel_primary_key_field = self.django_context.get_primary_key_field(related_model_cls)
field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__) field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__)
is_nullable = self.django_context.fields_context.get_field_nullability(field, None) is_nullable = self.django_context.fields_context.get_field_nullability(field, None)
set_type, get_type = get_field_descriptor_types(field_info, is_nullable) set_type, get_type = get_field_descriptor_types(field_info, is_nullable)
@@ -156,7 +157,8 @@ class AddManagers(ModelClassInitializer):
# no reverse accessor # no reverse accessor
continue continue
related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(relation.related_model) related_model_cls = self.django_context.fields_context.get_related_model_cls(relation)
related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(related_model_cls)
if isinstance(relation, OneToOneRel): if isinstance(relation, OneToOneRel):
self.add_new_node_to_model_class(attname, Instance(related_model_info, [])) self.add_new_node_to_model_class(attname, Instance(related_model_info, []))

View File

@@ -42,7 +42,8 @@ def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext
return None return None
if isinstance(lookup_field, RelatedField) and lookup_field.column == lookup: if isinstance(lookup_field, RelatedField) and lookup_field.column == lookup:
lookup_field = django_context.get_primary_key_field(lookup_field.related_model) related_model_cls = django_context.fields_context.get_related_model_cls(lookup_field)
lookup_field = django_context.get_primary_key_field(related_model_cls)
field_get_type = django_context.fields_context.get_field_get_type(helpers.get_typechecker_api(ctx), field_get_type = django_context.fields_context.get_field_get_type(helpers.get_typechecker_api(ctx),
lookup_field, method=method) lookup_field, method=method)

View File

@@ -233,3 +233,25 @@
name = models.CharField(primary_key=True, max_length=100) name = models.CharField(primary_key=True, max_length=100)
class Book(models.Model): class Book(models.Model):
publisher = models.ForeignKey(Publisher, on_delete=models.CASCADE) publisher = models.ForeignKey(Publisher, on_delete=models.CASCADE)
- case: init_in_abstract_model_classmethod_should_not_throw_error_for_valid_fields
main: |
from myapp.models import MyModel
MyModel.base_init()
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class AbstractModel(models.Model):
class Meta:
abstract = True
text = models.CharField(max_length=100)
@classmethod
def base_init(cls) -> 'AbstractModel':
return cls(text='mytext')
class MyModel(AbstractModel):
pass