Extract: Make sure params are not duplicated

This commit is contained in:
Dave Halter
2020-02-23 23:22:00 +01:00
parent da935baa99
commit 48e25c1b9b
2 changed files with 40 additions and 7 deletions

View File

@@ -383,10 +383,9 @@ def _is_not_extractable_syntax(node):
def extract_function(inference_state, path, module_context, name, pos, until_pos):
is_expression = True
nodes = _find_nodes(module_context.tree_node, pos, until_pos)
return_variables = []
context = module_context.create_context(nodes[0])
is_bound_method = context.is_bound_method()
params = list(_find_non_global_names(context, nodes))
params, return_variables = list(_find_inputs_and_outputs(context, nodes))
dct = {}
# Find variables
@@ -436,17 +435,32 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos
return Refactoring(inference_state.grammar, file_to_node_changes)
def _find_inputs_and_outputs(context, nodes):
inputs = []
outputs = []
for name in _find_non_global_names(context, nodes):
if name.is_definition():
if name not in outputs:
outputs.append(name.value)
else:
if name.value not in inputs:
name_definitions = context.goto(name, name.start_pos)
if not name_definitions \
or any(not n.parent_context.is_module() or n.api_type == 'param'
for n in name_definitions):
inputs.append(name.value)
# Check if outputs are really needed:
return inputs, outputs
def _find_non_global_names(context, nodes):
for node in nodes:
try:
children = node.children
except AttributeError:
if node.type == 'name':
name_definitions = context.goto(node, node.start_pos)
if not name_definitions \
or any(not n.parent_context.is_module() or n.api_type == 'param'
for n in name_definitions):
yield node.value
yield node
else:
# We only want to check foo in foo.bar
if node.type == 'trailer' and node.children[0] == '.':

View File

@@ -63,6 +63,25 @@ class X:
def f(x, b):
#? 11 text {'new_name': 'ab'}
return x.ab(b)
# -------------------------------------------------- in-method-2
glob1 = 1
class X:
def g(self): pass
def f(self, b, c):
#? 11 text {'new_name': 'ab'}
return self.g() or self.f(b) ^ glob1 & b
# ++++++++++++++++++++++++++++++++++++++++++++++++++
glob1 = 1
class X:
def g(self): pass
def ab(self, b):
return self.g() or self.f(b) ^ glob1 & b
def f(self, b, c):
#? 11 text {'new_name': 'ab'}
return self.ab(b)
# -------------------------------------------------- in-classmethod-1
class X:
@classmethod