add get_next_by_FOO, get_previous_by_FOO for date,datetime fields

This commit is contained in:
Maxim Kurnikov
2019-07-19 03:08:50 +03:00
parent 1721c997be
commit 5bb1bc250d
3 changed files with 63 additions and 4 deletions

View File

@@ -2,13 +2,14 @@ from abc import ABCMeta, abstractmethod
from abc import ABCMeta, abstractmethod
from typing import cast
from django.db.models.fields import DateTimeField, DateField
from django.db.models.fields.related import ForeignKey
from django.db.models.fields.reverse_related import ManyToManyRel, ManyToOneRel, OneToOneRel
from mypy.newsemanal.semanal import NewSemanticAnalyzer
from mypy.nodes import MDEF, SymbolTableNode, TypeInfo, Var
from mypy.nodes import MDEF, SymbolTableNode, TypeInfo, Var, Argument, ARG_NAMED_OPT, ARG_STAR2
from mypy.plugin import ClassDefContext
from mypy.plugins import common
from mypy.types import Instance
from mypy.types import Instance, TypeOfAny, AnyType
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.lib import fullnames, helpers
@@ -140,12 +141,13 @@ class AddManagers(ModelClassInitializer):
continue
class AddFieldChoicesDisplayMethods(ModelClassInitializer):
class AddExtraFieldMethods(ModelClassInitializer):
def run(self):
model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.fullname)
if model_cls is None:
return
# get_FOO_display for choices
for field in self.django_context.get_model_fields(model_cls):
if field.choices:
info = self.lookup_typeinfo_or_incomplete_defn_error('builtins.str')
@@ -155,6 +157,26 @@ class AddFieldChoicesDisplayMethods(ModelClassInitializer):
args=[],
return_type=return_type)
# get_next_by, get_previous_by for Date, DateTime
for field in self.django_context.get_model_fields(model_cls):
if isinstance(field, (DateField, DateTimeField)) and not field.null:
return_type = Instance(self.model_classdef.info, [])
common.add_method(self.ctx,
name='get_next_by_{}'.format(field.attname),
args=[Argument(Var('kwargs', AnyType(TypeOfAny.explicit)),
AnyType(TypeOfAny.explicit),
initializer=None,
kind=ARG_STAR2)],
return_type=return_type)
common.add_method(self.ctx,
name='get_previous_by_{}'.format(field.attname),
args=[Argument(Var('kwargs', AnyType(TypeOfAny.explicit)),
AnyType(TypeOfAny.explicit),
initializer=None,
kind=ARG_STAR2)],
return_type=return_type)
def process_model_class(ctx: ClassDefContext,
django_context: DjangoContext) -> None:
@@ -163,7 +185,7 @@ def process_model_class(ctx: ClassDefContext,
AddDefaultPrimaryKey,
AddRelatedModelsId,
AddManagers,
AddFieldChoicesDisplayMethods,
AddExtraFieldMethods,
]
for initializer_cls in initializers:
try: