Merge pull request #139 from mkurnikov/none-to-optional-annotations

Replace None annotations with Optional[...] around, stability fixes
This commit is contained in:
Maxim Kurnikov
2019-08-24 17:32:14 +03:00
committed by GitHub
15 changed files with 98 additions and 52 deletions

View File

@@ -6,10 +6,6 @@ from django.http.response import HttpResponse, HttpResponseBase
logger: Any
class BaseHandler:
_view_middleware: None = ...
_template_response_middleware: None = ...
_exception_middleware: None = ...
_middleware_chain: None = ...
def load_middleware(self) -> None: ...
def make_view_atomic(self, view: Callable) -> Callable: ...
def get_exception_response(self, request: Any, resolver: Any, status_code: Any, exception: Any): ...

View File

@@ -5,6 +5,6 @@ from django.db.models.base import Model
def popen_wrapper(args: List[str], stdout_encoding: str = ...) -> Tuple[str, str, int]: ...
def handle_extensions(extensions: List[str]) -> Set[str]: ...
def find_command(cmd: str, path: None = ..., pathext: None = ...) -> Optional[str]: ...
def find_command(cmd: str, path: Optional[str] = ..., pathext: Optional[str] = ...) -> Optional[str]: ...
def get_random_secret_key(): ...
def parse_apps_and_model_labels(labels: List[str]) -> Tuple[Set[Type[Model]], Set[AppConfig]]: ...

View File

@@ -1,6 +1,7 @@
import types
from datetime import date, datetime, time
from decimal import Decimal
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Union, Type
from uuid import UUID
logger: Any
@@ -16,7 +17,12 @@ class CursorWrapper:
def __getattr__(self, attr: str) -> Any: ...
def __iter__(self) -> None: ...
def __enter__(self) -> CursorWrapper: ...
def __exit__(self, type: None, value: None, traceback: None) -> None: ...
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
tb: Optional[types.TracebackType],
) -> None: ...
def callproc(self, procname: str, params: List[Any] = ..., kparams: Dict[str, int] = ...) -> Any: ...
def execute(
self, sql: str, params: Optional[Union[Sequence[_SQLType], Mapping[str, _SQLType]]] = ...

View File

@@ -80,9 +80,9 @@ class ForeignObject(RelatedField[_ST, _GT]):
on_delete: Callable[..., None],
from_fields: Sequence[str],
to_fields: Sequence[str],
rel: None = ...,
rel: Optional[ForeignObjectRel] = ...,
related_name: Optional[str] = ...,
related_query_name: None = ...,
related_query_name: Optional[str] = ...,
limit_choices_to: Optional[Union[Dict[str, Any], Callable[[], Any]]] = ...,
parent_link: bool = ...,
db_constraint: bool = ...,

View File

@@ -71,7 +71,7 @@ class Options(Generic[_M]):
abstract: bool = ...
managed: bool = ...
proxy: bool = ...
proxy_for_model: None = ...
proxy_for_model: Optional[Type[Model]] = ...
concrete_model: Optional[Type[Model]] = ...
swappable: None = ...
parents: collections.OrderedDict = ...

View File

@@ -92,7 +92,7 @@ class QuerySet(Generic[_T], Collection[_T], Sized):
) -> ValuesQuerySet[_T, Any]: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> ValuesQuerySet[_T, datetime.date]: ...
def datetimes(
self, field_name: str, kind: str, order: str = ..., tzinfo: None = ...
self, field_name: str, kind: str, order: str = ..., tzinfo: Optional[datetime.tzinfo] = ...
) -> ValuesQuerySet[_T, datetime.datetime]: ...
def none(self) -> QuerySet[_T]: ...
def all(self) -> QuerySet[_T]: ...

View File

