Deal a lot better with prefixes in range extractions

This commit is contained in:
Dave Halter
2020-02-25 10:23:38 +01:00
parent f8d9f498d0
commit 89398e5c87
2 changed files with 44 additions and 20 deletions

View File

@@ -233,7 +233,7 @@ def extract_variable(grammar, path, module_node, name, pos, until_pos):
raise RefactoringError(message) raise RefactoringError(message)
generated_code = name + ' = ' + _expression_nodes_to_string(nodes) generated_code = name + ' = ' + _expression_nodes_to_string(nodes)
file_to_node_changes = {path: _replace(nodes, name, generated_code)} file_to_node_changes = {path: _replace(nodes, name, generated_code, pos)}
return Refactoring(grammar, file_to_node_changes) return Refactoring(grammar, file_to_node_changes)
@@ -294,19 +294,22 @@ def _find_nodes(module_node, pos, until_pos):
return nodes return nodes
def _replace(nodes, expression_replacement, extracted, insert_before_leaf=None): def _replace(nodes, expression_replacement, extracted, pos, insert_before_leaf=None):
# Now try to replace the nodes found with a variable and move the code # Now try to replace the nodes found with a variable and move the code
# before the current statement. # before the current statement.
definition = _get_parent_definition(nodes[0]) definition = _get_parent_definition(nodes[0])
if insert_before_leaf is None: if insert_before_leaf is None:
insert_before_leaf = definition.get_first_leaf() insert_before_leaf = definition.get_first_leaf()
first_node_leaf = nodes[0].get_first_leaf()
lines = split_lines(insert_before_leaf.prefix, keepends=True)
if first_node_leaf is insert_before_leaf:
removed_line_count = nodes[0].start_pos[0] - pos[0]
lines = lines[:-removed_line_count - 1] + [lines[-1]]
lines[-1:-1] = [indent_block(extracted, lines[-1]) + '\n']
extracted_prefix = ''.join(lines)
replacement_dct = {} replacement_dct = {}
extracted_prefix = _insert_line_before(
insert_before_leaf.prefix,
extracted,
)
first_node_leaf = nodes[0].get_first_leaf()
if first_node_leaf is insert_before_leaf: if first_node_leaf is insert_before_leaf:
replacement_dct[nodes[0]] = extracted_prefix + expression_replacement replacement_dct[nodes[0]] = extracted_prefix + expression_replacement
else: else:
@@ -322,6 +325,14 @@ def _expression_nodes_to_string(nodes):
return ''.join(n.get_code(include_prefix=i != 0) for i, n in enumerate(nodes)) return ''.join(n.get_code(include_prefix=i != 0) for i, n in enumerate(nodes))
def _suite_nodes_to_string(nodes, pos):
n = nodes[0]
included_line_count = n.start_pos[0] - pos[0]
lines = split_lines(n.get_first_leaf().prefix, keepends=True)[-included_line_count - 1]
return ''.join(lines) + n.get_code(include_prefix=False) \
+ ''.join(n.get_code() for n in nodes[1:])
def _remove_indent_of_prefix(prefix): def _remove_indent_of_prefix(prefix):
r""" r"""
Removes the last indentation of a prefix, e.g. " \n \n " becomes " \n \n". Removes the last indentation of a prefix, e.g. " \n \n " becomes " \n \n".
@@ -333,12 +344,6 @@ def _get_indentation(node):
return split_lines(node.get_first_leaf().prefix)[-1] return split_lines(node.get_first_leaf().prefix)[-1]
def _insert_line_before(prefix, code):
lines = split_lines(prefix, keepends=True)
lines[-1:-1] = [indent_block(code, lines[-1]) + '\n']
return ''.join(lines)
def _get_parent_definition(node): def _get_parent_definition(node):
""" """
Returns the statement where a node is defined. Returns the statement where a node is defined.
@@ -398,7 +403,6 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos
is_bound_method = context.is_bound_method() is_bound_method = context.is_bound_method()
params, return_variables = list(_find_inputs_and_outputs(context, nodes)) params, return_variables = list(_find_inputs_and_outputs(context, nodes))
dct = {}
# Find variables # Find variables
# Is a class method / method # Is a class method / method
if context.is_module(): if context.is_module():
@@ -420,7 +424,7 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos
)) or [return_variables[-1]] )) or [return_variables[-1]]
output_var_str = ', '.join(return_variables) output_var_str = ', '.join(return_variables)
code_block = dedent(''.join(n.get_code() for n in nodes)) code_block = dedent(_suite_nodes_to_string(nodes, pos))
code_block += 'return ' + output_var_str + '\n' code_block += 'return ' + output_var_str + '\n'
decorator = '' decorator = ''
@@ -453,10 +457,8 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos
else: else:
replacement = _get_indentation(nodes[0]) + output_var_str + ' = ' + function_call replacement = _get_indentation(nodes[0]) + output_var_str + ' = ' + function_call
dct[nodes[0]] = replacement replaced_str = _replace(nodes, replacement, function_code, pos, insert_before_leaf)
for node in nodes[1:]: file_to_node_changes = {path: replaced_str}
dct[node] = ''
file_to_node_changes = {path: _replace(nodes, replacement, function_code, insert_before_leaf)}
return Refactoring(inference_state.grammar, file_to_node_changes) return Refactoring(inference_state.grammar, file_to_node_changes)

View File

@@ -159,6 +159,7 @@ def x(z):
v1 = 3 v1 = 3
v2 = 2 v2 = 2
x = test(v1 + v2 * v3) x = test(v1 + v2 * v3)
# ++++++++++++++++++++++++++++++++++++++++++++++++++ # ++++++++++++++++++++++++++++++++++++++++++++++++++
#? 0 text {'new_name': 'a', 'until_line': 4} #? 0 text {'new_name': 'a', 'until_line': 4}
def a(test, v3): def a(test, v3):
@@ -168,4 +169,25 @@ def a(test, v3):
return x return x
a(test, v3) x = a(test, v3)
# -------------------------------------------------- with-range-2
#? 2 text {'new_name': 'a', 'until_line': 6, 'until_column': 4}
#foo
v1 = 3
v2 = 2
x, y = test(v1 + v2 * v3)
#raaaa
y
# ++++++++++++++++++++++++++++++++++++++++++++++++++
#? 2 text {'new_name': 'a', 'until_line': 6, 'until_column': 4}
def a(test, v3):
#foo
v1 = 3
v2 = 2
x, y = test(v1 + v2 * v3)
#raaaa
return y
y = a(test, v3)
y