Further import completion improvements.

This commit is contained in:
Dave Halter
2016-05-28 02:08:43 +02:00
parent e4fe2a6d09
commit 4714b464a6
5 changed files with 162 additions and 29 deletions

View File

@@ -58,7 +58,7 @@ class Completion:
def __init__(self, evaluator, parser, user_context, position, call_signatures_method): def __init__(self, evaluator, parser, user_context, position, call_signatures_method):
self._evaluator = evaluator self._evaluator = evaluator
self._parser = parser self._parser = parser
self._module = parser.module() self._module = evaluator.wrap(parser.module())
self._user_context = user_context self._user_context = user_context
self._pos = position self._pos = position
self._call_signatures_method = call_signatures_method self._call_signatures_method = call_signatures_method
@@ -136,7 +136,7 @@ class Completion:
allowed_keywords, allowed_tokens = \ allowed_keywords, allowed_tokens = \
helpers.get_possible_completion_types(grammar, stack) helpers.get_possible_completion_types(grammar, stack)
print(allowed_keywords, allowed_tokens) print(allowed_keywords, [token.tok_name[a] for a in allowed_tokens])
completion_names = list(self._get_keyword_completion_names(allowed_keywords)) completion_names = list(self._get_keyword_completion_names(allowed_keywords))
if token.NAME in allowed_tokens: if token.NAME in allowed_tokens:
@@ -145,12 +145,39 @@ class Completion:
symbol_names = list(stack.get_node_names(grammar)) symbol_names = list(stack.get_node_names(grammar))
print(symbol_names) print(symbol_names)
nodes = list(stack.get_nodes())
last_symbol = symbol_names[-1]
if "import_stmt" in symbol_names: if "import_stmt" in symbol_names:
if symbol_names[-1] == "dotted_name": level = 0
completion_names += self._complete_dotted_name(stack, self._module) only_modules = True
elif symbol_names[-1] == "import_name": level, names = self._parse_dotted_names(nodes)
names = list(stack.get_nodes())[1::2] if "import_from" in symbol_names:
completion_names += self._get_importer_names(names) if 'import' in nodes:
only_modules = False
'''
if last_symbol == "dotted_name":
elif last_symbol == "import_from":
# No names are given yet, but the dots for level might be
# there.
if 'import' in nodes:
print(nodes[1])
raise NotImplementedError
else:
raise NotImplementedError
elif last_symbol == "import_name":
names = nodes[1::2]
completion_names += self._get_importer_names(names)
'''
else:
assert "import_name" in symbol_names
print(names, level)
completion_names += self._get_importer_names(
names,
level,
only_modules
)
else: else:
completion_names += self._simple_complete(completion_parts) completion_names += self._simple_complete(completion_parts)
@@ -238,18 +265,20 @@ class Completion:
) )
return completion_names return completion_names
def _complete_dotted_name(self, stack, module): def _parse_dotted_names(self, nodes):
nodes = list(stack.get_nodes())
level = 0 level = 0
for i, node in enumerate(nodes[1:], 1): names = []
for node in nodes[1:]:
if node in ('.', '...'): if node in ('.', '...'):
level += len(node.value) if not names:
level += len(node.value)
elif node.type == 'dotted_name':
names += node.children[::2]
elif node.type == 'name':
names.append(node)
else: else:
names = nodes[i::2]
break break
return level, names
return self._get_importer_names(names, level=level)
def _get_importer_names(self, names, level=0, only_modules=True): def _get_importer_names(self, names, level=0, only_modules=True):
names = [str(n) for n in names] names = [str(n) for n in names]

View File

@@ -73,7 +73,7 @@ def get_stack_at_position(grammar, module, pos):
Returns the possible node names (e.g. import_from, xor_test or yield_stmt). Returns the possible node names (e.g. import_from, xor_test or yield_stmt).
""" """
user_stmt = module.get_statement_for_position(pos) user_stmt = module.get_statement_for_position(pos)
if user_stmt is None: if user_stmt is None or user_stmt.type == 'whitespace':
# If there's no error statement and we're just somewhere, we want # If there's no error statement and we're just somewhere, we want
# completions for just whitespace. # completions for just whitespace.
code = '' code = ''
@@ -81,16 +81,23 @@ def get_stack_at_position(grammar, module, pos):
for error_statement in module.error_statements: for error_statement in module.error_statements:
if error_statement.start_pos < pos <= error_statement.end_pos: if error_statement.start_pos < pos <= error_statement.end_pos:
code = error_statement.get_code(include_prefix=False) code = error_statement.get_code(include_prefix=False)
start_pos = error_statement.start_pos node = error_statement
break break
else:
raise NotImplementedError
else: else:
code = user_stmt.get_code_with_error_statements(include_prefix=False) code = user_stmt.get_code_with_error_statements(include_prefix=False)
start_pos = user_stmt.start_pos node = user_stmt
# Remove indentations. # Make sure we include the whitespace after the statement as well, since it
code = code.lstrip() # could be where we would want to complete.
code = get_code_until(code, start_pos, pos) print('a', repr(code), node, repr(node.get_next_leaf().prefix))
# Remove whitespace at the end. code += node.get_next_leaf().prefix
code = get_code_until(code, node.start_pos, pos)
# Remove whitespace at the end. Necessary, because the tokenizer will parse
# an error token (there's no new line at the end in our case). This doesn't
# alter any truth about the valid tokens at that position.
code = code.rstrip() code = code.rstrip()
class EndMarkerReached(Exception): class EndMarkerReached(Exception):
@@ -101,7 +108,6 @@ def get_stack_at_position(grammar, module, pos):
if token_[0] == token.ENDMARKER: if token_[0] == token.ENDMARKER:
raise EndMarkerReached() raise EndMarkerReached()
else: else:
print(token_, token.tok_name[token_[0]])
yield token_ yield token_
print(repr(code)) print(repr(code))

