Add type hints to all test code (#1217)

* Add type hints to all test code

* Fixes

* Fix indentation

* Review fixes
This commit is contained in:
Marti Raudsepp
2022-10-31 11:20:10 +02:00
committed by GitHub
parent 9b4162beb1
commit e3c131bc61
16 changed files with 72 additions and 53 deletions

View File

@@ -1,7 +1,7 @@
import tempfile
import typing
import uuid
from contextlib import contextmanager
from typing import Any, Generator, List, Optional
import pytest
@@ -27,7 +27,7 @@ django_settings_module = str (required)
@contextmanager
def write_to_file(file_contents: str, suffix: typing.Optional[str] = None) -> typing.Generator[str, None, None]:
def write_to_file(file_contents: str, suffix: Optional[str] = None) -> Generator[str, None, None]:
with tempfile.NamedTemporaryFile(mode="w+", suffix=suffix) as config_file:
config_file.write(file_contents)
config_file.seek(0)
@@ -54,8 +54,7 @@ def write_to_file(file_contents: str, suffix: typing.Optional[str] = None) -> ty
),
],
)
def test_misconfiguration_handling(capsys, config_file_contents, message_part):
# type: (typing.Any, typing.List[str], str) -> None
def test_misconfiguration_handling(capsys: Any, config_file_contents: List[str], message_part: str) -> None:
"""Invalid configuration raises `SystemExit` with a precise error message."""
contents = "\n".join(config_file_contents).expandtabs(4)
with write_to_file(contents) as filename:
@@ -74,7 +73,7 @@ def test_misconfiguration_handling(capsys, config_file_contents, message_part):
pytest.param(None, id="as none"),
],
)
def test_handles_filename(capsys, filename: str):
def test_handles_filename(capsys: Any, filename: str) -> None:
with pytest.raises(SystemExit, match="2"):
DjangoPluginConfig(filename)
@@ -116,7 +115,7 @@ def test_handles_filename(capsys, filename: str):
),
],
)
def test_toml_misconfiguration_handling(capsys, config_file_contents, message_part):
def test_toml_misconfiguration_handling(capsys: Any, config_file_contents, message_part) -> None:
with write_to_file(config_file_contents, suffix=".toml") as filename:
with pytest.raises(SystemExit, match="2"):
DjangoPluginConfig(filename)

View File

@@ -79,7 +79,7 @@
# this will fail if `model` has a type other than the generic specified in the class declaration
model = TestModel
def a_method_action(self, request, queryset):
def a_method_action(self, request: HttpRequest, queryset: QuerySet) -> None:
pass
# This test is here to make sure we're not running into a mypy issue which is

View File

