From 48e25c1b9b0237eda3f8ce36186508b74b7e88b5 Mon Sep 17 00:00:00 2001 From: Dave Halter Date: Sun, 23 Feb 2020 23:22:00 +0100 Subject: [PATCH] Extract: Make sure params are not duplicated --- jedi/api/refactoring.py | 28 +++++++++++++++++++++------- test/refactor/extract_function.py | 19 +++++++++++++++++++ 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/jedi/api/refactoring.py b/jedi/api/refactoring.py index 4e936ff8..55c8b026 100644 --- a/jedi/api/refactoring.py +++ b/jedi/api/refactoring.py @@ -383,10 +383,9 @@ def _is_not_extractable_syntax(node): def extract_function(inference_state, path, module_context, name, pos, until_pos): is_expression = True nodes = _find_nodes(module_context.tree_node, pos, until_pos) - return_variables = [] context = module_context.create_context(nodes[0]) is_bound_method = context.is_bound_method() - params = list(_find_non_global_names(context, nodes)) + params, return_variables = list(_find_inputs_and_outputs(context, nodes)) dct = {} # Find variables @@ -436,17 +435,32 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos return Refactoring(inference_state.grammar, file_to_node_changes) +def _find_inputs_and_outputs(context, nodes): + inputs = [] + outputs = [] + for name in _find_non_global_names(context, nodes): + if name.is_definition(): + if name not in outputs: + outputs.append(name.value) + else: + if name.value not in inputs: + name_definitions = context.goto(name, name.start_pos) + if not name_definitions \ + or any(not n.parent_context.is_module() or n.api_type == 'param' + for n in name_definitions): + inputs.append(name.value) + + # Check if outputs are really needed: + return inputs, outputs + + def _find_non_global_names(context, nodes): for node in nodes: try: children = node.children except AttributeError: if node.type == 'name': - name_definitions = context.goto(node, node.start_pos) - if not name_definitions \ - or any(not n.parent_context.is_module() or n.api_type == 'param' - for n in name_definitions): - yield node.value + yield node else: # We only want to check foo in foo.bar if node.type == 'trailer' and node.children[0] == '.': diff --git a/test/refactor/extract_function.py b/test/refactor/extract_function.py index 62b27ce4..565691fa 100644 --- a/test/refactor/extract_function.py +++ b/test/refactor/extract_function.py @@ -63,6 +63,25 @@ class X: def f(x, b): #? 11 text {'new_name': 'ab'} return x.ab(b) +# -------------------------------------------------- in-method-2 +glob1 = 1 +class X: + def g(self): pass + + def f(self, b, c): + #? 11 text {'new_name': 'ab'} + return self.g() or self.f(b) ^ glob1 & b +# ++++++++++++++++++++++++++++++++++++++++++++++++++ +glob1 = 1 +class X: + def g(self): pass + + def ab(self, b): + return self.g() or self.f(b) ^ glob1 & b + + def f(self, b, c): + #? 11 text {'new_name': 'ab'} + return self.ab(b) # -------------------------------------------------- in-classmethod-1 class X: @classmethod