diff --git a/jedi/api/__init__.py b/jedi/api/__init__.py index 912f911d..934d2773 100644 --- a/jedi/api/__init__.py +++ b/jedi/api/__init__.py @@ -592,8 +592,9 @@ class Script(object): raise TypeError('If you provide until_line you have to provide until_column') until_pos = until_line, until_column return refactoring.extract_function( - self._inference_state, self.path, self._module_node, new_name, - (line, column), until_pos) + self._inference_state, self.path, self._get_module_context(), + new_name, (line, column), until_pos + ) @no_py2_support def inline(self, line=None, column=None): diff --git a/jedi/api/refactoring.py b/jedi/api/refactoring.py index 74626fcf..23996351 100644 --- a/jedi/api/refactoring.py +++ b/jedi/api/refactoring.py @@ -278,23 +278,24 @@ def _find_nodes(module_node, pos, until_pos): return nodes -def _replace(nodes, expression_replacement, extracted): +def _replace(nodes, expression_replacement, extracted, insert_before_leaf=None): # Now try to replace the nodes found with a variable and move the code # before the current statement. definition = _get_parent_definition(nodes[0]) - first_definition_leaf = definition.get_first_leaf() + if insert_before_leaf is None: + insert_before_leaf = definition.get_first_leaf() replacement_dct = {} extracted_prefix = _insert_line_before( - first_definition_leaf.prefix, + insert_before_leaf.prefix, extracted, ) first_node_leaf = nodes[0].get_first_leaf() - if first_node_leaf is first_definition_leaf: + if first_node_leaf is insert_before_leaf: replacement_dct[nodes[0]] = extracted_prefix + expression_replacement else: replacement_dct[nodes[0]] = first_node_leaf.prefix + expression_replacement - replacement_dct[first_definition_leaf] = extracted_prefix + first_definition_leaf.value + replacement_dct[insert_before_leaf] = extracted_prefix + insert_before_leaf.value for node in nodes[1:]: replacement_dct[node] = '' @@ -371,20 +372,28 @@ def _is_not_extractable_syntax(node): or node.type == 'keyword' and node.value not in ('None', 'True', 'False') -def extract_function(inference_state, path, module_node, name, pos, until_pos): +def extract_function(inference_state, path, module_context, name, pos, until_pos): # 1. extract expression is_class_method = False is_method = False is_expression = True class_indentation = '' # 2. extract statements - nodes = _find_nodes(module_node, pos, until_pos) + nodes = _find_nodes(module_context.tree_node, pos, until_pos) return_variables = [] params = _find_non_global_names(nodes) dct = {} # Find variables # Is a class method / method + context = module_context.create_context(nodes[0]) + if context.is_module(): + insert_before_leaf = None # Leaf will be determined later + else: + node = context.tree_node + while node.parent.type in ('async_funcdef', 'decorated', 'async_stmt'): + node = node.parent + insert_before_leaf = node.get_first_leaf() if is_expression: code_block = 'return ' + _expression_nodes_to_string(nodes) else: @@ -409,7 +418,7 @@ def extract_function(inference_state, path, module_node, name, pos, until_pos): dct[nodes[0]] = replacement for node in nodes[1:]: dct[node] = '' - file_to_node_changes = {path: _replace(nodes, replacement, function_code)} + file_to_node_changes = {path: _replace(nodes, replacement, function_code, insert_before_leaf)} return Refactoring(inference_state.grammar, file_to_node_changes) diff --git a/test/refactor/extract_function.py b/test/refactor/extract_function.py index 7eca733a..10c87025 100644 --- a/test/refactor/extract_function.py +++ b/test/refactor/extract_function.py @@ -18,3 +18,29 @@ def ab(): ab() +# -------------------------------------------------- in-function-1 +def f(x): +#? 11 text {'new_name': 'ab'} + return x + 1 * 2 +# ++++++++++++++++++++++++++++++++++++++++++++++++++ +def ab(): + return x + 1 * 2 + + +def f(x): +#? 11 text {'new_name': 'ab'} + return ab() +# -------------------------------------------------- in-function-with-dec +@classmethod +def f(x): +#? 11 text {'new_name': 'ab'} + return x + 1 * 2 +# ++++++++++++++++++++++++++++++++++++++++++++++++++ +def ab(): + return x + 1 * 2 + + +@classmethod +def f(x): +#? 11 text {'new_name': 'ab'} + return ab()