From cbba314286b7605661e3ffb2d0ecabe2138da82b Mon Sep 17 00:00:00 2001 From: Dave Halter Date: Thu, 26 May 2016 00:10:54 +0200 Subject: [PATCH] Progress and actually passing a few tests. --- jedi/api/__init__.py | 3 +- jedi/api/completion.py | 47 +++++++++++++++++++++++++----- jedi/api/helpers.py | 64 ++++++++++++++++++++++++++++++----------- jedi/api/inference.py | 15 +++++----- jedi/parser/__init__.py | 39 +++++++++++++------------ jedi/parser/fast.py | 8 +++--- jedi/parser/tree.py | 54 +++++++++++++++++++++++++++++----- 7 files changed, 169 insertions(+), 61 deletions(-) diff --git a/jedi/api/__init__.py b/jedi/api/__init__.py index df8d17e7..1c7ce8f6 100644 --- a/jedi/api/__init__.py +++ b/jedi/api/__init__.py @@ -162,8 +162,9 @@ class Script(object): self._evaluator, self._parser, self._user_context, self._pos, self.call_signatures ) + completions = completion.completions(path) debug.speed('completions end') - return completion.completions(path) + return completions def goto_definitions(self): """ diff --git a/jedi/api/completion.py b/jedi/api/completion.py index d86687c1..313bcafb 100644 --- a/jedi/api/completion.py +++ b/jedi/api/completion.py @@ -123,20 +123,38 @@ class Completion: helpers.check_error_statements(module, self._pos) grammar = self._evaluator.grammar - stack = helpers.get_stack_at_position(grammar, module, self._pos) + + # Now we set the position to the place where we try to find out what we + # have before it. + pos = self._pos + if completion_parts.name: + pos = pos[0], pos[1] - len(completion_parts.name) + + stack = helpers.get_stack_at_position(grammar, module, pos) allowed_keywords, allowed_tokens = \ helpers.get_possible_completion_types(grammar, stack) + print(allowed_keywords, allowed_tokens) completion_names = list(self._get_keyword_completion_names(allowed_keywords)) - if token.NAME in allowed_tokens: - # Differentiate between import names and other names. - completion_names += self._simple_complete(completion_parts) + if token.NAME in allowed_tokens: + # This means that we actually have to do type inference. + + symbol_names = list(stack.get_node_names(grammar)) + + if "import_stmt" in symbol_names: + if "dotted_name" in symbol_names: + completion_names += self._complete_dotted_name(stack, module) + else: + completion_names += self._simple_complete(completion_parts) + + """ completion_names = [] if names is not None: imp_names = tuple(str(n) for n in names if n.end_pos < self._pos) i = imports.Importer(self._evaluator, imp_names, module, level) completion_names = i.completion_names(self._evaluator, only_modules) + """ return completion_names @@ -170,9 +188,9 @@ class Completion: completion_names += self._simple_complete(completion_parts) return completion_names - def _get_keyword_completion_names(self, keywords): - for keyword in keywords: - yield keywords.keyword(self._evaluator, keyword).name + def _get_keyword_completion_names(self, keywords_): + for k in keywords_: + yield keywords.keyword(self._evaluator, k).name def _simple_complete(self, completion_parts): if not completion_parts.path and not completion_parts.has_dot: @@ -211,3 +229,18 @@ class Completion: names, self._parser.user_stmt() ) return completion_names + + def _complete_dotted_name(self, stack, module): + nodes = list(stack.get_nodes()) + + level = 0 + for i, node in enumerate(nodes[1:], 1): + if node in ('.', '...'): + level += len(node.value) + else: + names = [str(n) for n in nodes[i::2]] + break + + print(names, nodes) + i = imports.Importer(self._evaluator, names, module, level) + return i.completion_names(self._evaluator, only_modules=True) diff --git a/jedi/api/helpers.py b/jedi/api/helpers.py index e42ff76d..9516578a 100644 --- a/jedi/api/helpers.py +++ b/jedi/api/helpers.py @@ -13,6 +13,7 @@ from jedi.parser import tokenize, token CompletionParts = namedtuple('CompletionParts', ['path', 'has_dot', 'name']) + def get_completion_parts(path_until_cursor): """ Returns the parts for the completion @@ -42,25 +43,28 @@ def get_on_import_stmt(evaluator, user_context, user_stmt, is_like_search=False) def check_error_statements(module, pos): - for error_statement in module.error_statement_stacks: + for error_statement in module.error_statements: if error_statement.first_type in ('import_from', 'import_name') \ - and error_statement.first_pos < pos <= error_statement.next_start_pos: + and error_statement.start_pos < pos <= error_statement.end_pos: return importer_from_error_statement(error_statement, pos) return None, 0, False, False -def get_code_until(code, start_pos, end_pos): +def get_code_until(code, code_start_pos, end_pos): + """ + :param code_start_pos: is where the code starts. + """ lines = common.splitlines(code) - line_difference = end_pos[0] - start_pos[0] + line_difference = end_pos[0] - code_start_pos[0] if line_difference == 0: - end_line_length = end_pos[1] - start_pos[1] + end_line_length = end_pos[1] - code_start_pos[1] else: end_line_length = end_pos[1] - if line_difference > len(lines) or end_line_length > len(lines[-1]): + if line_difference > len(lines) or end_line_length > len(lines[line_difference]): raise ValueError("The end_pos seems to be after the code part.") - new_lines = lines[:line_difference] + [lines[-1][:end_line_length]] + new_lines = lines[:line_difference] + [lines[line_difference][:end_line_length]] return '\n'.join(new_lines) @@ -68,30 +72,56 @@ def get_stack_at_position(grammar, module, pos): """ Returns the possible node names (e.g. import_from, xor_test or yield_stmt). """ - for error_statement in module.error_statement_stacks: - if error_statement.first_pos < pos <= error_statement.next_start_pos: - code = error_statement.get_code() - code = get_code_until(code, error_statement.first_pos, pos) - break + user_stmt = module.get_statement_for_position(pos) + if user_stmt is None: + # If there's no error statement and we're just somewhere, we want + # completions for just whitespace. + code = '' + + for error_statement in module.error_statements: + if error_statement.start_pos < pos <= error_statement.end_pos: + code = error_statement.get_code(include_prefix=False) + start_pos = error_statement.start_pos + break else: - raise NotImplementedError + code = user_stmt.get_code_with_error_statements(include_prefix=False) + start_pos = user_stmt.start_pos + + # Remove indentations. + code = code.lstrip() + code = get_code_until(code, start_pos, pos) + # Remove whitespace at the end. + code = code.rstrip() class EndMarkerReached(Exception): pass def tokenize_without_endmarker(code): - for token_ in tokenize.source_tokens(code): + for token_ in tokenize.source_tokens(code, use_exact_op_types=True): if token_[0] == token.ENDMARKER: raise EndMarkerReached() else: + print(token_, token.tok_name[token_[0]]) yield token_ - p = parser.Parser(grammar, code, tokenizer=tokenize_without_endmarker(code), + print(repr(code)) + p = parser.Parser(grammar, code, start_parsing=False) try: - p.parse() + p.parse(tokenizer=tokenize_without_endmarker(code)) except EndMarkerReached: - return p.pgen_parser.stack + return Stack(p.pgen_parser.stack) + + +class Stack(list): + def get_node_names(self, grammar): + for dfa, state, (node_number, nodes) in self: + yield grammar.number2symbol[node_number] + + def get_nodes(self): + for dfa, state, (node_number, nodes) in self: + for node in nodes: + yield node def get_possible_completion_types(grammar, stack): diff --git a/jedi/api/inference.py b/jedi/api/inference.py index 56bf4e55..345ddeec 100644 --- a/jedi/api/inference.py +++ b/jedi/api/inference.py @@ -24,7 +24,7 @@ def type_inference(evaluator, parser, user_context, position, dotted_path, is_co # matched to much. return [] - if isinstance(user_stmt, tree.Import): + if isinstance(user_stmt, tree.Import) and not is_completion: i, _ = helpers.get_on_import_stmt(evaluator, user_context, user_stmt, is_completion) if i is None: @@ -36,12 +36,13 @@ def type_inference(evaluator, parser, user_context, position, dotted_path, is_co if eval_stmt is None: return [] - module = evaluator.wrap(parser.module()) - names, level, _, _ = helpers.check_error_statements(module, position) - if names: - names = [str(n) for n in names] - i = imports.Importer(evaluator, names, module, level) - return i.follow() + if not is_completion: + module = evaluator.wrap(parser.module()) + names, level, _, _ = helpers.check_error_statements(module, position) + if names: + names = [str(n) for n in names] + i = imports.Importer(evaluator, names, module, level) + return i.follow() scopes = evaluator.eval_element(eval_stmt) diff --git a/jedi/parser/__init__.py b/jedi/parser/__init__.py index d00b587b..da6bf23e 100644 --- a/jedi/parser/__init__.py +++ b/jedi/parser/__init__.py @@ -20,7 +20,6 @@ import re from jedi.parser import tree as pt from jedi.parser import tokenize -from jedi.parser import token from jedi.parser.token import (DEDENT, INDENT, ENDMARKER, NEWLINE, NUMBER, STRING, OP, ERRORTOKEN) from jedi.parser.pgen2.pgen import generate_grammar @@ -75,14 +74,13 @@ class ErrorStatement(object): ) @property - def next_start_pos(self): + def end_pos(self): s = self._next_start_pos return s[0] + self._position_modifier.line, s[1] @property - def first_pos(self): - first_type, nodes = self.stack[0] - return nodes[0].start_pos + def start_pos(self): + return next(self._iter_nodes()).start_pos @property def first_type(self): @@ -96,8 +94,15 @@ class ErrorStatement(object): return True return False - def get_code(self): - return ''.join(node.get_code() for _, nodes in self.stack for node in nodes) + def _iter_nodes(self): + for _, nodes in self.stack: + for node in nodes: + yield node + + def get_code(self, include_prefix=True): + iterator = self._iter_nodes() + first = next(iterator) + return first.get_code(include_prefix=include_prefix) + ''.join(node.get_code() for node in iterator) class ParserSyntaxError(object): @@ -144,7 +149,7 @@ class Parser(object): self._used_names = {} self._scope_names_stack = [{}] - self._error_statement_stacks = [] + self._error_statements = [] self._last_failed_start_pos = (0, 0) self._global_names = [] @@ -164,20 +169,19 @@ class Parser(object): self._start_symbol = start_symbol self._grammar = grammar - self._tokenizer = tokenizer - if tokenizer is None: - self._tokenizer = tokenize.source_tokens(source, use_exact_op_types=True) self._parsed = None if start_parsing: - self.parse() + if tokenizer is None: + tokenizer = tokenize.source_tokens(source, use_exact_op_types=True) + self.parse(tokenizer) - def parse(self): + def parse(self, tokenizer): if self._parsed is not None: return self._parsed - self._parsed = self.pgen_parser.parse(self._tokenize(self._tokenizer)) + self._parsed = self.pgen_parser.parse(self._tokenize(tokenizer)) if self._start_symbol == 'file_input' != self._parsed.type: # If there's only one statement, we get back a non-module. That's @@ -198,7 +202,7 @@ class Parser(object): raise ParseError yield typ, value, prefix, start_pos - def error_recovery(self, grammar, stack, typ, value, start_pos, prefix, + def error_recovery(self, grammar, stack, arcs, typ, value, start_pos, prefix, add_token_callback): raise ParseError @@ -308,7 +312,6 @@ class Parser(object): endmarker._start_pos = newline._start_pos break - class ParserWithRecovery(Parser): """ This class is used to parse a Python file, it then divides them into a @@ -340,7 +343,7 @@ class ParserWithRecovery(Parser): self.module.used_names = self._used_names self.module.path = module_path self.module.global_names = self._global_names - self.module.error_statement_stacks = self._error_statement_stacks + self.module.error_statements = self._error_statements def error_recovery(self, grammar, stack, arcs, typ, value, start_pos, prefix, add_token_callback): @@ -427,7 +430,7 @@ class ParserWithRecovery(Parser): self._scope_names_stack.pop() if failed_stack: err = ErrorStatement(failed_stack, arcs, value, self.position_modifier, start_pos) - self._error_statement_stacks.append(err) + self._error_statements.append(err) self._last_failed_start_pos = start_pos diff --git a/jedi/parser/fast.py b/jedi/parser/fast.py index 0c0041f9..d0cbcc89 100644 --- a/jedi/parser/fast.py +++ b/jedi/parser/fast.py @@ -45,8 +45,8 @@ class FastModule(tree.Module): return [name for m in self.modules for name in m.global_names] @property - def error_statement_stacks(self): - return [e for m in self.modules for e in m.error_statement_stacks] + def error_statements(self): + return [e for m in self.modules for e in m.error_statements] def __repr__(self): return "" % (type(self).__name__, self.name, @@ -59,8 +59,8 @@ class FastModule(tree.Module): def global_names(self, value): pass - @error_statement_stacks.setter - def error_statement_stacks(self, value): + @error_statements.setter + def error_statements(self, value): pass @used_names.setter diff --git a/jedi/parser/tree.py b/jedi/parser/tree.py index 7824a634..6c21298b 100644 --- a/jedi/parser/tree.py +++ b/jedi/parser/tree.py @@ -38,6 +38,7 @@ from inspect import cleandoc from itertools import chain import textwrap +from jedi import common from jedi._compatibility import (Python3Method, encoding, is_py3, utf8_repr, literal_eval, use_metaclass, unicode) from jedi import cache @@ -196,6 +197,28 @@ class Base(object): def nodes_to_execute(self, last_added=False): raise NotImplementedError() + def get_code_with_error_statements(self, include_prefix=False): + module = self.get_parent_until() + source = self.get_code(include_prefix=include_prefix) + start_pos, end_pos = self.start_pos, self.end_pos + # Check for error statements that are inside the node. + error_statements = [ + e for e in module.error_statements + if start_pos <= e.start_pos and end_pos >= e.end_pos + ] + lines = common.splitlines(source) + # Note: Error statements must not be sorted. The positions are only + # correct if we insert them the way that they were tokenized. + for error_statement in error_statements: + line_index = error_statement.start_pos[0] - start_pos[0] + + line = lines[line_index] + index = error_statement.start_pos[1] + line = line[:index] + error_statement.get_code() + line[index:] + lines[line_index] = line + + return '\n'.join(lines) + class Leaf(Base): __slots__ = ('position_modifier', 'value', 'parent', '_start_pos', 'prefix') @@ -246,10 +269,13 @@ class Leaf(Base): except AttributeError: # A Leaf doesn't have children. return node - def get_code(self, normalized=False): + def get_code(self, normalized=False, include_prefix=True): if normalized: return self.value - return self.prefix + self.value + if include_prefix: + return self.prefix + self.value + else: + return self.value def next_sibling(self): """ @@ -304,11 +330,11 @@ class LeafWithNewLines(Leaf): end_pos_col = len(lines[-1]) return end_pos_line, end_pos_col - @utf8_repr def __repr__(self): return "<%s: %r>" % (type(self).__name__, self.value) + class Whitespace(LeafWithNewLines): """Contains NEWLINE and ENDMARKER tokens.""" __slots__ = () @@ -452,9 +478,13 @@ class BaseNode(Base): def end_pos(self): return self.children[-1].end_pos - def get_code(self, normalized=False): - # TODO implement normalized (dependin on context). - return "".join(c.get_code(normalized) for c in self.children) + def get_code(self, normalized=False, include_prefix=True): + # TODO implement normalized (depending on context). + if include_prefix: + return "".join(c.get_code(normalized) for c in self.children) + else: + first = self.children[0].get_code(include_prefix=False) + return first + "".join(c.get_code(normalized) for c in self.children[1:]) @Python3Method def name_for_position(self, position): @@ -468,6 +498,16 @@ class BaseNode(Base): return result return None + def get_leaf_for_position(self, position): + for c in self.children: + if c.start_pos <= position <= c.end_pos: + try: + return c.get_leaf_for_position(position) + except AttributeError: + return c + + raise ValueError("Position does not exist.") + @Python3Method def get_statement_for_position(self, pos): for c in self.children: @@ -633,7 +673,7 @@ class Module(Scope): of a module. """ __slots__ = ('path', 'global_names', 'used_names', '_name', - 'error_statement_stacks') + 'error_statements') type = 'file_input' def __init__(self, children):