15 Commits

Author SHA1 Message Date
Maxim Kurnikov
4cb10390cf bump version 2019-02-14 03:34:49 +03:00
Maxim Kurnikov
c1640b619f fix stale import 2019-02-14 03:21:11 +03:00
Maxim Kurnikov
a08ad80a0d fix star import parsing for settings 2019-02-14 03:16:07 +03:00
Maxim Kurnikov
f30cd092f1 add default for MYPY_DJANGO_CONFIG 2019-02-13 23:02:49 +03:00
Maxim Kurnikov
dcd9ee0bb8 enable 'validation' test folder 2019-02-13 21:12:58 +03:00
Maxim Kurnikov
26a80a8279 add properly typed FOREIGN_KEY_FIELD_NAME_id fields to models 2019-02-13 21:05:02 +03:00
Maxim Kurnikov
82de0a8791 lint 2019-02-13 20:00:42 +03:00
Maxim Kurnikov
79ebe20f2e add more test folders 2019-02-13 19:44:25 +03:00
Maxim Kurnikov
587c2c484b more accurate types for from_queryset() 2019-02-13 17:55:50 +03:00
Maxim Kurnikov
4a22da29cb add support for default related managers, fixes #18 2019-02-13 17:11:22 +03:00
Maxim Kurnikov
70378b8f40 preserve fallback to Any for unrecognized field types for init/create 2019-02-13 17:00:35 +03:00
Maxim Kurnikov
b7f7713c5a add support for get_user_model(), fixes #16 2019-02-13 15:56:21 +03:00
Maxim Kurnikov
2720b74242 add proper generic support for get_object_or_404/get_list_or_404, fixes #22 2019-02-13 14:52:10 +03:00
Maxim Kurnikov
563c0add5e add release script 2019-02-13 14:36:33 +03:00
Maxim Kurnikov
3191740c6b bump version 2019-02-13 14:36:17 +03:00
34 changed files with 603 additions and 239 deletions

View File

@@ -26,7 +26,7 @@ in your `mypy.ini` file.
## Configuration ## Configuration
In order to specify config file, set `MYPY_DJANGO_CONFIG` environment variable with path to the config file. In order to specify config file, set `MYPY_DJANGO_CONFIG` environment variable with path to the config file. Default is `./mypy_django.ini`
Config file format (.ini): Config file format (.ini):
``` ```

View File

@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Iterator
from django.contrib.admin.options import ModelAdmin from django.contrib.admin.options import ModelAdmin
from django.core.handlers.wsgi import WSGIRequest from django.core.handlers.wsgi import WSGIRequest
@@ -16,7 +16,7 @@ class ListFilter:
self, request: WSGIRequest, params: Dict[str, str], model: Type[Model], model_admin: ModelAdmin self, request: WSGIRequest, params: Dict[str, str], model: Type[Model], model_admin: ModelAdmin
) -> None: ... ) -> None: ...
def has_output(self) -> bool: ... def has_output(self) -> bool: ...
def choices(self, changelist: Any) -> None: ... def choices(self, changelist: Any) -> Optional[Iterator[Dict[str, Any]]]: ...
def queryset(self, request: Any, queryset: QuerySet) -> Optional[QuerySet]: ... def queryset(self, request: Any, queryset: QuerySet) -> Optional[QuerySet]: ...
def expected_parameters(self) -> Optional[List[str]]: ... def expected_parameters(self) -> Optional[List[str]]: ...

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional, Tuple, List from typing import Any, Optional, Tuple, List, overload
from django.db import models from django.db import models
@@ -30,4 +30,8 @@ class AbstractBaseUser(models.Model):
@classmethod @classmethod
def get_email_field_name(cls) -> str: ... def get_email_field_name(cls) -> str: ...
@classmethod @classmethod
@overload
def normalize_username(cls, username: str) -> str: ... def normalize_username(cls, username: str) -> str: ...
@classmethod
@overload
def normalize_username(cls, username: Any) -> Any: ...

View File

@@ -4,12 +4,7 @@ from django.db import models
from django.db.models.base import Model from django.db.models.base import Model
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
class ContentTypeManager(models.Manager): class ContentTypeManager(models.Manager["ContentType"]):
creation_counter: int
model: None
name: None
use_in_migrations: bool = ...
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
def get_by_natural_key(self, app_label: str, model: str) -> ContentType: ... def get_by_natural_key(self, app_label: str, model: str) -> ContentType: ...
def get_for_model(self, model: Union[Type[Model], Model], for_concrete_model: bool = ...) -> ContentType: ... def get_for_model(self, model: Union[Type[Model], Model], for_concrete_model: bool = ...) -> ContentType: ...
def get_for_models(self, *models: Any, for_concrete_models: bool = ...) -> Dict[Type[Model], ContentType]: ... def get_for_models(self, *models: Any, for_concrete_models: bool = ...) -> Dict[Type[Model], ContentType]: ...
@@ -18,9 +13,9 @@ class ContentTypeManager(models.Manager):
class ContentType(models.Model): class ContentType(models.Model):
id: int id: int
app_label: str = ... app_label: models.CharField = ...
model: str = ... model: models.CharField = ...
objects: Any = ... objects: ContentTypeManager = ...
@property @property
def name(self) -> str: ... def name(self) -> str: ...
def model_class(self) -> Optional[Type[Model]]: ... def model_class(self) -> Optional[Type[Model]]: ...

View File

@@ -1,14 +1,13 @@
from typing import Any, Optional from django.contrib.sites.models import Site
from django.db import models from django.db import models
class FlatPage(models.Model): class FlatPage(models.Model):
id: None url: models.CharField = ...
url: str = ... title: models.CharField = ...
title: str = ... content: models.TextField = ...
content: str = ... enable_comments: models.BooleanField = ...
enable_comments: bool = ... template_name: models.CharField = ...
template_name: str = ... registration_required: models.BooleanField = ...
registration_required: bool = ... sites: models.ManyToManyField[Site] = ...
sites: Any = ...
def get_absolute_url(self) -> str: ... def get_absolute_url(self) -> str: ...

View File

@@ -1,40 +1,43 @@
# Stubs for django.core.files.uploadedfile (Python 3.5)
from typing import Any, Dict, IO, Iterator, Optional, Union from typing import Any, Dict, IO, Iterator, Optional, Union
from django.core.files import temp as tempfile from django.core.files import temp as tempfile
from django.core.files.base import File from django.core.files.base import File
class UploadedFile(File): class UploadedFile(File):
content_type = ... # type: Optional[str] content_type: Optional[str] = ...
charset = ... # type: Optional[str] charset: Optional[str] = ...
content_type_extra = ... # type: Optional[Dict[str, str]] content_type_extra: Optional[Dict[str, str]] = ...
def __init__( def __init__(
self, self,
file: IO, file: Optional[IO] = ...,
name: str = None, name: Optional[str] = ...,
content_type: str = None, content_type: Optional[str] = ...,
size: int = None, size: Optional[int] = ...,
charset: str = None, charset: Optional[str] = ...,
content_type_extra: Dict[str, str] = None, content_type_extra: Optional[Dict[str, str]] = ...,
) -> None: ... ) -> None: ...
class TemporaryUploadedFile(UploadedFile): class TemporaryUploadedFile(UploadedFile):
def __init__( def __init__(
self, name: str, content_type: str, size: int, charset: str, content_type_extra: Dict[str, str] = None self,
name: Optional[str],
content_type: Optional[str],
size: Optional[int],
charset: Optional[str],
content_type_extra: Optional[Dict[str, str]] = ...,
) -> None: ... ) -> None: ...
def temporary_file_path(self) -> str: ... def temporary_file_path(self) -> str: ...
class InMemoryUploadedFile(UploadedFile): class InMemoryUploadedFile(UploadedFile):
field_name = ... # type: Optional[str] field_name: Optional[str] = ...
def __init__( def __init__(
self, self,
file: IO, file: IO,
field_name: Optional[str], field_name: Optional[str],
name: str, name: Optional[str],
content_type: Optional[str], content_type: Optional[str],
size: int, size: Optional[int],
charset: Optional[str], charset: Optional[str],
content_type_extra: Dict[str, str] = None, content_type_extra: Dict[str, str] = ...,
) -> None: ... ) -> None: ...
def chunks(self, chunk_size: int = None) -> Iterator[bytes]: ... def chunks(self, chunk_size: int = None) -> Iterator[bytes]: ...
def multiple_chunks(self, chunk_size: int = None) -> bool: ... def multiple_chunks(self, chunk_size: int = None) -> bool: ...

View File

@@ -168,12 +168,12 @@ class RawSQL(Expression):
def __init__(self, sql: str, params: Sequence[Any], output_field: Optional[_OutputField] = ...) -> None: ... def __init__(self, sql: str, params: Sequence[Any], output_field: Optional[_OutputField] = ...) -> None: ...
class Func(SQLiteNumericMixin, Expression): class Func(SQLiteNumericMixin, Expression):
function: Any = ... function: str = ...
template: str = ... template: str = ...
arg_joiner: str = ... arg_joiner: str = ...
arity: Any = ... arity: int = ...
source_expressions: List[Expression] = ... source_expressions: List[Expression] = ...
extra: Any = ... extra: Dict[Any, Any] = ...
def __init__(self, *expressions: Any, output_field: Optional[_OutputField] = ..., **extra: Any) -> None: ... def __init__(self, *expressions: Any, output_field: Optional[_OutputField] = ..., **extra: Any) -> None: ...
def get_source_expressions(self) -> List[Combinable]: ... def get_source_expressions(self) -> List[Combinable]: ...
def set_source_expressions(self, exprs: List[Expression]) -> None: ... def set_source_expressions(self, exprs: List[Expression]) -> None: ...

View File