@@ -119,15 +119,17 @@ class SimpleTestCase(unittest.TestCase):
field_kwargs: None = ...,
empty_value: str = ...,
) -> Any: ...
def assertHTMLEqual(self, html1: str, html2: str, msg: None = ...) -> None: ...
def assertHTMLNotEqual(self, html1: str, html2: str, msg: None = ...) -> None: ...
def assertHTMLEqual(self, html1: str, html2: str, msg: Optional[str] = ...) -> None: ...
def assertHTMLNotEqual(self, html1: str, html2: str, msg: Optional[str] = ...) -> None: ...
def assertInHTML(
self, needle: str, haystack: SafeText, count: Optional[int] = ..., msg_prefix: str = ...
) -> None: ...
def assertJSONEqual(self, raw: str, expected_data: Union[Dict[str, str], bool, str], msg: None = ...) -> None: ...
def assertJSONNotEqual(self, raw: str, expected_data: str, msg: None = ...) -> None: ...
def assertXMLEqual(self, xml1: str, xml2: str, msg: None = ...) -> None: ...
def assertXMLNotEqual(self, xml1: str, xml2: str, msg: None = ...) -> None: ...
def assertJSONEqual(
self, raw: str, expected_data: Union[Dict[str, str], bool, str], msg: Optional[str] = ...
) -> None: ...
def assertJSONNotEqual(self, raw: str, expected_data: str, msg: Optional[str] = ...) -> None: ...
def assertXMLEqual(self, xml1: str, xml2: str, msg: Optional[str] = ...) -> None: ...
def assertXMLNotEqual(self, xml1: str, xml2: str, msg: Optional[str] = ...) -> None: ...
class TransactionTestCase(SimpleTestCase):
reset_sequences: bool = ...
@@ -141,7 +143,7 @@ class TransactionTestCase(SimpleTestCase):
values: Union[List[None], List[Tuple[str, str]], List[date], List[int], List[str], Set[str], QuerySet],
transform: Union[Callable, Type[str]] = ...,
ordered: bool = ...,
msg: None = ...,
msg: Optional[str] = ...,
) -> None: ...
def assertNumQueries(
self, num: int, func: Optional[Union[Callable, Type[list]]] = ..., *args: Any, using: Any = ..., **kwargs: Any

View File

@@ -1,8 +1,10 @@
from datetime import date
from typing import Any, Optional
from typing import Any, Optional, Dict
TIME_STRINGS: Any
TIME_STRINGS: Dict[str, str]
TIMESINCE_CHUNKS: Any
def timesince(d: date, now: Optional[date] = ..., reversed: bool = ..., time_strings: None = ...) -> str: ...
def timeuntil(d: date, now: Optional[date] = ..., time_strings: None = ...) -> str: ...
def timesince(
d: date, now: Optional[date] = ..., reversed: bool = ..., time_strings: Optional[Dict[str, str]] = ...
) -> str: ...
def timeuntil(d: date, now: Optional[date] = ..., time_strings: Optional[Dict[str, str]] = ...) -> str: ...

View File

@@ -1,7 +1,7 @@
import os
from collections import defaultdict
from contextlib import contextmanager
from typing import Dict, Iterator, Optional, Set, TYPE_CHECKING, Tuple, Type
from typing import Dict, Iterator, Optional, Set, TYPE_CHECKING, Tuple, Type, Union
from django.core.exceptions import FieldError
from django.db.models.base import Model
@@ -113,11 +113,13 @@ class DjangoFieldsContext:
is_nullable = self.get_field_nullability(field, method)
if isinstance(field, RelatedField):
related_model_cls = self.django_context.fields_context.get_related_model_cls(field)
if method == 'values':
primary_key_field = self.django_context.get_primary_key_field(field.related_model)
primary_key_field = self.django_context.get_primary_key_field(related_model_cls)
return self.get_field_get_type(api, primary_key_field, method=method)
model_info = helpers.lookup_class_typeinfo(api, field.related_model)
model_info = helpers.lookup_class_typeinfo(api, related_model_cls)
if model_info is None:
return AnyType(TypeOfAny.unannotated)
@@ -126,6 +128,17 @@ class DjangoFieldsContext:
return helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type',
is_nullable=is_nullable)
def get_related_model_cls(self, field: Union[RelatedField, ForeignObjectRel]) -> Type[Model]:
if isinstance(field, RelatedField):
related_model_cls = field.remote_field.model
else:
related_model_cls = field.field.model
if isinstance(related_model_cls, str):
related_model_cls = self.django_context.apps_registry.get_model(related_model_cls)
return related_model_cls
class DjangoLookupsContext:
def __init__(self, django_context: 'DjangoContext'):
@@ -144,12 +157,12 @@ class DjangoLookupsContext:
return self.django_context.get_primary_key_field(currently_observed_model)
current_field = currently_observed_model._meta.get_field(field_part)
if not isinstance(current_field, (ForeignObjectRel, RelatedField)):
continue
currently_observed_model = self.django_context.fields_context.get_related_model_cls(current_field)
if isinstance(current_field, ForeignObjectRel):
currently_observed_model = current_field.related_model
current_field = self.django_context.get_primary_key_field(currently_observed_model)
else:
if isinstance(current_field, RelatedField):
currently_observed_model = current_field.related_model
# if it is None, solve_lookup_type() will fail earlier
assert current_field is not None
@@ -213,10 +226,11 @@ class DjangoContext:
from django.contrib.contenttypes.fields import GenericForeignKey
expected_types = {}
# add pk
primary_key_field = self.get_primary_key_field(model_cls)
field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method=method)
expected_types['pk'] = field_set_type
# add pk if not abstract=True
if not model_cls._meta.abstract:
primary_key_field = self.get_primary_key_field(model_cls)
field_set_type = self.fields_context.get_field_set_type(api, primary_key_field, method=method)
expected_types['pk'] = field_set_type
for field in model_cls._meta.get_fields():
if isinstance(field, Field):
@@ -232,9 +246,9 @@ class DjangoContext:
expected_types[field_name] = AnyType(TypeOfAny.unannotated)
continue
related_model = field.related_model
if related_model._meta.proxy_for_model:
related_model = field.related_model._meta.proxy_for_model
related_model = self.fields_context.get_related_model_cls(field)
if related_model._meta.proxy_for_model is not None:
related_model = related_model._meta.proxy_for_model
related_model_info = helpers.lookup_class_typeinfo(api, related_model)
if related_model_info is None:

