diff --git a/jedi/inference/value/instance.py b/jedi/inference/value/instance.py index ac05288f..63f220e0 100644 --- a/jedi/inference/value/instance.py +++ b/jedi/inference/value/instance.py @@ -150,6 +150,16 @@ class AbstractInstanceValue(Value): args = ValuesArguments([index_value_set]) return ValueSet.from_sets(name.infer().execute(args) for name in names) + def py__iter__(self, contextualized_node=None): + iter_slot_names = self.get_function_slot_names('__iter__') + if not iter_slot_names: + return super().py__iter__(contextualized_node) + + def iterate(): + for generator in self.execute_function_slots(iter_slot_names): + yield from generator.py__next__(contextualized_node) + return iterate() + def __repr__(self): return "<%s of %s>" % (self.__class__.__name__, self.class_value) @@ -254,16 +264,6 @@ class _BaseTreeInstance(AbstractInstanceValue): or self.get_function_slot_names('__getattribute__')) return self.execute_function_slots(names, name) - def py__iter__(self, contextualized_node=None): - iter_slot_names = self.get_function_slot_names('__iter__') - if not iter_slot_names: - return super().py__iter__(contextualized_node) - - def iterate(): - for generator in self.execute_function_slots(iter_slot_names): - yield from generator.py__next__(contextualized_node) - return iterate() - def py__next__(self, contextualized_node=None): name = u'__next__' next_slot_names = self.get_function_slot_names(name) diff --git a/test/test_api/test_interpreter.py b/test/test_api/test_interpreter.py index ecf10e97..b7be471f 100644 --- a/test/test_api/test_interpreter.py +++ b/test/test_api/test_interpreter.py @@ -600,16 +600,31 @@ def test_dict_getitem(code, types): @pytest.mark.parametrize('class_is_findable', [False, True]) -def test__getitem__(class_is_findable): - class GetitemCls: +@pytest.mark.parametrize( + 'code, expected', [ + ('DunderCls()[0]', 'int'), + ('next(DunderCls())', 'float'), + ('for x in DunderCls(): x', 'str'), + ] +) +def test_dunders(class_is_findable, code, expected): + from typing import Iterator + + class DunderCls: def __getitem__(self, key) -> int: pass - if not class_is_findable: - GetitemCls.__name__ = 'asdf' + def __iter__(self, key) -> Iterator[str]: + pass - n, = jedi.Interpreter('GetitemCls()[0]', [locals()]).infer() - assert n.name == 'int' + def __next__(self, key) -> float: + pass + + if not class_is_findable: + DunderCls.__name__ = 'asdf' + + n, = jedi.Interpreter(code, [locals()]).infer() + assert n.name == expected def foo():