From a92c28840b254695e5e58e744e48f9f301be6c95 Mon Sep 17 00:00:00 2001 From: Dave Halter Date: Wed, 26 Feb 2020 09:31:33 +0100 Subject: [PATCH] Fix: Extract can now deal with return statements at the end --- jedi/api/refactoring/extract.py | 39 +++++++++++++++++++++---------- test/refactor/extract_function.py | 26 +++++++++++++++++++++ 2 files changed, 53 insertions(+), 12 deletions(-) diff --git a/jedi/api/refactoring/extract.py b/jedi/api/refactoring/extract.py index 6a7ec7ba..d0ef1015 100644 --- a/jedi/api/refactoring/extract.py +++ b/jedi/api/refactoring/extract.py @@ -207,6 +207,8 @@ def _is_not_extractable_syntax(node): def extract_function(inference_state, path, module_context, name, pos, until_pos): nodes = _find_nodes(module_context.tree_node, pos, until_pos) + assert len(nodes) + is_expression, _ = _is_expression_with_error(nodes) context = module_context.create_context(nodes[0]) is_bound_method = context.is_bound_method() @@ -223,15 +225,17 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos code_block = 'return ' + _expression_nodes_to_string(nodes) + '\n' remaining_prefix = None else: - # Find the actually used variables (of the defined ones). If none are - # used (e.g. if the range covers the whole function), return the last - # defined variable. - return_variables = list(_find_needed_output_variables( - context, - nodes[0].parent, - nodes[-1].end_pos, - return_variables - )) or [return_variables[-1]] + has_ending_return_stmt = _is_node_ending_return_stmt(nodes[-1]) + if not has_ending_return_stmt: + # Find the actually used variables (of the defined ones). If none are + # used (e.g. if the range covers the whole function), return the last + # defined variable. + return_variables = list(_find_needed_output_variables( + context, + nodes[0].parent, + nodes[-1].end_pos, + return_variables + )) or [return_variables[-1]] remaining_prefix, code_block = _suite_nodes_to_string(nodes, pos) after_leaf = nodes[-1].get_next_leaf() @@ -239,8 +243,9 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos code_block += first code_block = dedent(code_block) - output_var_str = ', '.join(return_variables) - code_block += 'return ' + output_var_str + '\n' + if not has_ending_return_stmt: + output_var_str = ', '.join(return_variables) + code_block += 'return ' + output_var_str + '\n' decorator = '' self_param = None @@ -270,7 +275,10 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos if is_expression: replacement = function_call else: - replacement = output_var_str + ' = ' + function_call + '\n' + if has_ending_return_stmt: + replacement = 'return ' + function_call + '\n' + else: + replacement = output_var_str + ' = ' + function_call + '\n' replacement_dct = _replace(nodes, replacement, function_code, pos, insert_before_leaf, remaining_prefix) @@ -353,3 +361,10 @@ def _find_needed_output_variables(context, search_node, at_least_pos, return_var if not name.is_definition() and name.value in return_variables: return_variables.remove(name.value) yield name.value + + +def _is_node_ending_return_stmt(node): + t = node.type + if t == 'simple_stmt': + return _is_node_ending_return_stmt(node.children[0]) + return t == 'return_stmt' diff --git a/test/refactor/extract_function.py b/test/refactor/extract_function.py index 55200507..cea730df 100644 --- a/test/refactor/extract_function.py +++ b/test/refactor/extract_function.py @@ -345,3 +345,29 @@ class X: #? 11 text {'new_name': 'ab', 'until_line': 12, 'until_column': 28} x = self.ab(b) # bar +# -------------------------------------------------- in-method-range-2 +glob1 = 1 +class X: + # comment + + def f(self, b, c): + #? 11 text {'new_name': 'ab', 'until_line': 11, 'until_column': 10} + #foo + local1 = 3 + local2 = 4 + return local1 * glob1 * b + # bar +# ++++++++++++++++++++++++++++++++++++++++++++++++++ +glob1 = 1 +class X: + # comment + + def ab(self, b): + #foo + local1 = 3 + local2 = 4 + return local1 * glob1 * b + + def f(self, b, c): + #? 11 text {'new_name': 'ab', 'until_line': 11, 'until_column': 10} + return self.ab(b)