diff --git a/jedi/inference/context.py b/jedi/inference/context.py index b60c2967..e500707d 100644 --- a/jedi/inference/context.py +++ b/jedi/inference/context.py @@ -308,11 +308,14 @@ class ModuleContext(TreeContextMixin, ValueContext): until_position=until_position, origin_scope=origin_scope ), - GlobalNameFilter(self, self.tree_node), + self.get_global_filter(), ) for f in filters: # Python 2... yield f + def get_global_filter(self): + return GlobalNameFilter(self, self.tree_node) + @property def string_names(self): return self._value.string_names diff --git a/jedi/inference/names.py b/jedi/inference/names.py index 9335c446..1202f7e8 100644 --- a/jedi/inference/names.py +++ b/jedi/inference/names.py @@ -564,7 +564,7 @@ class ImportName(AbstractNameDefinition): return m # It's almost always possible to find the import or to not find it. The # importing returns only one value, pretty much always. - return next(iter(import_values)) + return next(iter(import_values)).as_context() @memoize_method def infer(self): diff --git a/jedi/inference/references.py b/jedi/inference/references.py index 22fe4e92..ce61cbf6 100644 --- a/jedi/inference/references.py +++ b/jedi/inference/references.py @@ -1,4 +1,5 @@ from jedi.inference import imports +from jedi.inference.filters import ParserTreeFilter def _resolve_names(definition_names, avoid_names=()): @@ -25,17 +26,68 @@ def _dictionarize(names): ) +def _find_defining_names(module_context, tree_name): + found_names = _find_names(module_context, tree_name) + + found_names |= set(_find_global_variables(found_names, tree_name.value)) + for name in list(found_names): + if name.api_type == 'param' or name.tree_name is None \ + or name.tree_name.parent.type == 'trailer': + continue + found_names |= set(_add_names_in_same_context(name.parent_context, name.string_name)) + return set(_resolve_names(found_names)) + + def _find_names(module_context, tree_name): name = module_context.create_name(tree_name) found_names = set(name.goto()) found_names.add(name) - return _dictionarize(_resolve_names(found_names)) + + return set(_resolve_names(found_names)) + + +def _add_names_in_same_context(context, string_name): + if context.tree_node is None: + return + + until_position = None + while True: + filter_ = ParserTreeFilter( + parent_context=context, + until_position=until_position, + ) + names = set(filter_.get(string_name)) + if not names: + break + for name in names: + yield name + ordered = sorted(names, key=lambda x: x.start_pos) + until_position = ordered[0].start_pos + + +def _find_global_variables(names, search_name): + for name in names: + if name.tree_name is None: + continue + module_context = name.get_root_context() + try: + method = module_context.get_global_filter + except AttributeError: + continue + else: + for global_name in method().get(search_name): + yield global_name + c = module_context.create_context(global_name.tree_name) + for name in _add_names_in_same_context(c, global_name.string_name): + yield name def find_references(module_context, tree_name): search_name = tree_name.value - found_names = _find_names(module_context, tree_name) - module_contexts = set(d.get_root_context() for d in found_names.values()) + found_names = _find_defining_names(module_context, tree_name) + found_names_dct = _dictionarize(found_names) + + module_contexts = set(d.get_root_context() for d in found_names) module_contexts = set(m for m in module_contexts if not m.is_compiled()) non_matching_reference_maps = {} @@ -45,14 +97,14 @@ def find_references(module_context, tree_name): ) for module_context in potential_modules: for name_leaf in module_context.tree_node.get_used_names().get(search_name, []): - new = _find_names(module_context, name_leaf) - if any(tree_name in found_names for tree_name in new): - found_names.update(new) + new = _dictionarize(_find_names(module_context, name_leaf)) + if any(tree_name in found_names_dct for tree_name in new): + found_names_dct.update(new) for tree_name in new: for dct in non_matching_reference_maps.get(tree_name, []): # A reference that was previously searched for matches # with a now found name. Merge. - found_names.update(dct) + found_names_dct.update(dct) try: del non_matching_reference_maps[tree_name] except KeyError: @@ -60,4 +112,4 @@ def find_references(module_context, tree_name): else: for name in new: non_matching_reference_maps.setdefault(name, []).append(new) - return found_names.values() + return found_names_dct.values() diff --git a/jedi/inference/value/function.py b/jedi/inference/value/function.py index 90610165..dfe190f9 100644 --- a/jedi/inference/value/function.py +++ b/jedi/inference/value/function.py @@ -97,9 +97,6 @@ class FunctionMixin(object): class FunctionValue(use_metaclass(CachedMetaClass, FunctionMixin, FunctionAndClassBase)): - def is_function(self): - return True - @classmethod def from_context(cls, context, tree_node): def create(tree_node): @@ -162,6 +159,9 @@ class MethodValue(FunctionValue): class BaseFunctionExecutionContext(ValueContext, TreeContextMixin): + def is_function_execution(self): + return True + def _infer_annotations(self): raise NotImplementedError diff --git a/test/completion/usages.py b/test/completion/usages.py index 2f8d1802..7cd6b725 100644 --- a/test/completion/usages.py +++ b/test/completion/usages.py @@ -3,34 +3,34 @@ Renaming tests. This means search for references. I always leave a little bit of space to add room for additions, because the results always contain position informations. """ -#< 4 (0,4), (3,0), (5,0), (17,0), (12,4), (14,5), (15,0) -def abc(): pass +#< 4 (0,4), (3,0), (5,0), (12,4), (14,5), (15,0), (17,0), (19,0) +def abcd(): pass -#< 0 (-3,4), (0,0), (2,0), (14,0), (9,4), (11,5), (12,0) -abc.d.a.bsaasd.abc.d +#< 0 (-3,4), (0,0), (2,0), (9,4), (11,5), (12,0), (14,0), (16,0) +abcd.d.a.bsaasd.abcd.d -abc +abcd # unicode chars shouldn't be a problem. -x['smörbröd'].abc +x['smörbröd'].abcd # With the new parser these statements are not recognized as stateents, because # they are not valid Python. if 1: - abc = + abcd = else: - (abc) = -abc = -#< (-17,4), (-14,0), (-12,0), (0,0), (-2,0), (-3,5), (-5,4) -abc + (abcd) = +abcd = +#< (-17,4), (-14,0), (-12,0), (0,0), (2,0), (-2,0), (-3,5), (-5,4) +abcd -abc = 5 +abcd = 5 Abc = 3 -#< 6 (0,6), (2,4), (5,8), (17,0) +#< 6 (-3,0), (0,6), (2,4), (5,8), (17,0) class Abc(): - #< (-2,6), (0,4), (3,8), (15,0) + #< (-5,0), (-2,6), (0,4), (2,8), (3,8), (15,0) Abc def Abc(self): @@ -65,11 +65,11 @@ set_object_var.var = 1 response = 5 -#< 0 (0,0), (1,0), (2,0), (4,0) +#< 0 (-2,0), (0,0), (1,0), (2,0), (4,0) response = HttpResponse(mimetype='application/pdf') response['Content-Disposition'] = 'attachment; filename=%s.pdf' % id response.write(pdf) -#< (-4,0), (-3,0), (-2,0), (0,0) +#< (-6,0), (-4,0), (-3,0), (-2,0), (0,0) response @@ -215,18 +215,18 @@ class TestProperty: self.prop @property - #< 13 (0,8), (4,5) + #< 13 (0,8), (4,5), (6,8), (11,13) def rw_prop(self): return self._rw_prop - #< 8 (-4,8), (0,5) + #< 8 (-4,8), (0,5), (2,8), (7,13) @rw_prop.setter - #< 8 (0,8), (5,13) + #< 8 (-6,8), (-2,5), (0,8), (5,13) def rw_prop(self, value): self._rw_prop = value def b(self): - #< 13 (-5,8), (0,13) + #< 13 (-11,8), (-7,5), (-5,8), (0,13) self.rw_prop # ----------------- @@ -287,9 +287,9 @@ x = 32 [x for x in x] #< 0 (0,0), (2,1), (2,12) -x = 32 +y = 32 #< 12 (-2,0), (0,1), (0,12) -[x for b in x] +[y for b in y] #< 1 (0,1), (0,7) @@ -297,13 +297,13 @@ x = 32 #< 7 (0,1), (0,7) [x for x in something] -x = 3 +z = 3 #< 1 (0,1), (0,10) -{x:1 for x in something} +{z:1 for z in something} #< 10 (0,1), (0,10) -{x:1 for x in something} +{z:1 for z in something} -def x(): +def whatever_func(): zzz = 3 if UNDEFINED: zzz = 5 @@ -314,3 +314,31 @@ def x(): #< (0, 8), (1, 4), (-3, 12), (-6, 8), (-8, 4) zzz zzz + +# ----------------- +# global +# ----------------- + +def global_usage1(): + #< (0, 4), (4, 11), (6, 4), (9, 8), (12, 4) + my_global + +def global_definition(): + #< (-4, 4), (0, 11), (2, 4), (5, 8), (8, 4) + global my_global + #< 4 (-6, 4), (-2, 11), (0, 4), (3, 8), (6, 4) + my_global = 3 + if WHATEVER: + #< 8 (-9, 4), (-5, 11), (-3, 4), (0, 8), (3, 4) + my_global = 4 + +def global_usage2() + my_global + +def not_global(my_global): + my_global + +class DefinitelyNotGlobal: + def my_global(self): + def my_global(self): + pass