@@ -1,9 +1,11 @@
- case: login_required_bare
main: |
from typing import Any
from django.contrib.auth.decorators import login_required
from django.http import HttpRequest, HttpResponse
@login_required
def view_func(request): ...
reveal_type(view_func) # N: Revealed type is "def (request: Any) -> Any"
def view_func(request: HttpRequest) -> HttpResponse: ...
reveal_type(view_func) # N: Revealed type is "def (request: django.http.request.HttpRequest) -> django.http.response.HttpResponse"
- case: login_required_fancy
main: |
from django.contrib.auth.decorators import login_required
@@ -15,29 +17,33 @@
- case: login_required_weird
main: |
from django.contrib.auth.decorators import login_required
from django.http import HttpRequest, HttpResponse
# This is non-conventional usage, but covered in Django tests, so we allow it.
def view_func(request): ...
def view_func(request: HttpRequest) -> HttpResponse: ...
wrapped_view = login_required(view_func, redirect_field_name='a', login_url='b')
reveal_type(wrapped_view) # N: Revealed type is "def (request: Any) -> Any"
reveal_type(wrapped_view) # N: Revealed type is "def (request: django.http.request.HttpRequest) -> django.http.response.HttpResponse"
- case: login_required_incorrect_return
main: |
from typing import Any
from django.contrib.auth.decorators import login_required
@login_required() # E: Value of type variable "_VIEW" of function cannot be "Callable[[Any], str]"
def view_func2(request) -> str: ...
def view_func2(request: Any) -> str: ...
- case: user_passes_test
main: |
from django.contrib.auth.decorators import user_passes_test
from django.http import HttpRequest, HttpResponse
@user_passes_test(lambda u: u.get_username().startswith('super'))
def view_func(request): ...
reveal_type(view_func) # N: Revealed type is "def (request: Any) -> Any"
def view_func(request: HttpRequest) -> HttpResponse: ...
reveal_type(view_func) # N: Revealed type is "def (request: django.http.request.HttpRequest) -> django.http.response.HttpResponse"
- case: user_passes_test_bare_is_error
main: |
from django.http.response import HttpResponse
from django.http import HttpRequest, HttpResponse
from django.contrib.auth.decorators import user_passes_test
@user_passes_test # E: Argument 1 to "user_passes_test" has incompatible type "Callable[[Any], HttpResponse]"; expected "Callable[[Union[AbstractBaseUser, AnonymousUser]], bool]"
def view_func(request) -> HttpResponse: ...
@user_passes_test # E: Argument 1 to "user_passes_test" has incompatible type "Callable[[HttpRequest], HttpResponse]"; expected "Callable[[Union[AbstractBaseUser, AnonymousUser]], bool]"
def view_func(request: HttpRequest) -> HttpResponse: ...
- case: permission_required
main: |
from django.contrib.auth.decorators import permission_required
from django.http import HttpRequest, HttpResponse
@permission_required('polls.can_vote')
def view_func(request): ...
def view_func(request: HttpRequest) -> HttpResponse: ...

View File

@@ -12,10 +12,12 @@
reveal_type(func) # N: Revealed type is "def (x: builtins.int) -> builtins.list[Any]"
- case: non_atomic_requests_bare
main: |
from typing import Any
from django.db.transaction import non_atomic_requests
from django.http import HttpRequest, HttpResponse
@non_atomic_requests
def view_func(request): ...
reveal_type(view_func) # N: Revealed type is "def (request: Any) -> Any"
def view_func(request: HttpRequest) -> HttpResponse: ...
reveal_type(view_func) # N: Revealed type is "def (request: django.http.request.HttpRequest) -> django.http.response.HttpResponse"
- case: non_atomic_requests_args
main: |

View File

@@ -543,7 +543,7 @@
from django.db import models
class User(models.Model):
pass
def get_user_model_name():
def get_user_model_name() -> str:
return 'myapp.User'
class Profile(models.Model):
user = models.ForeignKey(to=get_user_model_name(), on_delete=models.CASCADE)
@@ -703,7 +703,7 @@
def custom(self) -> None:
pass
def TransactionManager():
def TransactionManager() -> BaseManager:
return BaseManager.from_queryset(TransactionQuerySet)()
class Transaction(models.Model):

View File

@@ -164,12 +164,12 @@
qs = User.objects.annotate(Count('id'))
annotated_user = qs.get()
def animals_only(param: Animal):
def animals_only(param: Animal) -> None:
pass
# Make sure that even though attr access falls back to Any, the type is still checked
animals_only(annotated_user) # E: Argument 1 to "animals_only" has incompatible type "WithAnnotations[myapp__models__User]"; expected "Animal"
def users_allowed(param: User):
def users_allowed(param: User) -> None:
# But this function accepts only the original User type, so any attr access is not allowed within this function
param.foo # E: "User" has no attribute "foo"
# Passing in the annotated User to a function taking a (unannotated) User is OK

View File

@@ -19,7 +19,7 @@
main: |
from myapp.models import User
async def main():
async def main() -> None:
async for user in User.objects.all():
reveal_type(user) # N: Revealed type is "myapp.models.User"
installed_apps:

View File

