diff --git a/jedi/api/refactoring.py b/jedi/api/refactoring.py index 1d5ae853..b316aa92 100644 --- a/jedi/api/refactoring.py +++ b/jedi/api/refactoring.py @@ -8,11 +8,11 @@ from parso import split_lines from jedi.api.exceptions import RefactoringError from jedi.common.utils import indent_block -_INLINE_NEEDS_BRACKET = ( - 'xor_expr and_expr shift_expr arith_expr term factor power atom_expr ' +_EXPRESSION_PARTS = ( 'or_test and_test not_test comparison' + 'xor_expr and_expr shift_expr arith_expr term factor power atom_expr ' ).split() -_EXTRACT_USE_PARENT = _INLINE_NEEDS_BRACKET + ['trailer'] +_EXTRACT_USE_PARENT = _EXPRESSION_PARTS + ['trailer'] _DEFINITION_SCOPES = ('suite', 'file_input') _NON_EXCTRACABLE = ('param',) @@ -189,7 +189,7 @@ def inline(grammar, names): path = name.get_root_context().py__file__() s = replace_code if rhs.type == 'testlist_star_expr' \ - or tree_name.parent.type in _INLINE_NEEDS_BRACKET \ + or tree_name.parent.type in _EXPRESSION_PARTS \ or tree_name.parent.type == 'trailer' \ and tree_name.parent.get_next_sibling() is not None: s = '(' + replace_code + ')' @@ -242,7 +242,8 @@ def extract_variable(grammar, path, module_node, new_name, pos, until_pos): parent_node = start_leaf while parent_node.end_pos < end_leaf.end_pos: parent_node = parent_node.parent - nodes = [parent_node] + + nodes = _remove_unwanted_expression_nodes(parent_node, pos, until_pos) if any(node.type == 'name' and node.is_definition() for node in nodes): raise RefactoringError('Cannot extract a definition of a name') if any(node.type in _NON_EXCTRACABLE for node in nodes) \ @@ -293,3 +294,27 @@ def _get_parent_definition(node): return node node = node.parent raise NotImplementedError('We should never even get here') + + +def _remove_unwanted_expression_nodes(parent_node, pos, until_pos): + """ + This function makes it so for `1 * 2 + 3` you can extract `2 + 3`, even + though it is not part of the expression. + """ + if parent_node.type in _EXPRESSION_PARTS: + nodes = parent_node.children + for i, n in enumerate(nodes): + if n.end_pos > pos: + start_index = i + if n.type == 'operator': + start_index -= 1 + break + for i, n in reversed(list(enumerate(nodes))): + if n.start_pos <= until_pos: + end_index = i + if n.type == 'operator': + end_index += 1 + break + print(nodes, start_index, end_index) + return nodes[start_index:end_index + 1] + return [parent_node] diff --git a/test/refactor/extract_variable.py b/test/refactor/extract_variable.py index f80c59af..4b4c6fb4 100644 --- a/test/refactor/extract_variable.py +++ b/test/refactor/extract_variable.py @@ -147,10 +147,10 @@ y + 1, 3 #? 1 text {'new_name': 'x', 'until_column': 1} x = y x + 1, 3 -# -------------------------------------------------- range-5 -#? 0 text {'new_name': 'x', 'until_column': 1} -y + 1, 3 +# -------------------------------------------------- addition-1 +#? 4 text {'new_name': 'x', 'until_column': 9} +z = y + 1 + 2+ 3, 3 # ++++++++++++++++++++++++++++++++++++++++++++++++++ -#? 0 text {'new_name': 'x', 'until_column': 1} -x = y -x + 1, 3 +#? 4 text {'new_name': 'x', 'until_column': 9} +x = y + 1 +z = x + 2+ 3, 3