diff --git a/jedi/parser/pytree.py b/jedi/parser/pytree.py index e6447dac..e5b09367 100644 --- a/jedi/parser/pytree.py +++ b/jedi/parser/pytree.py @@ -45,19 +45,6 @@ python_grammar_no_print_statement = python_grammar.copy() del python_grammar_no_print_statement.keywords["print"] -from jedi.parser import representation as pr -_ast_mapping = { - 'simple_stmt': pr.ExprStmt, - 'classdef': pr.Class, - 'funcdef': pr.Function, - 'file_input': pr.SubModule, - 'import_name': pr.Import, - 'import_from': pr.Import, -} - -ast_mapping = dict((getattr(python_symbols, k), v) for k, v in _ast_mapping.items()) - - def type_repr(type_num): global _type_reprs if not _type_reprs: @@ -147,6 +134,20 @@ def convert(grammar, raw_node): grammar rule produces a new complete node, so that the tree is build strictly bottom-up. """ + + from jedi.parser import representation as pr + _ast_mapping = { + 'simple_stmt': pr.ExprStmt, + 'classdef': pr.Class, + 'funcdef': pr.Function, + 'file_input': pr.SubModule, + 'import_name': pr.Import, + 'import_from': pr.Import, + } + + ast_mapping = dict((getattr(python_symbols, k), v) for k, v in _ast_mapping.items()) + + #import pdb; pdb.set_trace() type, value, context, children = raw_node if type in grammar.number2symbol: diff --git a/jedi/parser/representation.py b/jedi/parser/representation.py index ae040631..0f81239b 100644 --- a/jedi/parser/representation.py +++ b/jedi/parser/representation.py @@ -45,11 +45,17 @@ from jedi import common from jedi import debug from jedi import cache from jedi.parser import tokenize +from jedi.parser.pytree import python_symbols, Node SCOPE_CONTENTS = 'asserts', 'subscopes', 'imports', 'statements', 'returns' +def is_node(node, symbol_name): + return isinstance(node, Node) \ + and getattr(python_symbols, symbol_name) == node.type + + def filter_after_position(names, position): """ Removes all names after a certain position. If position is None, just @@ -978,8 +984,15 @@ class Statement(Simple, DocstringMixin): self.expression_list() def get_defined_names(self): - if isinstance(self.children[0], Import): - return self.children[0].get_defined_names() + first = self.children[0] # children[1] is always a newline. + if isinstance(first, Import): + return first.get_defined_names() + elif is_node(first, 'expr_stmt'): + names = [] + for i in range(0, len(first.children) - 2, 2): + if first.children[i + 1].value == '=': + names.append(first.children[i]) + return names return [] """Get the names for the statement.""" if self._set_vars is None: