diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index 300760e..47a67e3 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -2,13 +2,14 @@ from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod from typing import cast +from django.db.models.fields import DateTimeField, DateField from django.db.models.fields.related import ForeignKey from django.db.models.fields.reverse_related import ManyToManyRel, ManyToOneRel, OneToOneRel from mypy.newsemanal.semanal import NewSemanticAnalyzer -from mypy.nodes import MDEF, SymbolTableNode, TypeInfo, Var +from mypy.nodes import MDEF, SymbolTableNode, TypeInfo, Var, Argument, ARG_NAMED_OPT, ARG_STAR2 from mypy.plugin import ClassDefContext from mypy.plugins import common -from mypy.types import Instance +from mypy.types import Instance, TypeOfAny, AnyType from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.lib import fullnames, helpers @@ -140,12 +141,13 @@ class AddManagers(ModelClassInitializer): continue -class AddFieldChoicesDisplayMethods(ModelClassInitializer): +class AddExtraFieldMethods(ModelClassInitializer): def run(self): model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.fullname) if model_cls is None: return + # get_FOO_display for choices for field in self.django_context.get_model_fields(model_cls): if field.choices: info = self.lookup_typeinfo_or_incomplete_defn_error('builtins.str') @@ -155,6 +157,26 @@ class AddFieldChoicesDisplayMethods(ModelClassInitializer): args=[], return_type=return_type) + # get_next_by, get_previous_by for Date, DateTime + for field in self.django_context.get_model_fields(model_cls): + if isinstance(field, (DateField, DateTimeField)) and not field.null: + return_type = Instance(self.model_classdef.info, []) + common.add_method(self.ctx, + name='get_next_by_{}'.format(field.attname), + args=[Argument(Var('kwargs', AnyType(TypeOfAny.explicit)), + AnyType(TypeOfAny.explicit), + initializer=None, + kind=ARG_STAR2)], + return_type=return_type) + common.add_method(self.ctx, + name='get_previous_by_{}'.format(field.attname), + args=[Argument(Var('kwargs', AnyType(TypeOfAny.explicit)), + AnyType(TypeOfAny.explicit), + initializer=None, + kind=ARG_STAR2)], + return_type=return_type) + + def process_model_class(ctx: ClassDefContext, django_context: DjangoContext) -> None: @@ -163,7 +185,7 @@ def process_model_class(ctx: ClassDefContext, AddDefaultPrimaryKey, AddRelatedModelsId, AddManagers, - AddFieldChoicesDisplayMethods, + AddExtraFieldMethods, ] for initializer_cls in initializers: try: diff --git a/test-data/typecheck/fields/test_related.yml b/test-data/typecheck/fields/test_related.yml index 9263028..e3cdaa5 100644 --- a/test-data/typecheck/fields/test_related.yml +++ b/test-data/typecheck/fields/test_related.yml @@ -183,6 +183,7 @@ - case: models_imported_inside_init_file_one_to_one_field main: | from myapp2.models import Profile + reveal_type(Profile().user) # N: Revealed type is 'myapp.models.user.User*' reveal_type(Profile().user.profile) # N: Revealed type is 'myapp2.models.Profile' installed_apps: - myapp diff --git a/test-data/typecheck/models/test_extra_methods.yml b/test-data/typecheck/models/test_extra_methods.yml index a920ada..846e8ab 100644 --- a/test-data/typecheck/models/test_extra_methods.yml +++ b/test-data/typecheck/models/test_extra_methods.yml @@ -18,3 +18,39 @@ class MyUser(models.Model): name = models.CharField(max_length=100) gender = models.CharField(max_length=100, choices=GENDER_CHOICES) + +- case: date_datetime_fields_have_get_next_by_get_previous_by + main: | + from myapp.models import MyUser + reveal_type(MyUser().get_next_by_date()) # N: Revealed type is 'myapp.models.MyUser' + reveal_type(MyUser().get_next_by_datetime()) # N: Revealed type is 'myapp.models.MyUser' + reveal_type(MyUser().get_previous_by_date()) # N: Revealed type is 'myapp.models.MyUser' + reveal_type(MyUser().get_previous_by_datetime()) # N: Revealed type is 'myapp.models.MyUser' + + # accept arbitrary kwargs + MyUser().get_next_by_date(arg1=1, arg2=2) + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class MyUser(models.Model): + date = models.DateField() + datetime = models.DateTimeField() + +- case: get_next_by_get_previous_by_absent_if_null_true + main: | + from myapp.models import MyUser + MyUser().get_next_by_date() # E: "MyUser" has no attribute "get_next_by_date" + MyUser().get_previous_by_date() # E: "MyUser" has no attribute "get_previous_by_date" + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class MyUser(models.Model): + date = models.DateField(null=True)