diff --git a/jedi/api/refactoring.py b/jedi/api/refactoring.py index e550f44a..c86a25ce 100644 --- a/jedi/api/refactoring.py +++ b/jedi/api/refactoring.py @@ -233,7 +233,7 @@ def extract_variable(grammar, path, module_node, name, pos, until_pos): raise RefactoringError(message) generated_code = name + ' = ' + _expression_nodes_to_string(nodes) - file_to_node_changes = {path: _replace(nodes, name, generated_code)} + file_to_node_changes = {path: _replace(nodes, name, generated_code, pos)} return Refactoring(grammar, file_to_node_changes) @@ -294,19 +294,22 @@ def _find_nodes(module_node, pos, until_pos): return nodes -def _replace(nodes, expression_replacement, extracted, insert_before_leaf=None): +def _replace(nodes, expression_replacement, extracted, pos, insert_before_leaf=None): # Now try to replace the nodes found with a variable and move the code # before the current statement. definition = _get_parent_definition(nodes[0]) if insert_before_leaf is None: insert_before_leaf = definition.get_first_leaf() + first_node_leaf = nodes[0].get_first_leaf() + + lines = split_lines(insert_before_leaf.prefix, keepends=True) + if first_node_leaf is insert_before_leaf: + removed_line_count = nodes[0].start_pos[0] - pos[0] + lines = lines[:-removed_line_count - 1] + [lines[-1]] + lines[-1:-1] = [indent_block(extracted, lines[-1]) + '\n'] + extracted_prefix = ''.join(lines) replacement_dct = {} - extracted_prefix = _insert_line_before( - insert_before_leaf.prefix, - extracted, - ) - first_node_leaf = nodes[0].get_first_leaf() if first_node_leaf is insert_before_leaf: replacement_dct[nodes[0]] = extracted_prefix + expression_replacement else: @@ -322,6 +325,14 @@ def _expression_nodes_to_string(nodes): return ''.join(n.get_code(include_prefix=i != 0) for i, n in enumerate(nodes)) +def _suite_nodes_to_string(nodes, pos): + n = nodes[0] + included_line_count = n.start_pos[0] - pos[0] + lines = split_lines(n.get_first_leaf().prefix, keepends=True)[-included_line_count - 1] + return ''.join(lines) + n.get_code(include_prefix=False) \ + + ''.join(n.get_code() for n in nodes[1:]) + + def _remove_indent_of_prefix(prefix): r""" Removes the last indentation of a prefix, e.g. " \n \n " becomes " \n \n". @@ -333,12 +344,6 @@ def _get_indentation(node): return split_lines(node.get_first_leaf().prefix)[-1] -def _insert_line_before(prefix, code): - lines = split_lines(prefix, keepends=True) - lines[-1:-1] = [indent_block(code, lines[-1]) + '\n'] - return ''.join(lines) - - def _get_parent_definition(node): """ Returns the statement where a node is defined. @@ -398,7 +403,6 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos is_bound_method = context.is_bound_method() params, return_variables = list(_find_inputs_and_outputs(context, nodes)) - dct = {} # Find variables # Is a class method / method if context.is_module(): @@ -420,7 +424,7 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos )) or [return_variables[-1]] output_var_str = ', '.join(return_variables) - code_block = dedent(''.join(n.get_code() for n in nodes)) + code_block = dedent(_suite_nodes_to_string(nodes, pos)) code_block += 'return ' + output_var_str + '\n' decorator = '' @@ -453,10 +457,8 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos else: replacement = _get_indentation(nodes[0]) + output_var_str + ' = ' + function_call - dct[nodes[0]] = replacement - for node in nodes[1:]: - dct[node] = '' - file_to_node_changes = {path: _replace(nodes, replacement, function_code, insert_before_leaf)} + replaced_str = _replace(nodes, replacement, function_code, pos, insert_before_leaf) + file_to_node_changes = {path: replaced_str} return Refactoring(inference_state.grammar, file_to_node_changes) diff --git a/test/refactor/extract_function.py b/test/refactor/extract_function.py index a53794ac..ab589165 100644 --- a/test/refactor/extract_function.py +++ b/test/refactor/extract_function.py @@ -159,6 +159,7 @@ def x(z): v1 = 3 v2 = 2 x = test(v1 + v2 * v3) + # ++++++++++++++++++++++++++++++++++++++++++++++++++ #? 0 text {'new_name': 'a', 'until_line': 4} def a(test, v3): @@ -168,4 +169,25 @@ def a(test, v3): return x -a(test, v3) +x = a(test, v3) +# -------------------------------------------------- with-range-2 +#? 2 text {'new_name': 'a', 'until_line': 6, 'until_column': 4} +#foo +v1 = 3 +v2 = 2 +x, y = test(v1 + v2 * v3) +#raaaa +y +# ++++++++++++++++++++++++++++++++++++++++++++++++++ +#? 2 text {'new_name': 'a', 'until_line': 6, 'until_column': 4} +def a(test, v3): + #foo + v1 = 3 + v2 = 2 + x, y = test(v1 + v2 * v3) + #raaaa + return y + + +y = a(test, v3) +y