@@ -46,7 +46,7 @@
_T = TypeVar('_T', bound=models.Model)
class Base(Generic[_T]):
def __init__(self, model_cls: Type[_T]):
def __init__(self, model_cls: Type[_T]) -> None:
self.model_cls = model_cls
reveal_type(self.model_cls._default_manager) # N: Revealed type is "django.db.models.manager.BaseManager[_T`1]"
class MyModel(models.Model):
@@ -71,7 +71,7 @@
_T = TypeVar('_T', bound=models.Model)
class Base(Generic[_T]):
def __init__(self, model_cls: Type[_T]):
def __init__(self, model_cls: Type[_T]) -> None:
self.model_cls = model_cls
reveal_type(self.model_cls._base_manager) # N: Revealed type is "django.db.models.manager.BaseManager[_T`1]"
class MyModel(models.Model):
@@ -350,11 +350,12 @@
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from typing import Any
from django.db import models
class MyManager(models.Manager):
def get_instance(self) -> int:
pass
def get_instance_untyped(self, name):
def get_instance_untyped(self, name: str):
pass
class User(models.Model):
objects = MyManager()
@@ -389,21 +390,22 @@
installed_apps:
- myapp
out: |
myapp/models:4: error: Return type "MyModel" of "create" incompatible with return type "_T" in supertype "BaseManager"
myapp/models:5: error: Incompatible return value type (got "_T", expected "MyModel")
myapp/models:5: error: Return type "MyModel" of "create" incompatible with return type "_T" in supertype "BaseManager"
myapp/models:6: error: Incompatible return value type (got "_T", expected "MyModel")
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from typing import Any
from django.db import models
class MyModelManager(models.Manager):
def create(self, **kwargs) -> 'MyModel':
return super().create(**kwargs)
def create(self, **kwargs: Any) -> 'MyModel':
return super().create(**kwargs)
class MyModel(models.Model):
objects = MyModelManager()
objects = MyModelManager()
- case: override_manager_create2
@@ -416,15 +418,15 @@
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from typing import Any
from django.db import models
class MyModelManager(models.Manager['MyModel']):
def create(self, **kwargs) -> 'MyModel':
return super().create(**kwargs)
def create(self, **kwargs: Any) -> 'MyModel':
return super().create(**kwargs)
class MyModel(models.Model):
objects = MyModelManager()
objects = MyModelManager()
- case: regression_manager_scope_foreign
main: |
@@ -488,7 +490,7 @@
class InvisibleUnresolvable(AbstractUnresolvable):
text = models.TextField()
def process_booking(user: User):
def process_booking(user: User) -> None:
reveal_type(User.objects)
reveal_type(User._default_manager)

View File

@@ -124,7 +124,7 @@
- path: myapp/models.py
content: |
from django.db import models
def return_int():
def return_int() -> int:
return 0
class MyModel(models.Model):
id = models.IntegerField(primary_key=True, default=return_int)

View File

@@ -14,9 +14,10 @@
content: |
from django.db import models
class MyModel(models.Model):
def __str__(self):
def __str__(self) -> str:
reveal_type(self.id) # N: Revealed type is "builtins.int"
reveal_type(self.pk) # N: Revealed type is "builtins.int"
return ''
- case: test_access_to_id_field_through_self_if_primary_key_is_defined
@@ -36,9 +37,10 @@
from django.db import models
class MyModel(models.Model):
id = models.CharField(max_length=10, primary_key=True)
def __str__(self):
def __str__(self) -> str:
reveal_type(self.id) # N: Revealed type is "builtins.str"
reveal_type(self.pk) # N: Revealed type is "builtins.str"
return self.id
- case: test_access_to_id_field_through_self_if_primary_key_has_different_name
@@ -60,7 +62,8 @@
from django.db import models
class MyModel(models.Model):
primary = models.CharField(max_length=10, primary_key=True)
def __str__(self):
def __str__(self) -> str:
reveal_type(self.primary) # N: Revealed type is "builtins.str"
reveal_type(self.pk) # N: Revealed type is "builtins.str"
self.id # E: "MyModel" has no attribute "id"
return self.primary