View File

@@ -67,10 +67,10 @@ class ErrorStatement(object):
self._next_start_pos = next_start_pos self._next_start_pos = next_start_pos
def __repr__(self): def __repr__(self):
return '<%s next: %s@%s>' % ( return '<%s %s@%s>' % (
type(self).__name__, type(self).__name__,
repr(self.next_token), repr(self.next_token),
self.next_start_pos self.end_pos
) )
@property @property
@@ -104,6 +104,32 @@ class ErrorStatement(object):
first = next(iterator) first = next(iterator)
return first.get_code(include_prefix=include_prefix) + ''.join(node.get_code() for node in iterator) return first.get_code(include_prefix=include_prefix) + ''.join(node.get_code() for node in iterator)
def get_next_leaf(self):
for child in self.parent.children:
if child.start_pos == self.end_pos:
return child.first_leaf()
if child.start_pos > self.end_pos:
raise NotImplementedError('Node not found, must be in error statements.')
raise ValueError("Doesn't have a next leaf")
def set_parent(self, root_node):
"""
Used by the parser at the end of parsing. The error statements parents
have to be calculated at the end, because they are basically ripped out
of the stack at which time its parents don't yet exist..
"""
start_pos = self.start_pos
for c in root_node.children:
if c.start_pos < start_pos <= c.end_pos:
return self.set_parent(c)
self.parent = root_node
class ErrorToken(tree.LeafWithNewLines):
type = 'error_token'
class ParserSyntaxError(object): class ParserSyntaxError(object):
def __init__(self, message, position): def __init__(self, message, position):
@@ -193,6 +219,9 @@ class Parser(object):
if self._added_newline: if self._added_newline:
self.remove_last_newline() self.remove_last_newline()
for e in self._error_statements:
e.set_parent(self.get_parsed_node())
def get_parsed_node(self): def get_parsed_node(self):
return self._parsed return self._parsed
@@ -381,13 +410,23 @@ class ParserWithRecovery(Parser):
stack[index] stack[index]
#print('err', token.tok_name[typ], repr(value), start_pos, len(stack), index) #print('err', token.tok_name[typ], repr(value), start_pos, len(stack), index)
self._stack_removal(grammar, stack, arcs, index + 1, value, start_pos) if self._stack_removal(grammar, stack, arcs, index + 1, value, start_pos):
#add_token_callback(typ, value, prefix, start_pos)
pass
else:
#error_leaf = ErrorToken(self.position_modifier, value, start_pos, prefix)
#stack = [(None, [error_leaf])]
# TODO document the shizzle!
#self._error_statements.append(ErrorStatement(stack, None, None,
# self.position_modifier, error_leaf.end_pos))
return
if typ == INDENT: if typ == INDENT:
# For every deleted INDENT we have to delete a DEDENT as well. # For every deleted INDENT we have to delete a DEDENT as well.
# Otherwise the parser will get into trouble and DEDENT too early. # Otherwise the parser will get into trouble and DEDENT too early.
self._omit_dedent_list.append(self._indent_counter) self._omit_dedent_list.append(self._indent_counter)
if value in ('import', 'class', 'def', 'try', 'while', 'return'): if value in ('import', 'class', 'def', 'try', 'while', 'return', '\n'):
# Those can always be new statements. # Those can always be new statements.
add_token_callback(typ, value, prefix, start_pos) add_token_callback(typ, value, prefix, start_pos)
elif typ == DEDENT and symbol == 'suite': elif typ == DEDENT and symbol == 'suite':
@@ -435,6 +474,7 @@ class ParserWithRecovery(Parser):
self._last_failed_start_pos = start_pos self._last_failed_start_pos = start_pos
stack[start_index:] = [] stack[start_index:] = []
return failed_stack
def _tokenize(self, tokenizer): def _tokenize(self, tokenizer):
for typ, value, start_pos, prefix in tokenizer: for typ, value, start_pos, prefix in tokenizer:

