Get a first extract test mostly working

This commit is contained in:
Dave Halter
2020-02-24 10:12:38 +01:00
parent f527138e6c
commit f8d9f498d0
3 changed files with 70 additions and 18 deletions

View File

@@ -567,8 +567,8 @@ class Script(object):
else:
if until_line is None:
until_line = line
if until_line is None:
raise TypeError('If you provide until_line you have to provide until_column')
if until_column is None:
until_column = len(self._code_lines[until_line - 1])
until_pos = until_line, until_column
return refactoring.extract_variable(
self._grammar, self.path, self._module_node, new_name, (line, column), until_pos)
@@ -588,8 +588,8 @@ class Script(object):
else:
if until_line is None:
until_line = line
if until_line is None:
raise TypeError('If you provide until_line you have to provide until_column')
if until_column is None:
until_column = len(self._code_lines[until_line - 1])
until_pos = until_line, until_column
return refactoring.extract_function(
self._inference_state, self.path, self._get_module_context(),

View File

@@ -2,6 +2,7 @@ from os.path import dirname, basename, join
import os
import re
import difflib
from textwrap import dedent
from parso import split_lines
@@ -227,17 +228,24 @@ def extract_variable(grammar, path, module_node, name, pos, until_pos):
nodes = _find_nodes(module_node, pos, until_pos)
debug.dbg('Extracting nodes: %s', nodes)
if any(node.type == 'name' and node.is_definition() for node in nodes):
raise RefactoringError('Cannot extract a name that defines something')
if nodes[0].type not in _VARIABLE_EXCTRACTABLE:
raise RefactoringError('Cannot extract a "%s"' % nodes[0].type)
is_expression, message = _is_expression_with_issue(nodes)
if not is_expression:
raise RefactoringError(message)
generated_code = name + ' = ' + _expression_nodes_to_string(nodes)
file_to_node_changes = {path: _replace(nodes, name, generated_code)}
return Refactoring(grammar, file_to_node_changes)
def _is_expression_with_issue(nodes):
if any(node.type == 'name' and node.is_definition() for node in nodes):
return False, 'Cannot extract a name that defines something'
if nodes[0].type not in _VARIABLE_EXCTRACTABLE:
return False, 'Cannot extract a "%s"' % nodes[0].type
return True, ''
def _find_nodes(module_node, pos, until_pos):
start_node = module_node.get_leaf_for_position(pos, include_prefixes=True)
@@ -347,7 +355,9 @@ def _remove_unwanted_expression_nodes(parent_node, pos, until_pos):
This function makes it so for `1 * 2 + 3` you can extract `2 + 3`, even
though it is not part of the expression.
"""
if parent_node.type in _EXPRESSION_PARTS:
typ = parent_node.type
is_suite_part = typ in ('suite', 'file_input')
if typ in _EXPRESSION_PARTS or is_suite_part:
nodes = parent_node.children
for i, n in enumerate(nodes):
if n.end_pos > pos:
@@ -369,8 +379,9 @@ def _remove_unwanted_expression_nodes(parent_node, pos, until_pos):
break
break
nodes = nodes[start_index:end_index + 1]
nodes[0:1] = _remove_unwanted_expression_nodes(nodes[0], pos, until_pos)
nodes[-1:] = _remove_unwanted_expression_nodes(nodes[-1], pos, until_pos)
if not is_suite_part:
nodes[0:1] = _remove_unwanted_expression_nodes(nodes[0], pos, until_pos)
nodes[-1:] = _remove_unwanted_expression_nodes(nodes[-1], pos, until_pos)
return nodes
return [parent_node]
@@ -381,8 +392,8 @@ def _is_not_extractable_syntax(node):
def extract_function(inference_state, path, module_context, name, pos, until_pos):
is_expression = True
nodes = _find_nodes(module_context.tree_node, pos, until_pos)
is_expression, _ = _is_expression_with_issue(nodes)
context = module_context.create_context(nodes[0])
is_bound_method = context.is_bound_method()
params, return_variables = list(_find_inputs_and_outputs(context, nodes))
@@ -398,9 +409,19 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos
if is_expression:
code_block = 'return ' + _expression_nodes_to_string(nodes) + '\n'
else:
raise 1
# Find the actually used variables (of the defined ones). If none are
# used (e.g. if the range covers the whole function), return the last
# defined variable.
return_variables = list(_find_needed_output_variables(
context,
nodes[0].parent,
nodes[-1].end_pos,
return_variables
)) or [return_variables[-1]]
output_var_str = ', '.join(return_variables)
code_block += '\nreturn ' + output_var_str + '\n'
code_block = dedent(''.join(n.get_code() for n in nodes))
code_block += 'return ' + output_var_str + '\n'
decorator = ''
self_param = None
@@ -442,7 +463,7 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos
def _find_inputs_and_outputs(context, nodes):
inputs = []
outputs = []
for name in _find_non_global_names(context, nodes):
for name in _find_non_global_names(nodes):
if name.is_definition():
if name not in outputs:
outputs.append(name.value)
@@ -458,7 +479,7 @@ def _find_inputs_and_outputs(context, nodes):
return inputs, outputs
def _find_non_global_names(context, nodes):
def _find_non_global_names(nodes):
for node in nodes:
try:
children = node.children
@@ -470,7 +491,7 @@ def _find_non_global_names(context, nodes):
if node.type == 'trailer' and node.children[0] == '.':
continue
for x in _find_non_global_names(context, children): # Python 2...
for x in _find_non_global_names(children): # Python 2...
yield x
@@ -482,3 +503,19 @@ def _get_code_insertion_node(node, is_bound_method):
while node.parent.type in ('async_funcdef', 'decorated', 'async_stmt'):
node = node.parent
return node
def _find_needed_output_variables(context, search_node, at_least_pos, return_variables):
"""
Searches everything after at_least_pos in a node and checks if any of the
return_variables are used in there and returns those.
"""
for node in search_node.children:
if node.start_pos < at_least_pos:
continue
return_variables = set(return_variables)
for name in _find_non_global_names([node]):
if not name.is_definition() and name.value in return_variables:
return_variables.remove(name.value)
yield name.value

View File

@@ -154,3 +154,18 @@ def x(z):
def y(x):
#? 15 text {'new_name': 'f'}
return f(x, z)
# -------------------------------------------------- with-range-1
#? 0 text {'new_name': 'a', 'until_line': 4}
v1 = 3
v2 = 2
x = test(v1 + v2 * v3)
# ++++++++++++++++++++++++++++++++++++++++++++++++++
#? 0 text {'new_name': 'a', 'until_line': 4}
def a(test, v3):
v1 = 3
v2 = 2
x = test(v1 + v2 * v3)
return x
a(test, v3)