mirror of
https://github.com/davidhalter/django-stubs.git
synced 2025-12-10 22:11:54 +08:00
Add type hints to all test code (#1217)
* Add type hints to all test code * Fixes * Fix indentation * Review fixes
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import tempfile
|
||||
import typing
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -27,7 +27,7 @@ django_settings_module = str (required)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def write_to_file(file_contents: str, suffix: typing.Optional[str] = None) -> typing.Generator[str, None, None]:
|
||||
def write_to_file(file_contents: str, suffix: Optional[str] = None) -> Generator[str, None, None]:
|
||||
with tempfile.NamedTemporaryFile(mode="w+", suffix=suffix) as config_file:
|
||||
config_file.write(file_contents)
|
||||
config_file.seek(0)
|
||||
@@ -54,8 +54,7 @@ def write_to_file(file_contents: str, suffix: typing.Optional[str] = None) -> ty
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_misconfiguration_handling(capsys, config_file_contents, message_part):
|
||||
# type: (typing.Any, typing.List[str], str) -> None
|
||||
def test_misconfiguration_handling(capsys: Any, config_file_contents: List[str], message_part: str) -> None:
|
||||
"""Invalid configuration raises `SystemExit` with a precise error message."""
|
||||
contents = "\n".join(config_file_contents).expandtabs(4)
|
||||
with write_to_file(contents) as filename:
|
||||
@@ -74,7 +73,7 @@ def test_misconfiguration_handling(capsys, config_file_contents, message_part):
|
||||
pytest.param(None, id="as none"),
|
||||
],
|
||||
)
|
||||
def test_handles_filename(capsys, filename: str):
|
||||
def test_handles_filename(capsys: Any, filename: str) -> None:
|
||||
with pytest.raises(SystemExit, match="2"):
|
||||
DjangoPluginConfig(filename)
|
||||
|
||||
@@ -116,7 +115,7 @@ def test_handles_filename(capsys, filename: str):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_toml_misconfiguration_handling(capsys, config_file_contents, message_part):
|
||||
def test_toml_misconfiguration_handling(capsys: Any, config_file_contents, message_part) -> None:
|
||||
with write_to_file(config_file_contents, suffix=".toml") as filename:
|
||||
with pytest.raises(SystemExit, match="2"):
|
||||
DjangoPluginConfig(filename)
|
||||
|
||||
@@ -79,7 +79,7 @@
|
||||
# this will fail if `model` has a type other than the generic specified in the class declaration
|
||||
model = TestModel
|
||||
|
||||
def a_method_action(self, request, queryset):
|
||||
def a_method_action(self, request: HttpRequest, queryset: QuerySet) -> None:
|
||||
pass
|
||||
|
||||
# This test is here to make sure we're not running into a mypy issue which is
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
- case: login_required_bare
|
||||
main: |
|
||||
from typing import Any
|
||||
from django.contrib.auth.decorators import login_required
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
@login_required
|
||||
def view_func(request): ...
|
||||
reveal_type(view_func) # N: Revealed type is "def (request: Any) -> Any"
|
||||
def view_func(request: HttpRequest) -> HttpResponse: ...
|
||||
reveal_type(view_func) # N: Revealed type is "def (request: django.http.request.HttpRequest) -> django.http.response.HttpResponse"
|
||||
- case: login_required_fancy
|
||||
main: |
|
||||
from django.contrib.auth.decorators import login_required
|
||||
@@ -15,29 +17,33 @@
|
||||
- case: login_required_weird
|
||||
main: |
|
||||
from django.contrib.auth.decorators import login_required
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
# This is non-conventional usage, but covered in Django tests, so we allow it.
|
||||
def view_func(request): ...
|
||||
def view_func(request: HttpRequest) -> HttpResponse: ...
|
||||
wrapped_view = login_required(view_func, redirect_field_name='a', login_url='b')
|
||||
reveal_type(wrapped_view) # N: Revealed type is "def (request: Any) -> Any"
|
||||
reveal_type(wrapped_view) # N: Revealed type is "def (request: django.http.request.HttpRequest) -> django.http.response.HttpResponse"
|
||||
- case: login_required_incorrect_return
|
||||
main: |
|
||||
from typing import Any
|
||||
from django.contrib.auth.decorators import login_required
|
||||
@login_required() # E: Value of type variable "_VIEW" of function cannot be "Callable[[Any], str]"
|
||||
def view_func2(request) -> str: ...
|
||||
def view_func2(request: Any) -> str: ...
|
||||
- case: user_passes_test
|
||||
main: |
|
||||
from django.contrib.auth.decorators import user_passes_test
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
@user_passes_test(lambda u: u.get_username().startswith('super'))
|
||||
def view_func(request): ...
|
||||
reveal_type(view_func) # N: Revealed type is "def (request: Any) -> Any"
|
||||
def view_func(request: HttpRequest) -> HttpResponse: ...
|
||||
reveal_type(view_func) # N: Revealed type is "def (request: django.http.request.HttpRequest) -> django.http.response.HttpResponse"
|
||||
- case: user_passes_test_bare_is_error
|
||||
main: |
|
||||
from django.http.response import HttpResponse
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.contrib.auth.decorators import user_passes_test
|
||||
@user_passes_test # E: Argument 1 to "user_passes_test" has incompatible type "Callable[[Any], HttpResponse]"; expected "Callable[[Union[AbstractBaseUser, AnonymousUser]], bool]"
|
||||
def view_func(request) -> HttpResponse: ...
|
||||
@user_passes_test # E: Argument 1 to "user_passes_test" has incompatible type "Callable[[HttpRequest], HttpResponse]"; expected "Callable[[Union[AbstractBaseUser, AnonymousUser]], bool]"
|
||||
def view_func(request: HttpRequest) -> HttpResponse: ...
|
||||
- case: permission_required
|
||||
main: |
|
||||
from django.contrib.auth.decorators import permission_required
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
@permission_required('polls.can_vote')
|
||||
def view_func(request): ...
|
||||
def view_func(request: HttpRequest) -> HttpResponse: ...
|
||||
|
||||
@@ -12,10 +12,12 @@
|
||||
reveal_type(func) # N: Revealed type is "def (x: builtins.int) -> builtins.list[Any]"
|
||||
- case: non_atomic_requests_bare
|
||||
main: |
|
||||
from typing import Any
|
||||
from django.db.transaction import non_atomic_requests
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
@non_atomic_requests
|
||||
def view_func(request): ...
|
||||
reveal_type(view_func) # N: Revealed type is "def (request: Any) -> Any"
|
||||
def view_func(request: HttpRequest) -> HttpResponse: ...
|
||||
reveal_type(view_func) # N: Revealed type is "def (request: django.http.request.HttpRequest) -> django.http.response.HttpResponse"
|
||||
|
||||
- case: non_atomic_requests_args
|
||||
main: |
|
||||
|
||||
@@ -543,7 +543,7 @@
|
||||
from django.db import models
|
||||
class User(models.Model):
|
||||
pass
|
||||
def get_user_model_name():
|
||||
def get_user_model_name() -> str:
|
||||
return 'myapp.User'
|
||||
class Profile(models.Model):
|
||||
user = models.ForeignKey(to=get_user_model_name(), on_delete=models.CASCADE)
|
||||
@@ -703,7 +703,7 @@
|
||||
def custom(self) -> None:
|
||||
pass
|
||||
|
||||
def TransactionManager():
|
||||
def TransactionManager() -> BaseManager:
|
||||
return BaseManager.from_queryset(TransactionQuerySet)()
|
||||
|
||||
class Transaction(models.Model):
|
||||
|
||||
@@ -164,12 +164,12 @@
|
||||
qs = User.objects.annotate(Count('id'))
|
||||
annotated_user = qs.get()
|
||||
|
||||
def animals_only(param: Animal):
|
||||
def animals_only(param: Animal) -> None:
|
||||
pass
|
||||
# Make sure that even though attr access falls back to Any, the type is still checked
|
||||
animals_only(annotated_user) # E: Argument 1 to "animals_only" has incompatible type "WithAnnotations[myapp__models__User]"; expected "Animal"
|
||||
|
||||
def users_allowed(param: User):
|
||||
def users_allowed(param: User) -> None:
|
||||
# But this function accepts only the original User type, so any attr access is not allowed within this function
|
||||
param.foo # E: "User" has no attribute "foo"
|
||||
# Passing in the annotated User to a function taking a (unannotated) User is OK
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
main: |
|
||||
from myapp.models import User
|
||||
|
||||
async def main():
|
||||
async def main() -> None:
|
||||
async for user in User.objects.all():
|
||||
reveal_type(user) # N: Revealed type is "myapp.models.User"
|
||||
installed_apps:
|
||||
|
||||
@@ -46,7 +46,7 @@
|
||||
|
||||
_T = TypeVar('_T', bound=models.Model)
|
||||
class Base(Generic[_T]):
|
||||
def __init__(self, model_cls: Type[_T]):
|
||||
def __init__(self, model_cls: Type[_T]) -> None:
|
||||
self.model_cls = model_cls
|
||||
reveal_type(self.model_cls._default_manager) # N: Revealed type is "django.db.models.manager.BaseManager[_T`1]"
|
||||
class MyModel(models.Model):
|
||||
@@ -71,7 +71,7 @@
|
||||
|
||||
_T = TypeVar('_T', bound=models.Model)
|
||||
class Base(Generic[_T]):
|
||||
def __init__(self, model_cls: Type[_T]):
|
||||
def __init__(self, model_cls: Type[_T]) -> None:
|
||||
self.model_cls = model_cls
|
||||
reveal_type(self.model_cls._base_manager) # N: Revealed type is "django.db.models.manager.BaseManager[_T`1]"
|
||||
class MyModel(models.Model):
|
||||
@@ -350,11 +350,12 @@
|
||||
- path: myapp/__init__.py
|
||||
- path: myapp/models.py
|
||||
content: |
|
||||
from typing import Any
|
||||
from django.db import models
|
||||
class MyManager(models.Manager):
|
||||
def get_instance(self) -> int:
|
||||
pass
|
||||
def get_instance_untyped(self, name):
|
||||
def get_instance_untyped(self, name: str):
|
||||
pass
|
||||
class User(models.Model):
|
||||
objects = MyManager()
|
||||
@@ -389,21 +390,22 @@
|
||||
installed_apps:
|
||||
- myapp
|
||||
out: |
|
||||
myapp/models:4: error: Return type "MyModel" of "create" incompatible with return type "_T" in supertype "BaseManager"
|
||||
myapp/models:5: error: Incompatible return value type (got "_T", expected "MyModel")
|
||||
myapp/models:5: error: Return type "MyModel" of "create" incompatible with return type "_T" in supertype "BaseManager"
|
||||
myapp/models:6: error: Incompatible return value type (got "_T", expected "MyModel")
|
||||
files:
|
||||
- path: myapp/__init__.py
|
||||
- path: myapp/models.py
|
||||
content: |
|
||||
from typing import Any
|
||||
from django.db import models
|
||||
class MyModelManager(models.Manager):
|
||||
|
||||
def create(self, **kwargs) -> 'MyModel':
|
||||
return super().create(**kwargs)
|
||||
def create(self, **kwargs: Any) -> 'MyModel':
|
||||
return super().create(**kwargs)
|
||||
|
||||
|
||||
class MyModel(models.Model):
|
||||
objects = MyModelManager()
|
||||
objects = MyModelManager()
|
||||
|
||||
|
||||
- case: override_manager_create2
|
||||
@@ -416,15 +418,15 @@
|
||||
- path: myapp/__init__.py
|
||||
- path: myapp/models.py
|
||||
content: |
|
||||
from typing import Any
|
||||
from django.db import models
|
||||
class MyModelManager(models.Manager['MyModel']):
|
||||
|
||||
def create(self, **kwargs) -> 'MyModel':
|
||||
return super().create(**kwargs)
|
||||
def create(self, **kwargs: Any) -> 'MyModel':
|
||||
return super().create(**kwargs)
|
||||
|
||||
class MyModel(models.Model):
|
||||
|
||||
objects = MyModelManager()
|
||||
objects = MyModelManager()
|
||||
|
||||
- case: regression_manager_scope_foreign
|
||||
main: |
|
||||
@@ -488,7 +490,7 @@
|
||||
class InvisibleUnresolvable(AbstractUnresolvable):
|
||||
text = models.TextField()
|
||||
|
||||
def process_booking(user: User):
|
||||
def process_booking(user: User) -> None:
|
||||
reveal_type(User.objects)
|
||||
reveal_type(User._default_manager)
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@
|
||||
- path: myapp/models.py
|
||||
content: |
|
||||
from django.db import models
|
||||
def return_int():
|
||||
def return_int() -> int:
|
||||
return 0
|
||||
class MyModel(models.Model):
|
||||
id = models.IntegerField(primary_key=True, default=return_int)
|
||||
|
||||
@@ -14,9 +14,10 @@
|
||||
content: |
|
||||
from django.db import models
|
||||
class MyModel(models.Model):
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
reveal_type(self.id) # N: Revealed type is "builtins.int"
|
||||
reveal_type(self.pk) # N: Revealed type is "builtins.int"
|
||||
return ''
|
||||
|
||||
|
||||
- case: test_access_to_id_field_through_self_if_primary_key_is_defined
|
||||
@@ -36,9 +37,10 @@
|
||||
from django.db import models
|
||||
class MyModel(models.Model):
|
||||
id = models.CharField(max_length=10, primary_key=True)
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
reveal_type(self.id) # N: Revealed type is "builtins.str"
|
||||
reveal_type(self.pk) # N: Revealed type is "builtins.str"
|
||||
return self.id
|
||||
|
||||
|
||||
- case: test_access_to_id_field_through_self_if_primary_key_has_different_name
|
||||
@@ -60,7 +62,8 @@
|
||||
from django.db import models
|
||||
class MyModel(models.Model):
|
||||
primary = models.CharField(max_length=10, primary_key=True)
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
reveal_type(self.primary) # N: Revealed type is "builtins.str"
|
||||
reveal_type(self.pk) # N: Revealed type is "builtins.str"
|
||||
self.id # E: "MyModel" has no attribute "id"
|
||||
return self.primary
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
- case: async_client_methods
|
||||
main: |
|
||||
from django.test.client import AsyncClient
|
||||
async def main():
|
||||
async def main() -> None:
|
||||
client = AsyncClient()
|
||||
response = await client.get('foo')
|
||||
reveal_type(response.asgi_request) # N: Revealed type is "django.core.handlers.asgi.ASGIRequest"
|
||||
|
||||
@@ -49,7 +49,7 @@
|
||||
files:
|
||||
- path: extras/extra_module.py
|
||||
content: |
|
||||
def extra_fn():
|
||||
def extra_fn() -> None:
|
||||
pass
|
||||
|
||||
- case: add_mypypath_env_var_to_package_search
|
||||
@@ -61,5 +61,5 @@
|
||||
files:
|
||||
- path: extras/extra_module.py
|
||||
content: |
|
||||
def extra_fn():
|
||||
def extra_fn() -> None:
|
||||
pass
|
||||
|
||||
@@ -29,6 +29,7 @@
|
||||
main: |
|
||||
from typing import Any
|
||||
from django import forms
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.views.generic.edit import FormView
|
||||
|
||||
class MyForm(forms.ModelForm):
|
||||
@@ -37,25 +38,28 @@
|
||||
pass
|
||||
class MyView(FormView):
|
||||
form_class = MyForm
|
||||
def post(self, request, *args: Any, **kwds: Any):
|
||||
def post(self, request: HttpRequest, *args: Any, **kwds: Any) -> HttpResponse:
|
||||
form_class = self.get_form_class()
|
||||
reveal_type(form_class) # N: Revealed type is "Type[main.MyForm]"
|
||||
reveal_type(self.get_form(None)) # N: Revealed type is "main.MyForm"
|
||||
reveal_type(self.get_form()) # N: Revealed type is "main.MyForm"
|
||||
reveal_type(self.get_form(form_class)) # N: Revealed type is "main.MyForm"
|
||||
reveal_type(self.get_form(MyForm2)) # N: Revealed type is "main.MyForm2"
|
||||
return HttpResponse()
|
||||
|
||||
- case: updateview_form_valid_has_form_save
|
||||
main: |
|
||||
from django import forms
|
||||
from django.http import HttpResponse
|
||||
from django.views.generic.edit import UpdateView
|
||||
|
||||
class MyForm(forms.ModelForm):
|
||||
pass
|
||||
class MyView(UpdateView):
|
||||
form_class = MyForm
|
||||
def form_valid(self, form: forms.BaseModelForm):
|
||||
def form_valid(self, form: forms.BaseModelForm) -> HttpResponse:
|
||||
reveal_type(form.save) # N: Revealed type is "def (commit: builtins.bool =) -> Any"
|
||||
return HttpResponse()
|
||||
|
||||
- case: successmessagemixin_compatible_with_formmixin
|
||||
main: |
|
||||
|
||||
@@ -32,7 +32,7 @@
|
||||
def atomic_method1(self, abc: int) -> str:
|
||||
pass
|
||||
@transaction.atomic(savepoint=True)
|
||||
def atomic_method2(self):
|
||||
def atomic_method2(self) -> None:
|
||||
pass
|
||||
@transaction.atomic(using="db", savepoint=True)
|
||||
def atomic_method3(self, myparam: str) -> int:
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
reveal_type(TestView()) # N: Revealed type is "main.TestView"
|
||||
- case: method_decorator_function
|
||||
main: |
|
||||
from typing import Any
|
||||
from django.views.generic.base import View
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.contrib.auth.decorators import login_required
|
||||
@@ -15,6 +16,6 @@
|
||||
from django.http.request import HttpRequest
|
||||
class TestView(View):
|
||||
@method_decorator(login_required)
|
||||
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponseBase:
|
||||
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponseBase:
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
reveal_type(dispatch) # N: Revealed type is "def (self: main.TestView, request: django.http.request.HttpRequest, *args: Any, **kwargs: Any) -> django.http.response.HttpResponseBase"
|
||||
|
||||
@@ -14,22 +14,24 @@
|
||||
|
||||
- case: dispatch_http_response
|
||||
main: |
|
||||
from django.http import HttpResponse
|
||||
from typing import Any
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.views.generic.base import View
|
||||
|
||||
class MyView(View):
|
||||
def dispatch(self, request, *args, **kwargs) -> HttpResponse:
|
||||
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
|
||||
response: HttpResponse
|
||||
return response
|
||||
|
||||
|
||||
- case: dispatch_streaming_http_response
|
||||
main: |
|
||||
from django.http import StreamingHttpResponse
|
||||
from typing import Any
|
||||
from django.http import HttpRequest, StreamingHttpResponse
|
||||
from django.views.generic.base import View
|
||||
|
||||
class MyView(View):
|
||||
def dispatch(self, request, *args, **kwargs) -> StreamingHttpResponse:
|
||||
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> StreamingHttpResponse:
|
||||
response: StreamingHttpResponse
|
||||
return response
|
||||
|
||||
|
||||
Reference in New Issue
Block a user