@@ -207,6 +207,7 @@ class GenericIPAddressField(Field):
validators: Iterable[_ValidatorCallable] = ..., validators: Iterable[_ValidatorCallable] = ...,
error_messages: Optional[_ErrorMessagesToOverride] = ..., error_messages: Optional[_ErrorMessagesToOverride] = ...,
) -> None: ... ) -> None: ...
def __set__(self, instance, value: Union[str, int, Callable[..., Any], Combinable]): ...
def __get__(self, instance, owner) -> str: ... def __get__(self, instance, owner) -> str: ...
class DateTimeCheckMixin: ... class DateTimeCheckMixin: ...
@@ -269,7 +270,7 @@ class DateTimeField(DateField):
def __get__(self, instance, owner) -> datetime: ... def __get__(self, instance, owner) -> datetime: ...
class UUIDField(Field): class UUIDField(Field):
def __set__(self, instance, value: Any) -> None: ... def __set__(self, instance, value: Union[str, uuid.UUID]) -> None: ...
def __get__(self, instance, owner) -> uuid.UUID: ... def __get__(self, instance, owner) -> uuid.UUID: ...
class FilePathField(Field): class FilePathField(Field):

View File

@@ -33,7 +33,7 @@ from django.db.models.fields.reverse_related import (
) )
from django.db.models.query_utils import PathInfo, Q from django.db.models.query_utils import PathInfo, Q
from django.db.models.expressions import F from django.db.models.expressions import Combinable
if TYPE_CHECKING: if TYPE_CHECKING:
from django.db.models.manager import RelatedManager from django.db.models.manager import RelatedManager
@@ -105,11 +105,12 @@ class ForeignObject(RelatedField):
class ForeignKey(RelatedField, Generic[_T]): class ForeignKey(RelatedField, Generic[_T]):
def __init__(self, to: Union[Type[_T], str], on_delete: Any, related_name: str = ..., **kwargs): ... def __init__(self, to: Union[Type[_T], str], on_delete: Any, related_name: str = ..., **kwargs): ...
def __set__(self, instance, value: Union[Model, F]) -> None: ... def __set__(self, instance, value: Union[Model, Combinable]) -> None: ...
def __get__(self, instance, owner) -> _T: ... def __get__(self, instance, owner) -> _T: ...
class OneToOneField(RelatedField, Generic[_T]): class OneToOneField(RelatedField, Generic[_T]):
def __init__(self, to: Union[Type[_T], str], on_delete: Any, related_name: str = ..., **kwargs): ... def __init__(self, to: Union[Type[_T], str], on_delete: Any, related_name: str = ..., **kwargs): ...
def __set__(self, instance, value: Union[Model, Combinable]) -> None: ...
def __get__(self, instance, owner) -> _T: ... def __get__(self, instance, owner) -> _T: ...
class ManyToManyField(RelatedField, Generic[_T]): class ManyToManyField(RelatedField, Generic[_T]):

View File

@@ -9,22 +9,19 @@ class CumeDist(Func):
window_compatible: bool = ... window_compatible: bool = ...
class DenseRank(Func): class DenseRank(Func):
extra: Dict[Any, Any]
source_expressions: List[Any]
function: str = ...
name: str = ... name: str = ...
output_field: Any = ... output_field: Any = ...
window_compatible: bool = ... window_compatible: bool = ...
class FirstValue(Func): class FirstValue(Func):
arity: int = ...
function: str = ...
name: str = ... name: str = ...
window_compatible: bool = ... window_compatible: bool = ...
class LagLeadFunction(Func): class LagLeadFunction(Func):
window_compatible: bool = ... window_compatible: bool = ...
def __init__(self, expression: Optional[str], offset: int = ..., default: None = ..., **extra: Any) -> Any: ... def __init__(
self, expression: Optional[str], offset: int = ..., default: Optional[int] = ..., **extra: Any
) -> None: ...
class Lag(LagLeadFunction): class Lag(LagLeadFunction):
function: str = ... function: str = ...

View File

@@ -17,7 +17,9 @@ class BaseManager(QuerySet[_T]):
def deconstruct(self) -> Tuple[bool, str, None, Tuple, Dict[str, int]]: ... def deconstruct(self) -> Tuple[bool, str, None, Tuple, Dict[str, int]]: ...
def check(self, **kwargs: Any) -> List[Any]: ... def check(self, **kwargs: Any) -> List[Any]: ...
@classmethod @classmethod
def from_queryset(cls: Type[_Self], queryset_class: Any, class_name: Optional[Any] = ...) -> Type[_Self]: ... def from_queryset(
cls: Type[_Self], queryset_class: Type[QuerySet], class_name: Optional[str] = ...
) -> Type[_Self]: ...
@classmethod @classmethod
def _get_queryset_methods(cls, queryset_class: type) -> Dict[str, Any]: ... def _get_queryset_methods(cls, queryset_class: type) -> Dict[str, Any]: ...
def contribute_to_class(self, model: Type[Model], name: str) -> None: ... def contribute_to_class(self, model: Type[Model], name: str) -> None: ...

View File

@@ -1,4 +1,4 @@
from io import BytesIO from io import BytesIO, StringIO
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from django.http.request import QueryDict from django.http.request import QueryDict
@@ -11,7 +11,7 @@ class MultiPartParser:
def __init__( def __init__(
self, self,
META: Dict[str, Any], META: Dict[str, Any],
input_data: BytesIO, input_data: Union[StringIO, BytesIO],
upload_handlers: Union[List[Any], ImmutableList], upload_handlers: Union[List[Any], ImmutableList],
encoding: Optional[str] = ..., encoding: Optional[str] = ...,
) -> None: ... ) -> None: ...

View File

@@ -27,7 +27,7 @@ class HttpResponseBase(Iterable[AnyStr]):
charset: Optional[str] = ..., charset: Optional[str] = ...,
) -> None: ... ) -> None: ...
def serialize_headers(self) -> bytes: ... def serialize_headers(self) -> bytes: ...
def __setitem__(self, header: Union[str, bytes], value: Union[str, bytes]) -> None: ... def __setitem__(self, header: Union[str, bytes], value: Union[str, bytes, int]) -> None: ...
def __delitem__(self, header: Union[str, bytes]) -> None: ... def __delitem__(self, header: Union[str, bytes]) -> None: ...
def __getitem__(self, header: Union[str, bytes]) -> str: ... def __getitem__(self, header: Union[str, bytes]) -> str: ...
def has_header(self, header: str) -> bool: ... def has_header(self, header: str) -> bool: ...

View File

@@ -1,7 +1,7 @@
from typing import Any, Optional from typing import Any, Optional
from django.http.request import HttpRequest from django.http.request import HttpRequest
from django.http.response import HttpResponse, HttpResponseNotFound, HttpResponsePermanentRedirect from django.http.response import HttpResponseBase, HttpResponsePermanentRedirect
from django.utils.deprecation import MiddlewareMixin from django.utils.deprecation import MiddlewareMixin
class CommonMiddleware(MiddlewareMixin): class CommonMiddleware(MiddlewareMixin):
@@ -9,9 +9,9 @@ class CommonMiddleware(MiddlewareMixin):
def process_request(self, request: HttpRequest) -> Optional[HttpResponsePermanentRedirect]: ... def process_request(self, request: HttpRequest) -> Optional[HttpResponsePermanentRedirect]: ...
def should_redirect_with_slash(self, request: HttpRequest) -> bool: ... def should_redirect_with_slash(self, request: HttpRequest) -> bool: ...
def get_full_path_with_slash(self, request: HttpRequest) -> str: ... def get_full_path_with_slash(self, request: HttpRequest) -> str: ...
def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse: ... def process_response(self, request: HttpRequest, response: HttpResponseBase) -> HttpResponseBase: ...
class BrokenLinkEmailsMiddleware(MiddlewareMixin): class BrokenLinkEmailsMiddleware(MiddlewareMixin):
def process_response(self, request: HttpRequest, response: HttpResponseNotFound) -> HttpResponseNotFound: ... def process_response(self, request: HttpRequest, response: HttpResponseBase) -> HttpResponseBase: ...
def is_internal_request(self, domain: str, referer: str) -> bool: ... def is_internal_request(self, domain: str, referer: str) -> bool: ...
def is_ignorable_request(self, request: HttpRequest, uri: str, domain: str, referer: str) -> bool: ... def is_ignorable_request(self, request: HttpRequest, uri: str, domain: str, referer: str) -> bool: ...

View File