View File

@@ -1,19 +1,15 @@
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union, cast
from mypy import checker
from mypy.checker import TypeChecker
from mypy.mro import calculate_mro
from mypy.nodes import (
GDEF, MDEF, Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode, SymbolTable,
SymbolTableNode, TypeInfo, Var,
)
from mypy.nodes import (Block, ClassDef, Expression, GDEF, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolNode,
SymbolTable, SymbolTableNode, TypeInfo, Var)
from mypy.plugin import (
AttributeContext, CheckerPluginInterface, FunctionContext, MethodContext,
)
from mypy.types import AnyType, Instance, NoneTyp, TupleType
from mypy.types import Type as MypyType
from mypy.types import TypedDictType, TypeOfAny, UnionType
from mypy.types import AnyType, Instance, NoneTyp, TupleType, Type as MypyType, TypeOfAny, TypedDictType, UnionType
if TYPE_CHECKING:
from mypy_django_plugin.django.context import DjangoContext

View File

@@ -148,12 +148,14 @@ class NewSemanalDjangoPlugin(Plugin):
# forward relations
for field in self.django_context.get_model_fields(model_class):
if isinstance(field, RelatedField):
related_model_module = field.related_model.__module__
related_model_cls = self.django_context.fields_context.get_related_model_cls(field)
related_model_module = related_model_cls.__module__
if related_model_module != file.fullname():
deps.add(self._new_dependency(related_model_module))
# reverse relations
for relation in model_class._meta.related_objects:
related_model_module = relation.related_model.__module__
related_model_cls = self.django_context.fields_context.get_related_model_cls(relation)
related_model_module = related_model_cls.__module__
if related_model_module != file.fullname():
deps.add(self._new_dependency(related_model_module))
return list(deps)

View File

@@ -44,9 +44,12 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
assert isinstance(current_field, RelatedField)
related_model = related_model_to_set = current_field.related_model
if related_model_to_set._meta.proxy_for_model:
related_model_to_set = related_model._meta.proxy_for_model
related_model_cls = django_context.fields_context.get_related_model_cls(current_field)
related_model = related_model_cls
related_model_to_set = related_model_cls
if related_model_to_set._meta.proxy_for_model is not None:
related_model_to_set = related_model_to_set._meta.proxy_for_model
typechecker_api = helpers.get_typechecker_api(ctx)

View File

@@ -100,7 +100,8 @@ class AddRelatedModelsId(ModelClassInitializer):
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
for field in model_cls._meta.get_fields():
if isinstance(field, ForeignKey):
rel_primary_key_field = self.django_context.get_primary_key_field(field.related_model)
related_model_cls = self.django_context.fields_context.get_related_model_cls(field)
rel_primary_key_field = self.django_context.get_primary_key_field(related_model_cls)
field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__)
is_nullable = self.django_context.fields_context.get_field_nullability(field, None)
set_type, get_type = get_field_descriptor_types(field_info, is_nullable)
@@ -156,7 +157,8 @@ class AddManagers(ModelClassInitializer):
# no reverse accessor
continue
related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(relation.related_model)
related_model_cls = self.django_context.fields_context.get_related_model_cls(relation)
related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(related_model_cls)
if isinstance(relation, OneToOneRel):
self.add_new_node_to_model_class(attname, Instance(related_model_info, []))

View File

@@ -42,7 +42,8 @@ def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext
return None
if isinstance(lookup_field, RelatedField) and lookup_field.column == lookup:
lookup_field = django_context.get_primary_key_field(lookup_field.related_model)
related_model_cls = django_context.fields_context.get_related_model_cls(lookup_field)
lookup_field = django_context.get_primary_key_field(related_model_cls)
field_get_type = django_context.fields_context.get_field_get_type(helpers.get_typechecker_api(ctx),
lookup_field, method=method)

View File

@@ -233,3 +233,25 @@
name = models.CharField(primary_key=True, max_length=100)
class Book(models.Model):
publisher = models.ForeignKey(Publisher, on_delete=models.CASCADE)
- case: init_in_abstract_model_classmethod_should_not_throw_error_for_valid_fields
main: |
from myapp.models import MyModel
MyModel.base_init()
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class AbstractModel(models.Model):
class Meta:
abstract = True
text = models.CharField(max_length=100)
@classmethod
def base_init(cls) -> 'AbstractModel':
return cls(text='mytext')
class MyModel(AbstractModel):
pass