diff --git a/debug.py b/debug.py index 1b1f1cdd..f93317f0 100644 --- a/debug.py +++ b/debug.py @@ -14,7 +14,6 @@ WARNING = object() ERROR = object() debug_function = None -#debug_function = print_to_stdout ignored_modules = ['parsing', 'builtin'] @@ -41,3 +40,6 @@ def print_to_stdout(level, *args): """ The default debug function """ msg = (Fore.GREEN + 'dbg: ' if level == NOTICE else Fore.RED + 'warning: ') print(msg + ', '.join(str(a) for a in args) + Fore.RESET) + + +#debug_function = print_to_stdout diff --git a/dynamic.py b/dynamic.py index d4c1744a..118f5a3c 100644 --- a/dynamic.py +++ b/dynamic.py @@ -144,6 +144,19 @@ def dec(func): return wrapper +def _scan_array(arr, search_name): + """ Returns the function Call that match func_name in an Array. """ + result = [] + for sub in arr: + for s in sub: + if isinstance(s, parsing.Array): + result += _scan_array(s, search_name) + elif isinstance(s, parsing.Call): + n = s.name + if isinstance(n, parsing.Name) and search_name in n.names: + result.append(s) + return result + #@dec @evaluate.memoize_default([]) def _check_array_additions(compare_array, module, is_list): @@ -155,19 +168,6 @@ def _check_array_additions(compare_array, module, is_list): if not settings.dynamic_array_additions: return [] - def scan_array(arr, search_name): - """ Returns the function Calls that match func_name """ - result = [] - for sub in arr: - for s in sub: - if isinstance(s, parsing.Array): - result += scan_array(s, search_name) - elif isinstance(s, parsing.Call): - n = s.name - if isinstance(n, parsing.Name) and search_name in n.names: - result.append(s) - return result - def check_calls(calls, add_name): """ Calls are processed here. The part before the call is searched and @@ -245,7 +245,7 @@ def _check_array_additions(compare_array, module, is_list): if evaluate.follow_statement.push_stmt(stmt): # check recursion continue - res += check_calls(scan_array(stmt.get_assignment_calls(), n), n) + res += check_calls(_scan_array(stmt.get_assignment_calls(), n), n) evaluate.follow_statement.pop_stmt() return res @@ -291,3 +291,45 @@ class ArrayInstance(parsing.Base): is_list = str(self.instance.name) == 'list' items += _check_array_additions(self.instance, module, is_list) return items + + +def get_names(definitions, search_name, modules): + def check_call(call): + result = [] + follow = [] # There might be multiple search_name's in one call_path + call_path = list(call.generate_call_path()) + for i, name in enumerate(call_path): + if name == search_name: + follow.append(call_path[:i + 1]) + + for f in follow: + scope = call.parent_stmt().parent() + evaluate.statement_path = [] + position = call.parent_stmt().start_pos + if len(f) > 1: + f, search = f[:-1], f[-1] + else: + search = None + scopes = evaluate.follow_call_path(iter(f), scope, position) + follow_res = evaluate.goto(scopes, search, statement_path_offset=0) + + # compare to see if they match + if True in [r in definitions for r in follow_res]: + l = f[-1] # the NamePart object + result.append((l, l.start_pos, l.end_pos)) + + return result + + # TODO check modules in the same directoy + names = [] + for m in modules: + try: + stmts = m.used_names[search_name] + except KeyError: + continue + #TODO check heritage of statements + for stmt in stmts: + for call in _scan_array(stmt.get_assignment_calls(), search_name): + names += check_call(call) + print 'n', names + return names diff --git a/evaluate.py b/evaluate.py index 2d3c51d8..8b43f067 100644 --- a/evaluate.py +++ b/evaluate.py @@ -90,6 +90,13 @@ def clear_caches(): follow_statement.reset() +def statement_path_check(function): + def wrapper(stmt, *args, **kwargs): + statement_path.append(stmt) + return function(stmt, *args, **kwargs) + return wrapper + + def memoize_default(default=None): """ This is a typical memoization decorator, BUT there is one difference: @@ -731,7 +738,7 @@ class Generator(use_metaclass(CachedMetaClass, parsing.Base)): none_pos = (0, 0) executes_generator = ('__next__', 'send') for n in ('close', 'throw') + executes_generator: - name = parsing.Name([n], none_pos, none_pos) + name = parsing.Name([n, none_pos], none_pos, none_pos) if n in executes_generator: name.parent = weakref.ref(self) names.append(name) @@ -1227,6 +1234,7 @@ def assign_tuples(tup, results, seek_name): @helpers.RecursionDecorator +@statement_path_check @memoize_default(default=[]) def follow_statement(stmt, seek_name=None): """ @@ -1238,8 +1246,6 @@ def follow_statement(stmt, seek_name=None): :param stmt: A `parsing.Statement`. :param seek_name: A string. """ - statement_path.append(stmt) # important to know for the goto function - debug.dbg('follow_stmt %s (%s)' % (stmt, seek_name)) call_list = stmt.get_assignment_calls() debug.dbg('calls: %s' % call_list) @@ -1425,3 +1431,50 @@ def follow_path(path, scope, position=None): result = imports.strip_imports(get_scopes_for_name(scope, current, position=position)) return follow_paths(path, set(result), position=position) + + +def goto(scopes, search_name=None, statement_path_offset=1): + if search_name is None: + try: + definitions = [statement_path[statement_path_offset]] + except IndexError: + definitions = [] + for s in scopes: + if isinstance(s, imports.ImportPath): + s = s.follow()[0] + try: + s = statement_path[0] + except IndexError: + pass + definitions.append(s) + else: + def remove_unreal_imports(names): + """ + These imports are only virtual, because of multi-line imports. + """ + new_names = [] + for n in names: + par = n.parent() + # This is a special case: If the Import is "virtual" (which + # means the position is not defined), follow those modules. + if isinstance(par, parsing.Import) and not par.start_pos[0]: + module_count = 0 + for scope in imports.ImportPath(par).follow(): + if isinstance(scope, parsing.Import): + temp = scope.get_defined_names() + new_names += remove_unreal_imports(temp) + elif isinstance(scope, parsing.Module) \ + and not module_count: + # only first module (others are star imports) + module_count += 1 + new_names.append(scope.get_module_name(n.names)) + else: + new_names.append(n) + return new_names + + names = [] + for s in scopes: + names += s.get_defined_names() + names = remove_unreal_imports(names) + definitions = [n for n in names if n.names[-1] == search_name] + return definitions diff --git a/functions.py b/functions.py index 0debf320..36b4beed 100644 --- a/functions.py +++ b/functions.py @@ -265,6 +265,10 @@ def get_definition(source, line, column, source_path): def goto(source, line, column, source_path): + return _goto(source, line, column, source_path) + +def _goto(source, line, column, source_path): + """ for internal use """ pos = (line, column) f = modules.ModuleWithCursor(source_path, source=source, position=pos) @@ -325,6 +329,30 @@ def goto(source, line, column, source_path): return d +def get_related_names(source, line, column, source_path): + pos = (line, column) + f = modules.ModuleWithCursor(source_path, source=source, position=pos) + + goto_path = f.get_path_under_cursor() + goto_path, dot, search_name = _get_completion_parts(goto_path) + + # define goto path the right way + if not dot: + goto_path = search_name + search_name_new = None + else: + search_name_new = search_name + + scopes = _prepare_goto(source, pos, source_path, f, goto_path) + print scopes, search_name + definitions = evaluate.goto(scopes, search_name_new) + module = set([d.get_parent_until() for d in definitions]) + module.add(f.parser.module) + dynamic.get_names(definitions, search_name, module) + _clear_caches() + return + + def set_debug_function(func_cb): """ You can define a callback debug function to get all the debug messages. diff --git a/parsing.py b/parsing.py index 64d85152..fab49177 100644 --- a/parsing.py +++ b/parsing.py @@ -551,7 +551,8 @@ class Import(Simple): return [self.alias] if len(self.namespace) > 1: o = self.namespace - n = Name([o.names[0]], o.start_pos, o.end_pos, parent=o.parent()) + n = Name([(o.names[0], o.start_pos)], o.start_pos, o.end_pos, + parent=o.parent()) return [n] else: return [self.namespace] @@ -988,7 +989,14 @@ class NamePart(str): A string. Sometimes it is important to know if the string belongs to a name or not. """ - pass + def __new__(cls, s, start_pos): + self = super(NamePart, cls).__new__(cls, s) + self.start_pos = start_pos + return self + + @property + def end_pos(self): + return self.start_pos[0], self.start_pos[1] + len(self) class Name(Simple): @@ -1000,7 +1008,8 @@ class Name(Simple): """ def __init__(self, names, start_pos, end_pos, parent=None): super(Name, self).__init__(start_pos, end_pos) - self.names = tuple(NamePart(n) for n in names) + self.names = tuple(n if isinstance(n, NamePart) else NamePart(*n) + for n in names) if parent is not None: self.parent = weakref.ref(parent) @@ -1104,7 +1113,7 @@ class PyFuzzyParser(object): """ def append(el): names.append(el) - self.module.temp_used_names.append(el) + self.module.temp_used_names.append(el[0]) names = [] if pre_used_token is None: @@ -1114,7 +1123,7 @@ class PyFuzzyParser(object): else: token_type, tok = pre_used_token - append(tok) + append((tok, self.start_pos)) first_pos = self.start_pos while True: token_type, tok = self.next() @@ -1123,7 +1132,7 @@ class PyFuzzyParser(object): token_type, tok = self.next() if token_type != tokenize.NAME: break - append(tok) + append((tok, self.start_pos)) n = Name(names, first_pos, self.end_pos) if names else None return (n, token_type, tok) @@ -1209,7 +1218,7 @@ class PyFuzzyParser(object): if token_type != tokenize.NAME: return None - fname = Name([fname], self.start_pos, self.end_pos) + fname = Name([(fname, self.start_pos)], self.start_pos, self.end_pos) token_type, open = self.next() if open != '(': @@ -1252,7 +1261,7 @@ class PyFuzzyParser(object): % (self.start_pos[0], tokenize.tok_name[token_type], cname)) return None - cname = Name([cname], self.start_pos, self.end_pos) + cname = Name([(cname, self.start_pos)], self.start_pos, self.end_pos) super = [] token_type, next = self.next() @@ -1436,9 +1445,9 @@ class PyFuzzyParser(object): if not isinstance(stmt, Param): for tok_name in self.module.temp_used_names: try: - self.module.used_names[tok_name].append(stmt) + self.module.used_names[tok_name].add(stmt) except KeyError: - self.module.used_names[tok_name] = [stmt] + self.module.used_names[tok_name] = set([stmt]) self.module.temp_used_names = [] if is_return: # add returns to the scope diff --git a/test/completion/goto.py b/test/completion/goto.py index 07efa368..faea2fa9 100644 --- a/test/completion/goto.py +++ b/test/completion/goto.py @@ -122,15 +122,10 @@ def func(): func().b() # ----------------- -# get_definition +# on itself # ----------------- #! 7 ['class ClassDef'] class ClassDef(): """ abc """ pass - -##? 6 ClassDef2() -class ClassDef2(): - """ abc """ - pass