diff --git a/jedi/evaluate/filters.py b/jedi/evaluate/filters.py index 5db0c95a..70fa03a7 100644 --- a/jedi/evaluate/filters.py +++ b/jedi/evaluate/filters.py @@ -11,6 +11,7 @@ from jedi.evaluate import flow_analysis from jedi.evaluate.base_context import ContextSet, Context from jedi.parser_utils import get_parent_scope from jedi.evaluate.utils import to_list +from jedi.evaluate.cache import evaluator_function_cache class AbstractNameDefinition(object): @@ -188,28 +189,44 @@ class NameWrapper(object): return '%s(%s)' % (self.__class__.__name__, self._wrapped_name) +@evaluator_function_cache() +def _get_definition_names(evaluator, module_node, name_key): + try: + names = module_node.get_used_names()[name_key] + except KeyError: + return [] + return [name for name in names if name.is_definition()] + + class AbstractUsedNamesFilter(AbstractFilter): name_class = TreeNameDefinition def __init__(self, context, parser_scope): self._parser_scope = parser_scope - self._used_names = self._parser_scope.get_root_node().get_used_names() + self._module_node = self._parser_scope.get_root_node() + self._used_names = self._module_node.get_used_names() self.context = context def get(self, name): - try: - names = self._used_names[name] - except KeyError: - return [] - - return self._convert_names(self._filter(names)) + #print(self, self.context, name, type(self).__name__) + #import traceback, sys; traceback.print_stack(file=sys.stdout) + return self._convert_names(self._filter( + _get_definition_names(self.context.evaluator, self._module_node, name) + )) def _convert_names(self, names): return [self.name_class(self.context, name) for name in names] def values(self): - return self._convert_names(name for name_list in self._used_names.values() - for name in self._filter(name_list)) + evaluator = self.context.evaluator + module_node = self._module_node + return self._convert_names( + name + for name_key in self._used_names + for name in self._filter( + _get_definition_names(evaluator, module_node, name_key) + ) + ) def __repr__(self): return '<%s: %s>' % (self.__class__.__name__, self.context) @@ -238,8 +255,6 @@ class ParserTreeFilter(AbstractUsedNamesFilter): return list(self._check_flows(names)) def _is_name_reachable(self, name): - if not name.is_definition(): - return False parent = name.parent if parent.type == 'trailer': return False @@ -288,12 +303,25 @@ class GlobalNameFilter(AbstractUsedNamesFilter): def __init__(self, context, parser_scope): super(GlobalNameFilter, self).__init__(context, parser_scope) + def get(self, name): + try: + names = self._used_names[name] + except KeyError: + return [] + return self._convert_names(self._filter(names)) + @to_list def _filter(self, names): for name in names: if name.parent.type == 'global_stmt': yield name + def values(self): + return self._convert_names( + name for name_list in self._used_names.values() + for name in self._filter(name_list) + ) + class DictFilter(AbstractFilter): def __init__(self, dct):