View File

@@ -14,7 +14,7 @@
- case: async_client_methods
main: |
from django.test.client import AsyncClient
async def main():
async def main() -> None:
client = AsyncClient()
response = await client.get('foo')
reveal_type(response.asgi_request) # N: Revealed type is "django.core.handlers.asgi.ASGIRequest"

View File

@@ -49,7 +49,7 @@
files:
- path: extras/extra_module.py
content: |
def extra_fn():
def extra_fn() -> None:
pass
- case: add_mypypath_env_var_to_package_search
@@ -61,5 +61,5 @@
files:
- path: extras/extra_module.py
content: |
def extra_fn():
def extra_fn() -> None:
pass

View File

@@ -29,6 +29,7 @@
main: |
from typing import Any
from django import forms
from django.http import HttpRequest, HttpResponse
from django.views.generic.edit import FormView
class MyForm(forms.ModelForm):
@@ -37,25 +38,28 @@
pass
class MyView(FormView):
form_class = MyForm
def post(self, request, *args: Any, **kwds: Any):
def post(self, request: HttpRequest, *args: Any, **kwds: Any) -> HttpResponse:
form_class = self.get_form_class()
reveal_type(form_class) # N: Revealed type is "Type[main.MyForm]"
reveal_type(self.get_form(None)) # N: Revealed type is "main.MyForm"
reveal_type(self.get_form()) # N: Revealed type is "main.MyForm"
reveal_type(self.get_form(form_class)) # N: Revealed type is "main.MyForm"
reveal_type(self.get_form(MyForm2)) # N: Revealed type is "main.MyForm2"
return HttpResponse()
- case: updateview_form_valid_has_form_save
main: |
from django import forms
from django.http import HttpResponse
from django.views.generic.edit import UpdateView
class MyForm(forms.ModelForm):
pass
class MyView(UpdateView):
form_class = MyForm
def form_valid(self, form: forms.BaseModelForm):
def form_valid(self, form: forms.BaseModelForm) -> HttpResponse:
reveal_type(form.save) # N: Revealed type is "def (commit: builtins.bool =) -> Any"
return HttpResponse()
- case: successmessagemixin_compatible_with_formmixin
main: |

View File

@@ -32,7 +32,7 @@
def atomic_method1(self, abc: int) -> str:
pass
@transaction.atomic(savepoint=True)
def atomic_method2(self):
def atomic_method2(self) -> None:
pass
@transaction.atomic(using="db", savepoint=True)
def atomic_method3(self, myparam: str) -> int:

View File

@@ -8,6 +8,7 @@
reveal_type(TestView()) # N: Revealed type is "main.TestView"
- case: method_decorator_function
main: |
from typing import Any
from django.views.generic.base import View
from django.utils.decorators import method_decorator
from django.contrib.auth.decorators import login_required
@@ -15,6 +16,6 @@
from django.http.request import HttpRequest
class TestView(View):
@method_decorator(login_required)
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponseBase:
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponseBase:
return super().dispatch(request, *args, **kwargs)
reveal_type(dispatch) # N: Revealed type is "def (self: main.TestView, request: django.http.request.HttpRequest, *args: Any, **kwargs: Any) -> django.http.response.HttpResponseBase"

View File

@@ -14,22 +14,24 @@
- case: dispatch_http_response
main: |
from django.http import HttpResponse
from typing import Any
from django.http import HttpRequest, HttpResponse
from django.views.generic.base import View
class MyView(View):
def dispatch(self, request, *args, **kwargs) -> HttpResponse:
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
response: HttpResponse
return response
- case: dispatch_streaming_http_response
main: |
from django.http import StreamingHttpResponse
from typing import Any
from django.http import HttpRequest, StreamingHttpResponse
from django.views.generic.base import View
class MyView(View):
def dispatch(self, request, *args, **kwargs) -> StreamingHttpResponse:
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> StreamingHttpResponse:
response: StreamingHttpResponse
return response