@@ -1,9 +1,9 @@
from typing import Any, Callable, Dict, List, Optional, Type, Union, Sequence, Protocol from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence, Type, TypeVar, Union
from django.db.models import Manager, QuerySet
from django.db.models.base import Model from django.db.models.base import Model
from django.http.response import HttpResponse as HttpResponse, HttpResponseRedirect as HttpResponseRedirect from django.http.response import HttpResponse as HttpResponse, HttpResponseRedirect as HttpResponseRedirect
from django.db.models import Manager, QuerySet
from django.http import HttpRequest from django.http import HttpRequest
def render_to_response( def render_to_response(
@@ -28,6 +28,9 @@ class SupportsGetAbsoluteUrl(Protocol):
def redirect( def redirect(
to: Union[Callable, str, SupportsGetAbsoluteUrl], *args: Any, permanent: bool = ..., **kwargs: Any to: Union[Callable, str, SupportsGetAbsoluteUrl], *args: Any, permanent: bool = ..., **kwargs: Any
) -> HttpResponseRedirect: ... ) -> HttpResponseRedirect: ...
def get_object_or_404(klass: Union[Type[Model], Manager, QuerySet], *args: Any, **kwargs: Any) -> Model: ...
def get_list_or_404(klass: Union[Type[Model], Manager, QuerySet], *args: Any, **kwargs: Any) -> List[Model]: ... _T = TypeVar("_T", bound=Model)
def get_object_or_404(klass: Union[Type[_T], Manager[_T], QuerySet[_T]], *args: Any, **kwargs: Any) -> _T: ...
def get_list_or_404(klass: Union[Type[_T], Manager[_T], QuerySet[_T]], *args: Any, **kwargs: Any) -> List[_T]: ...
def resolve_url(to: Union[Callable, Model, str], *args: Any, **kwargs: Any) -> str: ... def resolve_url(to: Union[Callable, Model, str], *args: Any, **kwargs: Any) -> str: ...

View File

@@ -53,7 +53,7 @@ class Template:
nodelist: NodeList = ... nodelist: NodeList = ...
def __init__( def __init__(
self, self,
template_string: str, template_string: Union[Template, str],
origin: Optional[Origin] = ..., origin: Optional[Origin] = ...,
name: Optional[str] = ..., name: Optional[str] = ...,
engine: Optional[Engine] = ..., engine: Optional[Engine] = ...,

View File

@@ -1,4 +1,4 @@
from datetime import datetime from datetime import datetime, date, time
from decimal import Decimal from decimal import Decimal
from typing import Any, Iterator, List, Optional, Union from typing import Any, Iterator, List, Optional, Union
@@ -8,14 +8,16 @@ FORMAT_SETTINGS: Any
def reset_format_cache() -> None: ... def reset_format_cache() -> None: ...
def iter_format_modules(lang: str, format_module_path: Optional[Union[List[str], str]] = ...) -> Iterator[Any]: ... def iter_format_modules(lang: str, format_module_path: Optional[Union[List[str], str]] = ...) -> Iterator[Any]: ...
def get_format_modules(lang: Optional[str] = ..., reverse: bool = ...) -> List[Any]: ... def get_format_modules(lang: Optional[str] = ..., reverse: bool = ...) -> List[Any]: ...
def get_format( def get_format(format_type: str, lang: Optional[str] = ..., use_l10n: Optional[bool] = ...) -> str: ...
format_type: str, lang: Optional[str] = ..., use_l10n: Optional[bool] = ...
) -> Union[List[str], int, str]: ...
get_format_lazy: Any get_format_lazy: Any
def date_format(value: Union[datetime, str], format: Optional[str] = ..., use_l10n: Optional[bool] = ...) -> str: ... def date_format(
def time_format(value: Union[datetime, str], format: Optional[str] = ..., use_l10n: None = ...) -> str: ... value: Union[date, datetime, str], format: Optional[str] = ..., use_l10n: Optional[bool] = ...
) -> str: ...
def time_format(
value: Union[time, datetime, str], format: Optional[str] = ..., use_l10n: Optional[bool] = ...
) -> str: ...
def number_format( def number_format(
value: Union[Decimal, float, str], value: Union[Decimal, float, str],
decimal_pos: Optional[int] = ..., decimal_pos: Optional[int] = ...,

View File

@@ -1,5 +1,6 @@
import logging.config import logging.config
from typing import Any, Callable, Dict, Optional from logging import LogRecord
from typing import Any, Callable, Dict, Optional, Union
from django.core.mail.backends.locmem import EmailBackend from django.core.mail.backends.locmem import EmailBackend
from django.core.management.color import Style from django.core.management.color import Style
@@ -20,9 +21,13 @@ class AdminEmailHandler(logging.Handler):
class CallbackFilter(logging.Filter): class CallbackFilter(logging.Filter):
callback: Callable = ... callback: Callable = ...
def __init__(self, callback: Callable) -> None: ... def __init__(self, callback: Callable) -> None: ...
def filter(self, record: Union[str, LogRecord]) -> bool: ...
class RequireDebugFalse(logging.Filter): ... class RequireDebugFalse(logging.Filter):
class RequireDebugTrue(logging.Filter): ... def filter(self, record: Union[str, LogRecord]) -> bool: ...
class RequireDebugTrue(logging.Filter):
def filter(self, record: Union[str, LogRecord]) -> bool: ...
class ServerFormatter(logging.Formatter): class ServerFormatter(logging.Formatter):
datefmt: None datefmt: None

View File

@@ -1,11 +1,11 @@
from decimal import Decimal from decimal import Decimal
from typing import Any, Optional, Tuple, Union from typing import Optional, Sequence, Union
def format( def format(
number: Union[Decimal, float, str], number: Union[Decimal, float, str],
decimal_sep: str, decimal_sep: str,
decimal_pos: Optional[int] = ..., decimal_pos: Optional[int] = ...,
grouping: Union[Tuple[int, int, int], int] = ..., grouping: Union[int, Sequence[int]] = ...,
thousand_sep: str = ..., thousand_sep: str = ...,
force_grouping: bool = ..., force_grouping: bool = ...,
use_l10n: Optional[bool] = ..., use_l10n: Optional[bool] = ...,

View File

@@ -1,19 +1,21 @@
from typing import Any, Dict, List, Optional, Type from typing import Any, Callable, Dict, Optional, Sequence, Type, Union
from django.db import models
from django.forms import models as model_forms, Form # type: ignore # This will be solved when adding forms module
from django.http import HttpResponse, HttpRequest
from django.views.generic.base import ContextMixin, TemplateResponseMixin, View from django.views.generic.base import ContextMixin, TemplateResponseMixin, View
from django.views.generic.detail import BaseDetailView, SingleObjectMixin, SingleObjectTemplateResponseMixin from django.views.generic.detail import BaseDetailView, SingleObjectMixin, SingleObjectTemplateResponseMixin
from typing_extensions import Literal
from django.db import models
from django.forms import Form
from django.http import HttpRequest, HttpResponse
class FormMixin(ContextMixin): class FormMixin(ContextMixin):
initial = ... # type: Dict[str, object] initial: Dict[str, Any] = ...
form_class = ... # type: Optional[Type[Form]] form_class: Optional[Type[Form]] = ...
success_url = ... # type: Optional[str] success_url: Optional[Union[str, Callable[..., Any]]] = ...
prefix = ... # type: Optional[str] prefix: Optional[str] = ...
request = ... # type: HttpRequest request: HttpRequest = ...
def render_to_response(self, context: Dict[str, object], **response_kwargs: object) -> HttpResponse: ... def render_to_response(self, context: Dict[str, Any], **response_kwargs: object) -> HttpResponse: ...
def get_initial(self) -> Dict[str, object]: ... def get_initial(self) -> Dict[str, Any]: ...
def get_prefix(self) -> Optional[str]: ... def get_prefix(self) -> Optional[str]: ...
def get_form_class(self) -> Type[Form]: ... def get_form_class(self) -> Type[Form]: ...
def get_form(self, form_class: Type[Form] = None) -> Form: ... def get_form(self, form_class: Type[Form] = None) -> Form: ...
@@ -24,8 +26,8 @@ class FormMixin(ContextMixin):
def get_context_data(self, **kwargs: object) -> Dict[str, Any]: ... def get_context_data(self, **kwargs: object) -> Dict[str, Any]: ...
class ModelFormMixin(FormMixin, SingleObjectMixin): class ModelFormMixin(FormMixin, SingleObjectMixin):
fields = ... # type: Optional[List[str]] fields: Optional[Union[Sequence[str], Literal["__all__"]]] = ...
object = ... # type: models.Model object: models.Model = ...
def get_form_class(self) -> Type[Form]: ... def get_form_class(self) -> Type[Form]: ...
def get_form_kwargs(self) -> Dict[str, object]: ... def get_form_kwargs(self) -> Dict[str, object]: ...
def get_success_url(self) -> str: ... def get_success_url(self) -> str: ...

View File

@@ -1,5 +1,5 @@
from configparser import ConfigParser from configparser import ConfigParser
from typing import Optional from typing import List, Optional
from dataclasses import dataclass from dataclasses import dataclass
@@ -10,12 +10,16 @@ class Config:
ignore_missing_settings: bool = False ignore_missing_settings: bool = False
@classmethod @classmethod
def from_config_file(self, fpath: str) -> 'Config': def from_config_file(cls, fpath: str) -> 'Config':
ini_config = ConfigParser() ini_config = ConfigParser()
ini_config.read(fpath) ini_config.read(fpath)
if not ini_config.has_section('mypy_django_plugin'): if not ini_config.has_section('mypy_django_plugin'):
raise ValueError('Invalid config file: no [mypy_django_plugin] section') raise ValueError('Invalid config file: no [mypy_django_plugin] section')
return Config(django_settings_module=ini_config.get('mypy_django_plugin', 'django_settings',
fallback=None), django_settings = ini_config.get('mypy_django_plugin', 'django_settings',
fallback=None)
if django_settings:
django_settings = django_settings.strip()
return Config(django_settings_module=django_settings,
ignore_missing_settings=ini_config.get('mypy_django_plugin', 'ignore_missing_settings', ignore_missing_settings=ini_config.get('mypy_django_plugin', 'ignore_missing_settings',
fallback=False)) fallback=False))

View File

@@ -1,9 +1,11 @@
import typing import typing
from typing import Dict, Optional from typing import Dict, Optional
from mypy.nodes import Expression, ImportedName, MypyFile, NameExpr, SymbolNode, TypeInfo from mypy.checker import TypeChecker
from mypy.nodes import AssignmentStmt, ClassDef, Expression, FuncDef, ImportedName, Lvalue, MypyFile, NameExpr, SymbolNode, \
TypeInfo
from mypy.plugin import FunctionContext from mypy.plugin import FunctionContext
from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeVarType from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, TypeVarType, UnionType
MODEL_CLASS_FULLNAME = 'django.db.models.base.Model' MODEL_CLASS_FULLNAME = 'django.db.models.base.Model'
FIELD_FULLNAME = 'django.db.models.fields.Field' FIELD_FULLNAME = 'django.db.models.fields.Field'
@@ -146,3 +148,109 @@ def get_argument_type_by_name(ctx: FunctionContext, name: str) -> Optional[Type]
# Either an error or no value passed. # Either an error or no value passed.
return None return None
return arg_types[0] return arg_types[0]
def get_setting_expr(api: TypeChecker, setting_name: str) -> Optional[Expression]:
try:
settings_sym = api.modules['django.conf'].names['settings']
except KeyError:
return None
settings_type: TypeInfo = settings_sym.type.type
auth_user_model_sym = settings_type.get(setting_name)
if not auth_user_model_sym:
return None
module, _, name = auth_user_model_sym.fullname.rpartition('.')
if module not in api.modules:
return None
module_file = api.modules.get(module)
for name_expr, value_expr in iter_over_assignments(module_file):
if isinstance(name_expr, NameExpr) and name_expr.name == setting_name:
return value_expr
return None
def iter_over_assignments(
class_or_module: typing.Union[ClassDef, MypyFile]) -> typing.Iterator[typing.Tuple[Lvalue, Expression]]:
if isinstance(class_or_module, ClassDef):
statements = class_or_module.defs.body
else:
statements = class_or_module.defs
for stmt in statements:
if not isinstance(stmt, AssignmentStmt):
continue
if len(stmt.lvalues) > 1:
# not supported yet
continue
yield stmt.lvalues[0], stmt.rvalue
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
if not isinstance(tp, Instance):
return None
if tp.type.has_base(FIELD_FULLNAME):
set_method = tp.type.get_method('__set__')
if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType):
if 'value' in set_method.type.arg_names:
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
if isinstance(set_value_type, Instance):
set_value_type = fill_typevars(tp, set_value_type)
return set_value_type
elif isinstance(set_value_type, UnionType):
items_no_typevars = []
for item in set_value_type.items:
if isinstance(item, Instance):
item = fill_typevars(tp, item)
items_no_typevars.append(item)
return UnionType(items_no_typevars)
field_getter_type = extract_field_getter_type(tp)
if field_getter_type:
return field_getter_type
return None
def extract_field_getter_type(tp: Instance) -> Optional[Type]:
if not isinstance(tp, Instance):
return None
if tp.type.has_base(FIELD_FULLNAME):
get_method = tp.type.get_method('__get__')
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
return get_method.type.ret_type
# GenericForeignKey
if tp.type.has_base(GENERIC_FOREIGN_KEY_FULLNAME):
return AnyType(TypeOfAny.special_form)
return None
def get_django_metadata(model: TypeInfo) -> Dict[str, typing.Any]:
return model.metadata.setdefault('django', {})
def get_related_field_primary_key_names(base_model: TypeInfo) -> typing.List[str]:
django_metadata = get_django_metadata(base_model)
return django_metadata.setdefault('related_field_primary_keys', [])
def get_fields_metadata(model: TypeInfo) -> Dict[str, typing.Any]:
return get_django_metadata(model).setdefault('fields', {})
def extract_primary_key_type_for_set(model: TypeInfo) -> Optional[Type]:
for field_name, props in get_fields_metadata(model).items():
is_primary_key = props.get('primary_key', False)
if is_primary_key:
return extract_field_setter_type(model.names[field_name].type)
return None
def extract_primary_key_type_for_get(model: TypeInfo) -> Optional[Type]:
for field_name, props in get_fields_metadata(model).items():
is_primary_key = props.get('primary_key', False)
if is_primary_key:
return extract_field_getter_type(model.names[field_name].type)
return None

View File

@@ -2,19 +2,18 @@ import os
from typing import Callable, Dict, Optional, cast from typing import Callable, Dict, Optional, cast
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.nodes import TypeInfo from mypy.nodes import MemberExpr, TypeInfo
from mypy.options import Options from mypy.options import Options
from mypy.plugin import ClassDefContext, FunctionContext, MethodContext, Plugin from mypy.plugin import AttributeContext, ClassDefContext, FunctionContext, MethodContext, Plugin
from mypy.types import Instance, Type from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeType
from mypy_django_plugin import helpers, monkeypatch from mypy_django_plugin import helpers, monkeypatch
from mypy_django_plugin.config import Config from mypy_django_plugin.config import Config
from mypy_django_plugin.plugins import init_create
from mypy_django_plugin.plugins.fields import determine_type_of_array_field, record_field_properties_into_outer_model_class from mypy_django_plugin.plugins.fields import determine_type_of_array_field, record_field_properties_into_outer_model_class
from mypy_django_plugin.plugins.init_create import redefine_and_typecheck_model_init, redefine_and_typecheck_model_create from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations, get_string_value_from_expr
from mypy_django_plugin.plugins.migrations import determine_model_cls_from_string_for_migrations
from mypy_django_plugin.plugins.models import process_model_class from mypy_django_plugin.plugins.models import process_model_class
from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with from mypy_django_plugin.plugins.related_fields import extract_to_parameter_as_get_ret_type_for_related_field, reparametrize_with
from mypy_django_plugin.plugins.settings import AddSettingValuesToDjangoConfObject from mypy_django_plugin.plugins.settings import AddSettingValuesToDjangoConfObject, get_settings_metadata
def transform_model_class(ctx: ClassDefContext) -> None: def transform_model_class(ctx: ClassDefContext) -> None:
@@ -56,6 +55,75 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type:
return ret return ret
def return_user_model_hook(ctx: FunctionContext) -> Type:
api = cast(TypeChecker, ctx.api)
setting_expr = helpers.get_setting_expr(api, 'AUTH_USER_MODEL')
if setting_expr is None:
return ctx.default_return_type
model_path = get_string_value_from_expr(setting_expr)
if model_path is None:
return ctx.default_return_type
app_label, _, model_class_name = model_path.rpartition('.')
if app_label is None:
return ctx.default_return_type
model_fullname = helpers.get_model_fullname(app_label, model_class_name,
all_modules=api.modules)
if model_fullname is None:
api.fail(f'"{app_label}.{model_class_name}" model class is not imported so far. Try to import it '
f'(under if TYPE_CHECKING) at the beginning of the current file',
context=ctx.context)
return ctx.default_return_type
model_info = helpers.lookup_fully_qualified_generic(model_fullname,
all_modules=api.modules)
if model_info is None or not isinstance(model_info, TypeInfo):
return ctx.default_return_type
return TypeType(Instance(model_info, []))
def extract_and_return_primary_key_of_bound_related_field_parameter(ctx: AttributeContext) -> Type:
if not isinstance(ctx.default_attr_type, Instance) or not (ctx.default_attr_type.type.fullname() == 'builtins.int'):
return ctx.default_attr_type
if not isinstance(ctx.type, Instance) or not ctx.type.type.has_base(helpers.MODEL_CLASS_FULLNAME):
return ctx.default_attr_type
field_name = ctx.context.name.split('_')[0]
sym = ctx.type.type.get(field_name)
if sym and isinstance(sym.type, Instance) and len(sym.type.args) > 0:
to_arg = sym.type.args[0]
if isinstance(to_arg, AnyType):
return AnyType(TypeOfAny.special_form)
model_type: TypeInfo = to_arg.type
primary_key_type = helpers.extract_primary_key_type_for_get(model_type)
if primary_key_type:
return primary_key_type
return ctx.default_attr_type
class ExtractSettingType:
def __init__(self, module_fullname: str):
self.module_fullname = module_fullname
def __call__(self, ctx: AttributeContext) -> Type:
api = cast(TypeChecker, ctx.api)
original_module = api.modules.get(self.module_fullname)
if original_module is None:
return ctx.default_attr_type
definition = ctx.context
if isinstance(definition, MemberExpr):
sym = original_module.names.get(definition.name)
if sym and sym.type:
return sym.type
return ctx.default_attr_type
class DjangoPlugin(Plugin): class DjangoPlugin(Plugin):
def __init__(self, options: Options) -> None: def __init__(self, options: Options) -> None:
super().__init__(options) super().__init__(options)
@@ -63,20 +131,20 @@ class DjangoPlugin(Plugin):
monkeypatch.restore_original_load_graph() monkeypatch.restore_original_load_graph()
monkeypatch.restore_original_dependencies_handling() monkeypatch.restore_original_dependencies_handling()
config_fpath = os.environ.get('MYPY_DJANGO_CONFIG') config_fpath = os.environ.get('MYPY_DJANGO_CONFIG', 'mypy_django.ini')
if config_fpath: if config_fpath and os.path.exists(config_fpath):
self.config = Config.from_config_file(config_fpath) self.config = Config.from_config_file(config_fpath)
self.django_settings = self.config.django_settings_module self.django_settings_module = self.config.django_settings_module
else: else:
self.config = Config() self.config = Config()
self.django_settings = None self.django_settings_module = None
if 'DJANGO_SETTINGS_MODULE' in os.environ: if 'DJANGO_SETTINGS_MODULE' in os.environ:
self.django_settings = os.environ['DJANGO_SETTINGS_MODULE'] self.django_settings_module = os.environ['DJANGO_SETTINGS_MODULE']
settings_modules = ['django.conf.global_settings'] settings_modules = ['django.conf.global_settings']
if self.django_settings: if self.django_settings_module:
settings_modules.append(self.django_settings) settings_modules.append(self.django_settings_module)
monkeypatch.add_modules_as_a_source_seed_files(settings_modules) monkeypatch.add_modules_as_a_source_seed_files(settings_modules)
monkeypatch.inject_modules_as_dependencies_for_django_conf_settings(settings_modules) monkeypatch.inject_modules_as_dependencies_for_django_conf_settings(settings_modules)
@@ -105,6 +173,9 @@ class DjangoPlugin(Plugin):
def get_function_hook(self, fullname: str def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]: ) -> Optional[Callable[[FunctionContext], Type]]:
if fullname == 'django.contrib.auth.get_user_model':
return return_user_model_hook
if fullname in {helpers.FOREIGN_KEY_FULLNAME, if fullname in {helpers.FOREIGN_KEY_FULLNAME,
helpers.ONETOONE_FIELD_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME,
helpers.MANYTOMANY_FIELD_FULLNAME}: helpers.MANYTOMANY_FIELD_FULLNAME}:
@@ -121,15 +192,16 @@ class DjangoPlugin(Plugin):
if sym and isinstance(sym.node, TypeInfo): if sym and isinstance(sym.node, TypeInfo):
if sym.node.has_base(helpers.FIELD_FULLNAME): if sym.node.has_base(helpers.FIELD_FULLNAME):
return record_field_properties_into_outer_model_class return record_field_properties_into_outer_model_class
if sym.node.metadata.get('django', {}).get('generated_init'): if sym.node.metadata.get('django', {}).get('generated_init'):
return redefine_and_typecheck_model_init return init_create.redefine_and_typecheck_model_init
def get_method_hook(self, fullname: str def get_method_hook(self, fullname: str
) -> Optional[Callable[[MethodContext], Type]]: ) -> Optional[Callable[[MethodContext], Type]]:
manager_classes = self._get_current_manager_bases() manager_classes = self._get_current_manager_bases()
class_fullname, _, method_name = fullname.rpartition('.') class_fullname, _, method_name = fullname.rpartition('.')
if class_fullname in manager_classes and method_name == 'create': if class_fullname in manager_classes and method_name == 'create':
return redefine_and_typecheck_model_create return init_create.redefine_and_typecheck_model_create
if fullname in {'django.apps.registry.Apps.get_model', if fullname in {'django.apps.registry.Apps.get_model',
'django.db.migrations.state.StateApps.get_model'}: 'django.db.migrations.state.StateApps.get_model'}:
@@ -143,8 +215,8 @@ class DjangoPlugin(Plugin):
if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS: if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS:
settings_modules = ['django.conf.global_settings'] settings_modules = ['django.conf.global_settings']
if self.django_settings: if self.django_settings_module:
settings_modules.append(self.django_settings) settings_modules.append(self.django_settings_module)
return AddSettingValuesToDjangoConfObject(settings_modules, return AddSettingValuesToDjangoConfObject(settings_modules,
self.config.ignore_missing_settings) self.config.ignore_missing_settings)
@@ -153,6 +225,17 @@ class DjangoPlugin(Plugin):
return None return None
def get_attribute_hook(self, fullname: str
) -> Optional[Callable[[AttributeContext], Type]]:
module, _, name = fullname.rpartition('.')
sym = self.lookup_fully_qualified('django.conf.LazySettings')
if sym and isinstance(sym.node, TypeInfo):
metadata = get_settings_metadata(sym.node)
if module == 'builtins.object' and name in metadata:
return ExtractSettingType(module_fullname=metadata[name])
return extract_and_return_primary_key_of_bound_related_field_parameter
def plugin(version): def plugin(version):
return DjangoPlugin return DjangoPlugin

View File

@@ -1,11 +1,12 @@
from typing import Dict, Optional, Set, cast, Any from typing import Dict, Optional, Set, cast
from mypy.checker import TypeChecker from mypy.checker import TypeChecker
from mypy.nodes import FuncDef, TypeInfo, Var from mypy.nodes import TypeInfo, Var
from mypy.plugin import FunctionContext, MethodContext from mypy.plugin import FunctionContext, MethodContext
from mypy.types import AnyType, CallableType, Instance, Type, TypeOfAny, UnionType from mypy.types import AnyType, Instance, Type, TypeOfAny, UnionType
from mypy_django_plugin import helpers from mypy_django_plugin import helpers
from mypy_django_plugin.helpers import extract_field_setter_type, extract_primary_key_type_for_set, get_fields_metadata
def extract_base_pointer_args(model: TypeInfo) -> Set[str]: def extract_base_pointer_args(model: TypeInfo) -> Set[str]:
@@ -69,11 +70,14 @@ def redefine_and_typecheck_model_init(ctx: FunctionContext) -> Type:
def redefine_and_typecheck_model_create(ctx: MethodContext) -> Type: def redefine_and_typecheck_model_create(ctx: MethodContext) -> Type:
api = cast(TypeChecker, ctx.api) api = cast(TypeChecker, ctx.api)
if isinstance(ctx.type, Instance) and len(ctx.type.args) > 0: if isinstance(ctx.type, Instance) and len(ctx.type.args) > 0:
model: TypeInfo = ctx.type.args[0].type model_generic_arg = ctx.type.args[0]
else: else:
if isinstance(ctx.default_return_type, AnyType): model_generic_arg = ctx.default_return_type
return ctx.default_return_type
model: TypeInfo = ctx.default_return_type.type if isinstance(model_generic_arg, AnyType):
return ctx.default_return_type
model: TypeInfo = model_generic_arg.type
# extract name of base models for _ptr # extract name of base models for _ptr
base_pointer_args = extract_base_pointer_args(model) base_pointer_args = extract_base_pointer_args(model)
@@ -100,46 +104,6 @@ def redefine_and_typecheck_model_create(ctx: MethodContext) -> Type:
return ctx.default_return_type return ctx.default_return_type
def extract_field_setter_type(tp: Instance) -> Optional[Type]:
if not isinstance(tp, Instance):
return None
if tp.type.has_base(helpers.FIELD_FULLNAME):
set_method = tp.type.get_method('__set__')
if isinstance(set_method, FuncDef) and isinstance(set_method.type, CallableType):
if 'value' in set_method.type.arg_names:
set_value_type = set_method.type.arg_types[set_method.type.arg_names.index('value')]
if isinstance(set_value_type, Instance):
set_value_type = helpers.fill_typevars(tp, set_value_type)
return set_value_type
elif isinstance(set_value_type, UnionType):
items_no_typevars = []
for item in set_value_type.items:
if isinstance(item, Instance):
item = helpers.fill_typevars(tp, item)
items_no_typevars.append(item)
return UnionType(items_no_typevars)
get_method = tp.type.get_method('__get__')
if isinstance(get_method, FuncDef) and isinstance(get_method.type, CallableType):
return get_method.type.ret_type
# GenericForeignKey
if tp.type.has_base(helpers.GENERIC_FOREIGN_KEY_FULLNAME):
return AnyType(TypeOfAny.special_form)
return None
def get_fields_metadata(model: TypeInfo) -> Dict[str, Any]:
return model.metadata.setdefault('django', {}).setdefault('fields', {})
def extract_primary_key_type(model: TypeInfo) -> Optional[Type]:
for field_name, props in get_fields_metadata(model).items():
is_primary_key = props.get('primary_key', False)
if is_primary_key:
return extract_field_setter_type(model.names[field_name].type)
return None
def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]: def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]:
field_metadata = get_fields_metadata(model).get(field_name, {}) field_metadata = get_fields_metadata(model).get(field_name, {})
if 'choices' in field_metadata: if 'choices' in field_metadata:
@@ -150,33 +114,39 @@ def extract_choices_type(model: TypeInfo, field_name: str) -> Optional[str]:
def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]: def extract_expected_types(ctx: FunctionContext, model: TypeInfo) -> Dict[str, Type]:
expected_types: Dict[str, Type] = {} expected_types: Dict[str, Type] = {}
primary_key_type = extract_primary_key_type(model) primary_key_type = extract_primary_key_type_for_set(model)
if not primary_key_type: if not primary_key_type:
# no explicit primary key, set pk to Any and add id # no explicit primary key, set pk to Any and add id
primary_key_type = AnyType(TypeOfAny.special_form) primary_key_type = AnyType(TypeOfAny.special_form)
expected_types['id'] = ctx.api.named_generic_type('builtins.int', []) expected_types['id'] = ctx.api.named_generic_type('builtins.int', [])
expected_types['pk'] = primary_key_type expected_types['pk'] = primary_key_type
for base in model.mro: for base in model.mro:
for name, sym in base.names.items(): for name, sym in base.names.items():
if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance): # do not redefine special attrs
tp = sym.node.type if name in {'_meta', 'pk'}:
field_type = extract_field_setter_type(tp) continue
if field_type is None: if isinstance(sym.node, Var):
continue if sym.node.type is None or isinstance(sym.node.type, AnyType):
# types are not ready, fallback to Any
expected_types[name] = AnyType(TypeOfAny.from_unimported_type)
expected_types[name + '_id'] = AnyType(TypeOfAny.from_unimported_type)
choices_type_fullname = extract_choices_type(model, name) elif isinstance(sym.node.type, Instance):
if choices_type_fullname: tp = sym.node.type
field_type = UnionType([field_type, ctx.api.named_generic_type(choices_type_fullname, [])]) field_type = extract_field_setter_type(tp)
if field_type is None:
continue
if tp.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}: if tp.type.fullname() in {helpers.FOREIGN_KEY_FULLNAME, helpers.ONETOONE_FIELD_FULLNAME}:
ref_to_model = tp.args[0] ref_to_model = tp.args[0]
if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME): primary_key_type = AnyType(TypeOfAny.special_form)
primary_key_type = extract_primary_key_type(ref_to_model.type) if isinstance(ref_to_model, Instance) and ref_to_model.type.has_base(helpers.MODEL_CLASS_FULLNAME):
if not primary_key_type: typ = extract_primary_key_type_for_set(ref_to_model.type)
primary_key_type = AnyType(TypeOfAny.special_form) if typ:
primary_key_type = typ
expected_types[name + '_id'] = primary_key_type expected_types[name + '_id'] = primary_key_type
if field_type: if field_type:
expected_types[name] = field_type expected_types[name] = field_type
return expected_types return expected_types

