diff --git a/jedi/api/refactoring/extract.py b/jedi/api/refactoring/extract.py index f9e2397b..6a7ec7ba 100644 --- a/jedi/api/refactoring/extract.py +++ b/jedi/api/refactoring/extract.py @@ -210,7 +210,7 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos is_expression, _ = _is_expression_with_error(nodes) context = module_context.create_context(nodes[0]) is_bound_method = context.is_bound_method() - params, return_variables = list(_find_inputs_and_outputs(context, nodes)) + params, return_variables = list(_find_inputs_and_outputs(module_context, context, nodes)) # Find variables # Is a class method / method @@ -280,7 +280,22 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos return Refactoring(inference_state.grammar, file_to_node_changes) -def _find_inputs_and_outputs(context, nodes): +def _is_name_input(module_context, names, first, last): + for name in names: + if name.api_type == 'param' or not name.parent_context.is_module(): + if name.get_root_context() is not module_context: + print('true') + return True + if name.start_pos is None or not (first <= name.start_pos < last): + print('true1', first, name.start_pos, last) + return True + return False + + +def _find_inputs_and_outputs(module_context, context, nodes): + first = nodes[0].start_pos + last = nodes[-1].end_pos + inputs = [] outputs = [] for name in _find_non_global_names(nodes): @@ -291,8 +306,7 @@ def _find_inputs_and_outputs(context, nodes): if name.value not in inputs: name_definitions = context.goto(name, name.start_pos) if not name_definitions \ - or any(not n.parent_context.is_module() or n.api_type == 'param' - for n in name_definitions): + or _is_name_input(module_context, name_definitions, first, last): inputs.append(name.value) # Check if outputs are really needed: diff --git a/test/refactor/extract_function.py b/test/refactor/extract_function.py index 103f7d6e..e8b0491f 100644 --- a/test/refactor/extract_function.py +++ b/test/refactor/extract_function.py @@ -227,7 +227,7 @@ def x(v1): # ++++++++++++++++++++++++++++++++++++++++++++++++++ import os # comment1 -def a(v1, v2, v3): +def a(v1, v3): v2 = 2 if 1: x, y = os.listdir(v1 + v2 * v3) @@ -239,7 +239,7 @@ def a(v1, v2, v3): def x(v1): #foo #? 2 text {'new_name': 'a', 'until_line': 9, 'until_column': 5} - x, y = a(v1, v2, v3) + x, y = a(v1, v3) #bar return x, y # -------------------------------------------------- with-range-func-2 @@ -259,7 +259,7 @@ x import os # comment1 # comment2 -def a(v1, v2, v3): +def a(v1, v3): #foo v2 = 2 if 1: @@ -270,7 +270,7 @@ def a(v1, v2, v3): def x(v1): #? 2 text {'new_name': 'a', 'until_line': 10, 'until_column': 0} - y = a(v1, v2, v3) + y = a(v1, v3) return y x # -------------------------------------------------- with-range-func-3 @@ -298,3 +298,18 @@ def x(v1): #bar return x x +# -------------------------------------------------- in-class-1 +class X1: + #? 11 text {'new_name': 'f', 'until_line': 4} + a = 3 + c = a + 2 +# ++++++++++++++++++++++++++++++++++++++++++++++++++ +def f(): + a = 3 + c = a + 2 + return c + + +class X1: + #? 11 text {'new_name': 'f', 'until_line': 4} + c = f()