View File

@@ -37,6 +37,7 @@ import re
from inspect import cleandoc from inspect import cleandoc
from itertools import chain from itertools import chain
import textwrap import textwrap
import abc
from jedi import common from jedi import common
from jedi._compatibility import (Python3Method, encoding, is_py3, utf8_repr, from jedi._compatibility import (Python3Method, encoding, is_py3, utf8_repr,
@@ -194,6 +195,7 @@ class Base(object):
# Default is not being a scope. Just inherit from Scope. # Default is not being a scope. Just inherit from Scope.
return False return False
@abc.abstractmethod
def nodes_to_execute(self, last_added=False): def nodes_to_execute(self, last_added=False):
raise NotImplementedError() raise NotImplementedError()
@@ -238,6 +240,12 @@ class Leaf(Base):
def start_pos(self, value): def start_pos(self, value):
self._start_pos = value[0] - self.position_modifier.line, value[1] self._start_pos = value[0] - self.position_modifier.line, value[1]
def get_start_pos_of_prefix(self):
try:
return self.get_previous().end_pos
except IndexError:
return 1, 0 # It's the first leaf.
@property @property
def end_pos(self): def end_pos(self):
return (self._start_pos[0] + self.position_modifier.line, return (self._start_pos[0] + self.position_modifier.line,
@@ -250,6 +258,8 @@ class Leaf(Base):
def get_previous(self): def get_previous(self):
""" """
Returns the previous leaf in the parser tree. Returns the previous leaf in the parser tree.
Raises an IndexError if it's the first element.
# TODO rename to get_previous_leaf
""" """
node = self node = self
while True: while True:
@@ -269,6 +279,33 @@ class Leaf(Base):
except AttributeError: # A Leaf doesn't have children. except AttributeError: # A Leaf doesn't have children.
return node return node
def first_leaf(self):
return self
def get_next_leaf(self):
"""
Returns the previous leaf in the parser tree.
Raises an IndexError if it's the last element.
"""
node = self
while True:
c = node.parent.children
i = c.index(self)
if i == len(c) - 1:
node = node.parent
if node.parent is None:
raise IndexError('Cannot access the next element of the last one.')
else:
node = c[i + 1]
break
while True:
try:
node = node.children[0]
except AttributeError: # A Leaf doesn't have children.
return node
def get_code(self, normalized=False, include_prefix=True): def get_code(self, normalized=False, include_prefix=True):
if normalized: if normalized:
return self.value return self.value
@@ -474,6 +511,9 @@ class BaseNode(Base):
def start_pos(self): def start_pos(self):
return self.children[0].start_pos return self.children[0].start_pos
def get_start_pos_of_prefix(self):
return self.children[0].get_start_pos_of_prefix()
@property @property
def end_pos(self): def end_pos(self):
return self.children[-1].end_pos return self.children[-1].end_pos
@@ -498,9 +538,14 @@ class BaseNode(Base):
return result return result
return None return None
def get_leaf_for_position(self, position): def get_leaf_for_position(self, position, include_prefixes=False):
for c in self.children: for c in self.children:
if c.start_pos <= position <= c.end_pos: if include_prefixes:
start_pos = c.get_start_pos_with_prefix()
else:
start_pos = c.start_pos
if start_pos <= position <= c.end_pos:
try: try:
return c.get_leaf_for_position(position) return c.get_leaf_for_position(position)
except AttributeError: except AttributeError:
@@ -528,6 +573,17 @@ class BaseNode(Base):
except AttributeError: except AttributeError:
return self.children[0] return self.children[0]
def get_next_leaf(self):
"""
Raises an IndexError if it's the last node. (Would be the module)
"""
c = self.parent.children
index = c.index(self)
if index == len(c) - 1:
return self.get_next_leaf()
else:
return c[index + 1]
@utf8_repr @utf8_repr
def __repr__(self): def __repr__(self):
code = self.get_code().replace('\n', ' ').strip() code = self.get_code().replace('\n', ' ').strip()

View File

@@ -63,6 +63,8 @@ import datetime.date
#? 21 ['import'] #? 21 ['import']
from import_tree.pkg import pkg from import_tree.pkg import pkg
#? ['mod1', 'mod2', 'random', 'pkg', 'rename1', 'rename2', 'recurse_class1', 'recurse_class2', 'invisible_pkg', 'flow_import']
from import_tree.pkg import pkg,
#? 22 ['mod1'] #? 22 ['mod1']
from import_tree.pkg. import mod1 from import_tree.pkg. import mod1
#? 17 ['mod1', 'mod2', 'random', 'pkg', 'rename1', 'rename2', 'recurse_class1', 'recurse_class2', 'invisible_pkg', 'flow_import'] #? 17 ['mod1', 'mod2', 'random', 'pkg', 'rename1', 'rename2', 'recurse_class1', 'recurse_class2', 'invisible_pkg', 'flow_import']