View File

@@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod
from typing import Dict, Iterator, List, Optional, Tuple, cast from typing import Dict, Iterator, List, Optional, Tuple, cast
import dataclasses import dataclasses
from mypy.nodes import ARG_STAR, ARG_STAR2, Argument, AssignmentStmt, CallExpr, ClassDef, Context, Expression, IndexExpr, \ from mypy.nodes import ARG_STAR, ARG_STAR2, Argument, CallExpr, ClassDef, Context, Expression, IndexExpr, \
Lvalue, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var Lvalue, MDEF, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var
from mypy.plugin import ClassDefContext from mypy.plugin import ClassDefContext
from mypy.plugins.common import add_method from mypy.plugins.common import add_method
@@ -10,6 +10,7 @@ from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import AnyType, Instance, NoneTyp, TypeOfAny from mypy.types import AnyType, Instance, NoneTyp, TypeOfAny
from mypy_django_plugin import helpers from mypy_django_plugin import helpers
from mypy_django_plugin.helpers import iter_over_assignments
@dataclasses.dataclass @dataclasses.dataclass
@@ -55,16 +56,6 @@ class ModelClassInitializer(metaclass=ABCMeta):
raise NotImplementedError() raise NotImplementedError()
def iter_over_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, Expression]]:
for stmt in klass.defs.body:
if not isinstance(stmt, AssignmentStmt):
continue
if len(stmt.lvalues) > 1:
# not supported yet
continue
yield stmt.lvalues[0], stmt.rvalue
def iter_call_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, CallExpr]]: def iter_call_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, CallExpr]]:
for lvalue, rvalue in iter_over_assignments(klass): for lvalue, rvalue in iter_over_assignments(klass):
if isinstance(rvalue, CallExpr): if isinstance(rvalue, CallExpr):
@@ -83,8 +74,11 @@ def iter_over_one_to_n_related_fields(klass: ClassDef) -> Iterator[Tuple[NameExp
class SetIdAttrsForRelatedFields(ModelClassInitializer): class SetIdAttrsForRelatedFields(ModelClassInitializer):
def run(self) -> None: def run(self) -> None:
for lvalue, rvalue in iter_over_one_to_n_related_fields(self.model_classdef): for lvalue, rvalue in iter_over_one_to_n_related_fields(self.model_classdef):
self.add_new_node_to_model_class(lvalue.name + '_id', # base_model_info = self.api.named_type('builtins.object').type
typ=self.api.named_type('__builtins__.int')) # helpers.get_related_field_primary_key_names(base_model_info).append(node_name)
node_name = lvalue.name + '_id'
self.add_new_node_to_model_class(name=node_name,
typ=self.api.builtin_type('builtins.int'))
class InjectAnyAsBaseForNestedMeta(ModelClassInitializer): class InjectAnyAsBaseForNestedMeta(ModelClassInitializer):
@@ -188,15 +182,17 @@ class AddRelatedManagers(ModelClassInitializer):
return None return None
if self.model_classdef.fullname == ref_to_fullname: if self.model_classdef.fullname == ref_to_fullname:
related_manager_name = defn.name.lower() + '_set'
if 'related_name' in rvalue.arg_names: if 'related_name' in rvalue.arg_names:
related_name_expr = rvalue.args[rvalue.arg_names.index('related_name')] related_name_expr = rvalue.args[rvalue.arg_names.index('related_name')]
if not isinstance(related_name_expr, StrExpr): if not isinstance(related_name_expr, StrExpr):
return None return None
related_name = related_name_expr.value related_manager_name = related_name_expr.value
typ = get_related_field_type(rvalue, self.api, defn.info)
if typ is None: typ = get_related_field_type(rvalue, self.api, defn.info)
return None if typ is None:
self.add_new_node_to_model_class(related_name, typ) return None
self.add_new_node_to_model_class(related_manager_name, typ)
def iter_over_classdefs(module_file: MypyFile) -> Iterator[ClassDef]: def iter_over_classdefs(module_file: MypyFile) -> Iterator[ClassDef]:

View File

@@ -1,9 +1,10 @@
from typing import List, Optional, cast from typing import Iterable, List, Optional, cast
from mypy.nodes import ClassDef, Context, MypyFile, SymbolNode, SymbolTableNode, Var from mypy.nodes import ClassDef, Context, ImportAll, MypyFile, SymbolNode, SymbolTableNode, TypeInfo, Var
from mypy.plugin import ClassDefContext from mypy.plugin import ClassDefContext
from mypy.semanal import SemanticAnalyzerPass2 from mypy.semanal import SemanticAnalyzerPass2
from mypy.types import Instance, NoneTyp, Type, UnionType from mypy.types import AnyType, Instance, NoneTyp, Type, TypeOfAny, UnionType
from mypy.util import correct_relative_import
def get_error_context(node: SymbolNode) -> Context: def get_error_context(node: SymbolNode) -> Context:
@@ -36,14 +37,44 @@ def make_sym_copy_of_setting(sym: SymbolTableNode) -> Optional[SymbolTableNode]:
return None return None
def load_settings_from_module(settings_classdef: ClassDef, module: MypyFile) -> None: def get_settings_metadata(lazy_settings_info: TypeInfo):
for name, sym in module.names.items(): return lazy_settings_info.metadata.setdefault('django', {}).setdefault('settings', {})
if name.isupper() and isinstance(sym.node, Var):
if sym.type is not None:
copied = make_sym_copy_of_setting(sym) def load_settings_from_names(settings_classdef: ClassDef,
if copied is None: modules: Iterable[MypyFile],
continue api: SemanticAnalyzerPass2) -> None:
settings_classdef.info.names[name] = copied settings_metadata = get_settings_metadata(settings_classdef.info)
for module in modules:
for name, sym in module.names.items():
if name.isupper() and isinstance(sym.node, Var):
if sym.type is not None:
copied = make_sym_copy_of_setting(sym)
if copied is None:
continue
settings_classdef.info.names[name] = copied
else:
var = Var(name, AnyType(TypeOfAny.unannotated))
var.info = api.named_type('__builtins__.object').type
settings_classdef.info.names[name] = SymbolTableNode(sym.kind, var)
settings_metadata[name] = module.fullname()
def get_import_star_modules(api: SemanticAnalyzerPass2, module: MypyFile) -> List[str]:
import_star_modules = []
for module_import in module.imports:
# relative import * are not resolved by mypy
if isinstance(module_import, ImportAll) and module_import.relative:
absolute_import_path, correct = correct_relative_import(module.fullname(), module_import.relative, module_import.id,
is_cur_package_init_file=False)
if not correct:
return []
for path in [absolute_import_path] + get_import_star_modules(api, module=api.modules.get(absolute_import_path)):
if path not in import_star_modules:
import_star_modules.append(path)
return import_star_modules
class AddSettingValuesToDjangoConfObject: class AddSettingValuesToDjangoConfObject:
@@ -55,7 +86,9 @@ class AddSettingValuesToDjangoConfObject:
api = cast(SemanticAnalyzerPass2, ctx.api) api = cast(SemanticAnalyzerPass2, ctx.api)
for module_name in self.settings_modules: for module_name in self.settings_modules:
module = api.modules[module_name] module = api.modules[module_name]
load_settings_from_module(ctx.cls, module=module) star_deps = [api.modules[star_dep]
for star_dep in reversed(get_import_star_modules(api, module))]
load_settings_from_names(ctx.cls, modules=star_deps + [module], api=api)
if self.ignore_missing_settings: if self.ignore_missing_settings:
ctx.cls.info.fallback_to_any = True ctx.cls.info.fallback_to_any = True

4
release.xsh Normal file
View File

@@ -0,0 +1,4 @@
#!/usr/local/bin/xonsh
python setup.py sdist
twine upload dist/*

View File

@@ -19,7 +19,7 @@ DJANGO_COMMIT_SHA = '03219b5f709dcd5b0bfacd963508625557ec1ef0'
# Some errors occur for the test suite itself, and cannot be addressed via django-stubs. They should be ignored # Some errors occur for the test suite itself, and cannot be addressed via django-stubs. They should be ignored
# using this constant. # using this constant.
MOCK_OBJECTS = ['MockRequest', 'MockCompiler', 'modelz', 'call_count', 'call_args_list', 'call_args'] MOCK_OBJECTS = ['MockRequest', 'MockCompiler', 'modelz', 'call_count', 'call_args_list', 'call_args', 'MockUser']
IGNORED_ERRORS = { IGNORED_ERRORS = {
'__common__': [ '__common__': [
*MOCK_OBJECTS, *MOCK_OBJECTS,
@@ -28,6 +28,7 @@ IGNORED_ERRORS = {
'Need type annotation for', 'Need type annotation for',
'Invalid value for a to= parameter', 'Invalid value for a to= parameter',
'already defined (possibly by an import)', 'already defined (possibly by an import)',
'gets multiple values for keyword argument',
'Cannot assign to a type', 'Cannot assign to a type',
re.compile(r'Cannot assign to class variable "[a-z_]+" via instance'), re.compile(r'Cannot assign to class variable "[a-z_]+" via instance'),
# forms <-> models plugin support # forms <-> models plugin support
@@ -135,6 +136,9 @@ IGNORED_ERRORS = {
'dispatch': [ 'dispatch': [
'Argument 1 to "connect" of "Signal" has incompatible type "object"; expected "Callable[..., Any]"' 'Argument 1 to "connect" of "Signal" has incompatible type "object"; expected "Callable[..., Any]"'
], ],
'deprecation': [
re.compile('"(old|new)" undefined in superclass')
],
'db_typecasts': [ 'db_typecasts': [
'"object" has no attribute "__iter__"; maybe "__str__" or "__dir__"? (not iterable)' '"object" has no attribute "__iter__"; maybe "__str__" or "__dir__"? (not iterable)'
], ],
@@ -147,13 +151,18 @@ IGNORED_ERRORS = {
'field_deconstruction': [ 'field_deconstruction': [
'Incompatible types in assignment (expression has type "ForeignKey[Any]", variable has type "CharField")' 'Incompatible types in assignment (expression has type "ForeignKey[Any]", variable has type "CharField")'
], ],
'file_uploads': [
'"handle_uncaught_exception" undefined in superclass'
],
'fixtures': [
'Incompatible types in assignment (expression has type "int", target has type "Iterable[str]")'
],
'get_object_or_404': [ 'get_object_or_404': [
'Argument 1 to "get_object_or_404" has incompatible type "str"; ' 'Argument 1 to "get_object_or_404" has incompatible type "str"; '
+ 'expected "Union[Type[Model], Manager[Any], QuerySet[Any]]"', + 'expected "Union[Type[<nothing>], Manager[<nothing>], QuerySet[<nothing>]]"',
'Argument 1 to "get_object_or_404" has incompatible type "Type[CustomClass]"; '
+ 'expected "Union[Type[Model], Manager[Any], QuerySet[Any]]"',
'Argument 1 to "get_list_or_404" has incompatible type "List[Type[Article]]"; ' 'Argument 1 to "get_list_or_404" has incompatible type "List[Type[Article]]"; '
+ 'expected "Union[Type[Model], Manager[Any], QuerySet[Any]]"' + 'expected "Union[Type[<nothing>], Manager[<nothing>], QuerySet[<nothing>]]"',
'CustomClass'
], ],
'get_or_create': [ 'get_or_create': [
'Argument 1 to "update_or_create" of "QuerySet" has incompatible type "**Dict[str, object]"; expected "MutableMapping[str, Any]"' 'Argument 1 to "update_or_create" of "QuerySet" has incompatible type "**Dict[str, object]"; expected "MutableMapping[str, Any]"'
@@ -165,6 +174,9 @@ IGNORED_ERRORS = {
'Argument "max_length" to "CharField" has incompatible type "str"; expected "Optional[int]"', 'Argument "max_length" to "CharField" has incompatible type "str"; expected "Optional[int]"',
'Argument "choices" to "CharField" has incompatible type "str"' 'Argument "choices" to "CharField" has incompatible type "str"'
], ],
'logging_tests': [
re.compile('"(setUpClass|tearDownClass)" undefined in superclass')
],
'model_inheritance_regress': [ 'model_inheritance_regress': [
'Incompatible types in assignment (expression has type "List[Supplier]", variable has type "QuerySet[Supplier]")' 'Incompatible types in assignment (expression has type "List[Supplier]", variable has type "QuerySet[Supplier]")'
], ],
@@ -176,6 +188,8 @@ IGNORED_ERRORS = {
'Incompatible types in assignment (expression has type "Type[Person]", variable has type', 'Incompatible types in assignment (expression has type "Type[Person]", variable has type',
'Unexpected keyword argument "name" for "Person"', 'Unexpected keyword argument "name" for "Person"',
'Cannot assign multiple types to name "PersonTwoImages" without an explicit "Type[...]" annotation', 'Cannot assign multiple types to name "PersonTwoImages" without an explicit "Type[...]" annotation',
'Incompatible types in assignment (expression has type "Type[Person]", '
+ 'base class "ImageFieldTestMixin" defined the type as "Type[PersonWithHeightAndWidth]")'
], ],
'model_regress': [ 'model_regress': [
'Too many arguments for "Worker"', 'Too many arguments for "Worker"',
@@ -201,6 +215,13 @@ IGNORED_ERRORS = {
'FakeLoader', 'FakeLoader',
'Argument 1 to "append" of "list" has incompatible type "AddIndex"; expected "CreateModel"' 'Argument 1 to "append" of "list" has incompatible type "AddIndex"; expected "CreateModel"'
], ],
'middleware_exceptions': [
'Argument 1 to "append" of "list" has incompatible type "Tuple[Any, Any]"; expected "str"'
],
'multiple_database': [
'Unexpected attribute "extra_arg" for model "Book"',
'Too many arguments for "create" of "QuerySet"'
],
'queryset_pickle': [ 'queryset_pickle': [
'"None" has no attribute "somefield"' '"None" has no attribute "somefield"'
], ],
@@ -295,6 +316,13 @@ IGNORED_ERRORS = {
'select_related_onetoone': [ 'select_related_onetoone': [
'"None" has no attribute' '"None" has no attribute'
], ],
'servers': [
re.compile('Argument [0-9] to "WSGIRequestHandler"')
],
'sitemaps_tests': [
'Incompatible types in assignment (expression has type "str", '
+ 'base class "Sitemap" defined the type as "Callable[[Sitemap, Model], str]")'
],
'view_tests': [ 'view_tests': [
'"Handler" has no attribute "include_html"', '"Handler" has no attribute "include_html"',
'"EmailMessage" has no attribute "alternatives"' '"EmailMessage" has no attribute "alternatives"'
@@ -329,10 +357,10 @@ TESTS_DIRS = [
'builtin_server', 'builtin_server',
'bulk_create', 'bulk_create',
# TODO: 'cache', # TODO: 'cache',
# TODO: 'check_framework', 'check_framework',
'choices', 'choices',
'conditional_processing', 'conditional_processing',
# TODO: 'contenttypes_tests', 'contenttypes_tests',
'context_processors', 'context_processors',
'csrf_tests', 'csrf_tests',
'custom_columns', 'custom_columns',
@@ -348,36 +376,36 @@ TESTS_DIRS = [
'db_typecasts', 'db_typecasts',
'db_utils', 'db_utils',
'dbshell', 'dbshell',
# TODO: 'decorators', 'decorators',
'defer', 'defer',
# TODO: 'defer_regress', 'defer_regress',
'delete', 'delete',
'delete_regress', 'delete_regress',
# TODO: 'deprecation', 'deprecation',
'dispatch', 'dispatch',
'distinct_on_fields', 'distinct_on_fields',
'empty', 'empty',
'expressions', 'expressions',
'expressions_case', 'expressions_case',
# TODO: 'expressions_window', 'expressions_window',
# TODO: 'extra_regress', # TODO: 'extra_regress',
'field_deconstruction', 'field_deconstruction',
'field_defaults', 'field_defaults',
'field_subclassing', 'field_subclassing',
# TODO: 'file_storage', # TODO: 'file_storage',
# TODO: 'file_uploads', 'file_uploads',
# TODO: 'files', # TODO: 'files',
'filtered_relation', 'filtered_relation',
# TODO: 'fixtures', 'fixtures',
'fixtures_model_package', 'fixtures_model_package',
# TODO: 'fixtures_regress', 'fixtures_regress',
# TODO: 'flatpages_tests', 'flatpages_tests',
'force_insert_update', 'force_insert_update',
'foreign_object', 'foreign_object',
# TODO: 'forms_tests', # TODO: 'forms_tests',
'from_db_value', 'from_db_value',
# TODO: 'generic_inline_admin', 'generic_inline_admin',
# TODO: 'generic_relations', 'generic_relations',
'generic_relations_regress', 'generic_relations_regress',
# TODO: 'generic_views', # TODO: 'generic_views',
'get_earliest_or_latest', 'get_earliest_or_latest',
@@ -398,7 +426,7 @@ TESTS_DIRS = [
# 'invalid_models_tests', # 'invalid_models_tests',
'known_related_objects', 'known_related_objects',
# TODO: 'logging_tests', 'logging_tests',
'lookup', 'lookup',
'm2m_and_m2o', 'm2m_and_m2o',
'm2m_intermediary', 'm2m_intermediary',
@@ -416,13 +444,13 @@ TESTS_DIRS = [
'many_to_one_null', 'many_to_one_null',
'max_lengths', 'max_lengths',
# TODO: 'messages_tests', # TODO: 'messages_tests',
# TODO: 'middleware', 'middleware',
# TODO: 'middleware_exceptions', 'middleware_exceptions',
'migrate_signals', 'migrate_signals',
'migration_test_data_persistence', 'migration_test_data_persistence',
'migrations', 'migrations',
'migrations2', 'migrations2',
# TODO: 'model_fields', 'model_fields',
# TODO: 'model_forms', # TODO: 'model_forms',
'model_formsets', 'model_formsets',
'model_formsets_regress', 'model_formsets_regress',
@@ -434,7 +462,7 @@ TESTS_DIRS = [
'model_package', 'model_package',
'model_regress', 'model_regress',
'modeladmin', 'modeladmin',
# TODO: 'multiple_database', 'multiple_database',
'mutually_referential', 'mutually_referential',
'nested_foreign_keys', 'nested_foreign_keys',
'no_models', 'no_models',
@@ -452,7 +480,7 @@ TESTS_DIRS = [
'properties', 'properties',
'proxy_model_inheritance', 'proxy_model_inheritance',
# TODO: 'proxy_models', # TODO: 'proxy_models',
# TODO: 'queries', 'queries',
'queryset_pickle', 'queryset_pickle',
'raw_query', 'raw_query',
'redirects_tests', 'redirects_tests',
@@ -468,7 +496,7 @@ TESTS_DIRS = [
'select_related_onetoone', 'select_related_onetoone',
'select_related_regress', 'select_related_regress',
# TODO: 'serializers', # TODO: 'serializers',
# TODO: 'servers', 'servers',
'sessions_tests', 'sessions_tests',
'settings_tests', 'settings_tests',
'shell', 'shell',
@@ -504,8 +532,7 @@ TESTS_DIRS = [
# TODO: 'urlpatterns_reverse', # TODO: 'urlpatterns_reverse',
'user_commands', 'user_commands',
# TODO: 'utils_tests', # TODO: 'utils_tests',
# not annotatable without annotation in test 'validation',
# TODO: 'validation',
'validators', 'validators',
'version', 'version',
'view_tests', 'view_tests',

View File

@@ -31,7 +31,7 @@ if sys.version_info[:2] < (3, 7):
setup( setup(
name="django-stubs", name="django-stubs",
version="0.3.0", version="0.5.0",
description='Django mypy stubs', description='Django mypy stubs',
long_description=readme, long_description=readme,
long_description_content_type='text/markdown', long_description_content_type='text/markdown',

View File

@@ -21,3 +21,15 @@ django_settings = mysettings
[file mysettings.py] [file mysettings.py]
MY_SETTING: int = 1 MY_SETTING: int = 1
[out] [out]
[CASE mypy_django_ini_in_current_directory_is_a_default]
from django.conf import settings
reveal_type(settings.MY_SETTING) # E: Revealed type is 'builtins.int'
[file mypy_django.ini]
[[mypy_django_plugin]
django_settings = mysettings
[file mysettings.py]
MY_SETTING: int = 1
[out]

View File

@@ -155,3 +155,19 @@ class MyModel(models.Model):
day = models.CharField(max_length=3, choices=((1, 'Fri'), (2, 'Sat'))) day = models.CharField(max_length=3, choices=((1, 'Fri'), (2, 'Sat')))
MyModel(day=1) MyModel(day=1)
[out] [out]
[CASE if_there_is_no_data_for_base_classes_of_fields_and_ignore_unresolved_attributes_set_to_true_to_not_fail]
from decimal import Decimal
from django.db import models
from fields2 import MoneyField
class InvoiceRow(models.Model):
base_amount = MoneyField(max_digits=10, decimal_places=2)
vat_rate = models.DecimalField(max_digits=10, decimal_places=2)
InvoiceRow(1, Decimal(0), Decimal(0))
InvoiceRow(base_amount=Decimal(0), vat_rate=Decimal(0))
InvoiceRow.objects.create(base_amount=Decimal(0), vat_rate=Decimal(0))
[out]
main:3: error: Cannot find module named 'fields2'
main:3: note: See https://mypy.readthedocs.io/en/latest/running_mypy.html#missing-imports

View File

@@ -22,13 +22,11 @@ class Publisher(models.Model):
class Book(models.Model): class Book(models.Model):
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE) publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
class StylesheetError(Exception):
pass
owner = models.ForeignKey(db_column='model_id', to='db.Unknown', on_delete=models.CASCADE) owner = models.ForeignKey(db_column='model_id', to='db.Unknown', on_delete=models.CASCADE)
book = Book() book = Book()
reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int' reveal_type(book.publisher_id) # E: Revealed type is 'builtins.int'
reveal_type(book.owner_id) # E: Revealed type is 'builtins.int' reveal_type(book.owner_id) # E: Revealed type is 'Any'
[CASE test_foreign_key_field_different_order_of_params] [CASE test_foreign_key_field_different_order_of_params]
from django.db import models from django.db import models
@@ -68,7 +66,7 @@ from django.db import models
class Publisher(models.Model): class Publisher(models.Model):
pass pass
[CASE test_to_parameter_as_string_with_application_name__fallbacks_to_any_if_model_not_present_in_dependency_graph] [CASE test_to_parameter_as_string_with_application_name_fallbacks_to_any_if_model_not_present_in_dependency_graph]
from django.db import models from django.db import models
class Book(models.Model): class Book(models.Model):
@@ -76,6 +74,9 @@ class Book(models.Model):
book = Book() book = Book()
reveal_type(book.publisher) # E: Revealed type is 'Any' reveal_type(book.publisher) # E: Revealed type is 'Any'
reveal_type(book.publisher_id) # E: Revealed type is 'Any'
Book(publisher_id=1)
Book.objects.create(publisher_id=1)
[file myapp/__init__.py] [file myapp/__init__.py]
[file myapp/models.py] [file myapp/models.py]
@@ -239,3 +240,38 @@ class ParkingSpot(BaseModel):
class Booking(BaseModel): class Booking(BaseModel):
parking_spot = models.ForeignKey(to=ParkingSpot, null=True, on_delete=models.SET_NULL) parking_spot = models.ForeignKey(to=ParkingSpot, null=True, on_delete=models.SET_NULL)
[out] [out]
[CASE if_no_related_name_is_passed_create_default_related_managers]
from django.db import models
class Publisher(models.Model):
pass
class Book(models.Model):
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
reveal_type(Publisher().book_set) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.Book]'
[CASE underscore_id_attribute_has_set_type_of_primary_key_if_explicit]
from django.db import models
import datetime
class Publisher(models.Model):
mypk = models.CharField(max_length=100, primary_key=True)
class Book(models.Model):
publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE)
reveal_type(Book().publisher_id) # E: Revealed type is 'builtins.str'
Book(publisher_id=1)
Book(publisher_id='hello')
Book(publisher_id=datetime.datetime.now()) # E: Incompatible type for "publisher_id" of "Book" (got "datetime", expected "Union[str, int, Combinable]")
Book.objects.create(publisher_id=1)
Book.objects.create(publisher_id='hello')
class Publisher2(models.Model):
mypk = models.IntegerField(primary_key=True)
class Book2(models.Model):
publisher = models.ForeignKey(to=Publisher2, on_delete=models.CASCADE)
reveal_type(Book2().publisher_id) # E: Revealed type is 'builtins.int'
Book2(publisher_id=1)
Book2(publisher_id='hello') # E: Incompatible type for "publisher_id" of "Book2" (got "str", expected "Union[int, Combinable, Literal['']]")
Book2.objects.create(publisher_id=1)
Book2.objects.create(publisher_id='hello') # E: Incompatible type for "publisher_id" of "Book2" (got "str", expected "Union[int, Combinable, Literal['']]")
[out]

View File

@@ -2,13 +2,18 @@
from django.conf import settings from django.conf import settings
reveal_type(settings.ROOT_DIR) # E: Revealed type is 'builtins.str' reveal_type(settings.ROOT_DIR) # E: Revealed type is 'builtins.str'
reveal_type(settings.APPS_DIR) # E: Revealed type is 'pathlib.Path'
reveal_type(settings.OBJ) # E: Revealed type is 'django.utils.functional.LazyObject' reveal_type(settings.OBJ) # E: Revealed type is 'django.utils.functional.LazyObject'
reveal_type(settings.NUMBERS) # E: Revealed type is 'builtins.list[builtins.str]' reveal_type(settings.NUMBERS) # E: Revealed type is 'builtins.list[builtins.str]'
reveal_type(settings.DICT) # E: Revealed type is 'builtins.dict[Any, Any]' reveal_type(settings.DICT) # E: Revealed type is 'builtins.dict[Any, Any]'
[env DJANGO_SETTINGS_MODULE=mysettings] [env DJANGO_SETTINGS_MODULE=mysettings]
[file mysettings.py] [file base.py]
SECRET_KEY = 112233 from pathlib import Path
ROOT_DIR = '/etc' ROOT_DIR = '/etc'
APPS_DIR = Path(ROOT_DIR)
[file mysettings.py]
from base import *
SECRET_KEY = 112233
NUMBERS = ['one', 'two'] NUMBERS = ['one', 'two']
DICT = {} # type: ignore DICT = {} # type: ignore
from django.utils.functional import LazyObject from django.utils.functional import LazyObject

View File

@@ -0,0 +1,56 @@
[CASE get_object_or_404_returns_proper_types]
from django.shortcuts import get_object_or_404, get_list_or_404
from django.db import models
class MyModel(models.Model):
pass
reveal_type(get_object_or_404(MyModel)) # E: Revealed type is 'main.MyModel*'
reveal_type(get_object_or_404(MyModel.objects)) # E: Revealed type is 'main.MyModel*'
reveal_type(get_object_or_404(MyModel.objects.get_queryset())) # E: Revealed type is 'main.MyModel*'
reveal_type(get_list_or_404(MyModel)) # E: Revealed type is 'builtins.list[main.MyModel*]'
reveal_type(get_list_or_404(MyModel.objects)) # E: Revealed type is 'builtins.list[main.MyModel*]'
reveal_type(get_list_or_404(MyModel.objects.get_queryset())) # E: Revealed type is 'builtins.list[main.MyModel*]'
[out]
[CASE get_user_model_returns_proper_class]
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from myapp.models import MyUser
from django.contrib.auth import get_user_model
UserModel = get_user_model()
reveal_type(UserModel.objects) # E: Revealed type is 'django.db.models.manager.Manager[myapp.models.MyUser]'
[env DJANGO_SETTINGS_MODULE=mysettings]
[file mysettings.py]
INSTALLED_APPS = ('myapp',)
AUTH_USER_MODEL = 'myapp.MyUser'
[file myapp/__init__.py]
[file myapp/models.py]
from django.db import models
class MyUser(models.Model):
pass
[out]
[CASE return_type_model_and_show_error_if_model_not_yet_imported]
from django.contrib.auth import get_user_model
UserModel = get_user_model()
reveal_type(UserModel.objects)
[env DJANGO_SETTINGS_MODULE=mysettings]
[file mysettings.py]
INSTALLED_APPS = ('myapp',)
AUTH_USER_MODEL = 'myapp.MyUser'
[file myapp/__init__.py]
[file myapp/models.py]
from django.db import models
class MyUser(models.Model):
pass
[out]
main:3: error: "myapp.MyUser" model class is not imported so far. Try to import it (under if TYPE_CHECKING) at the beginning of the current file
main:4: error: Revealed type is 'Any'
main:4: error: "Type[Model]" has no attribute "objects"