Implement magic method return values, fixes #1577

This commit is contained in:
Dave Halter
2020-05-15 19:09:44 +02:00
parent be594f1498
commit 41c146a6f3
4 changed files with 61 additions and 11 deletions

View File

@@ -31,6 +31,25 @@ from jedi.inference.context import CompForContext
from jedi.inference.value.decorator import Decoratee from jedi.inference.value.decorator import Decoratee
from jedi.plugins import plugin_manager from jedi.plugins import plugin_manager
operator_to_magic_method = {
'+': '__add__',
'-': '__sub__',
'*': '__mul__',
'/': '__div__',
'//': '__floordiv__',
'%': '__mod__',
'**': '__pow__',
'<<': '__lshift__',
'>>': '__rshift__',
'&': '__and__',
'|': '__or__',
'^': '__xor__',
}
reverse_operator_to_magic_method = {
k: '__r' + v[2:] for k, v in operator_to_magic_method.items()
}
def _limit_value_infers(func): def _limit_value_infers(func):
""" """
@@ -538,12 +557,8 @@ def _is_annotation_name(name):
return False return False
def _is_tuple(value): def _is_list_or_tuple(value):
return isinstance(value, iterable.Sequence) and value.array_type == 'tuple' return value.array_type in ('tuple', 'list')
def _is_list(value):
return isinstance(value, iterable.Sequence) and value.array_type == 'list'
def _bool_to_value(inference_state, bool_): def _bool_to_value(inference_state, bool_):
@@ -584,7 +599,7 @@ def _infer_comparison_part(inference_state, context, left, operator, right):
elif str_operator == '+': elif str_operator == '+':
if l_is_num and r_is_num or is_string(left) and is_string(right): if l_is_num and r_is_num or is_string(left) and is_string(right):
return left.execute_operation(right, str_operator) return left.execute_operation(right, str_operator)
elif _is_tuple(left) and _is_tuple(right) or _is_list(left) and _is_list(right): elif _is_list_or_tuple(left) and _is_list_or_tuple(right):
return ValueSet([iterable.MergedArray(inference_state, (left, right))]) return ValueSet([iterable.MergedArray(inference_state, (left, right))])
elif str_operator == '-': elif str_operator == '-':
if l_is_num and r_is_num: if l_is_num and r_is_num:
@@ -637,6 +652,20 @@ def _infer_comparison_part(inference_state, context, left, operator, right):
analysis.add(context, 'type-error-operation', operator, analysis.add(context, 'type-error-operation', operator,
message % (left, right)) message % (left, right))
if left.is_class() or right.is_class():
return NO_VALUES
method_name = operator_to_magic_method[str_operator]
magic_methods = left.py__getattribute__(method_name)
if not magic_methods:
reverse_method_name = reverse_operator_to_magic_method[str_operator]
magic_methods = left.py__getattribute__(reverse_method_name)
if magic_methods:
result = magic_methods.execute_with_values(right)
if result:
return result
result = ValueSet([left, right]) result = ValueSet([left, right])
debug.dbg('Used operator %s resulting in %s', operator, result) debug.dbg('Used operator %s resulting in %s', operator, result)
return result return result

View File

@@ -241,7 +241,7 @@ args_func(*1)[0]
args_func(*iter([1]))[0] args_func(*iter([1]))[0]
# different types # different types
e = args_func(*[1+"", {}]) e = args_func(*[1 if UNDEFINED else "", {}])
#? int() str() #? int() str()
e[0] e[0]
#? dict() #? dict()

View File

@@ -5,7 +5,10 @@ def positional_only_call(a, /, b):
a a
#? int() #? int()
b b
return a + b if UNDEFINED:
return a
else:
return b
#? int() str() #? int() str()
@@ -13,7 +16,10 @@ positional_only_call('', 1)
def positional_only_call2(a, /, b=3): def positional_only_call2(a, /, b=3):
return a + b if UNDEFINED:
return a
else:
return b
#? int() #? int()
positional_only_call2(1) positional_only_call2(1)

View File

@@ -56,7 +56,7 @@ a
(3 ** 3) (3 ** 3)
#? int() str() #? int() str()
(3 ** 'a') (3 ** 'a')
#? int() str() #? int()
(3 + 'a') (3 + 'a')
#? bool() #? bool()
(3 == 'a') (3 == 'a')
@@ -147,3 +147,18 @@ a = foobarbaz + 'hello'
#? int() float() #? int() float()
{'hello': 1, 'bar': 1.0}[a] {'hello': 1, 'bar': 1.0}[a]
# -----------------
# stubs
# -----------------
from datetime import datetime, timedelta
#?
(datetime - timedelta)
#? datetime()
(datetime() - timedelta())
#? timedelta()
(timedelta() - datetime())
#? timedelta()
(timedelta() - timedelta())