mirror of
https://github.com/davidhalter/jedi.git
synced 2025-12-07 06:24:27 +08:00
Get a first extract test mostly working
This commit is contained in:
@@ -567,8 +567,8 @@ class Script(object):
|
|||||||
else:
|
else:
|
||||||
if until_line is None:
|
if until_line is None:
|
||||||
until_line = line
|
until_line = line
|
||||||
if until_line is None:
|
if until_column is None:
|
||||||
raise TypeError('If you provide until_line you have to provide until_column')
|
until_column = len(self._code_lines[until_line - 1])
|
||||||
until_pos = until_line, until_column
|
until_pos = until_line, until_column
|
||||||
return refactoring.extract_variable(
|
return refactoring.extract_variable(
|
||||||
self._grammar, self.path, self._module_node, new_name, (line, column), until_pos)
|
self._grammar, self.path, self._module_node, new_name, (line, column), until_pos)
|
||||||
@@ -588,8 +588,8 @@ class Script(object):
|
|||||||
else:
|
else:
|
||||||
if until_line is None:
|
if until_line is None:
|
||||||
until_line = line
|
until_line = line
|
||||||
if until_line is None:
|
if until_column is None:
|
||||||
raise TypeError('If you provide until_line you have to provide until_column')
|
until_column = len(self._code_lines[until_line - 1])
|
||||||
until_pos = until_line, until_column
|
until_pos = until_line, until_column
|
||||||
return refactoring.extract_function(
|
return refactoring.extract_function(
|
||||||
self._inference_state, self.path, self._get_module_context(),
|
self._inference_state, self.path, self._get_module_context(),
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from os.path import dirname, basename, join
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import difflib
|
import difflib
|
||||||
|
from textwrap import dedent
|
||||||
|
|
||||||
from parso import split_lines
|
from parso import split_lines
|
||||||
|
|
||||||
@@ -227,17 +228,24 @@ def extract_variable(grammar, path, module_node, name, pos, until_pos):
|
|||||||
nodes = _find_nodes(module_node, pos, until_pos)
|
nodes = _find_nodes(module_node, pos, until_pos)
|
||||||
debug.dbg('Extracting nodes: %s', nodes)
|
debug.dbg('Extracting nodes: %s', nodes)
|
||||||
|
|
||||||
if any(node.type == 'name' and node.is_definition() for node in nodes):
|
is_expression, message = _is_expression_with_issue(nodes)
|
||||||
raise RefactoringError('Cannot extract a name that defines something')
|
if not is_expression:
|
||||||
|
raise RefactoringError(message)
|
||||||
if nodes[0].type not in _VARIABLE_EXCTRACTABLE:
|
|
||||||
raise RefactoringError('Cannot extract a "%s"' % nodes[0].type)
|
|
||||||
|
|
||||||
generated_code = name + ' = ' + _expression_nodes_to_string(nodes)
|
generated_code = name + ' = ' + _expression_nodes_to_string(nodes)
|
||||||
file_to_node_changes = {path: _replace(nodes, name, generated_code)}
|
file_to_node_changes = {path: _replace(nodes, name, generated_code)}
|
||||||
return Refactoring(grammar, file_to_node_changes)
|
return Refactoring(grammar, file_to_node_changes)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_expression_with_issue(nodes):
|
||||||
|
if any(node.type == 'name' and node.is_definition() for node in nodes):
|
||||||
|
return False, 'Cannot extract a name that defines something'
|
||||||
|
|
||||||
|
if nodes[0].type not in _VARIABLE_EXCTRACTABLE:
|
||||||
|
return False, 'Cannot extract a "%s"' % nodes[0].type
|
||||||
|
return True, ''
|
||||||
|
|
||||||
|
|
||||||
def _find_nodes(module_node, pos, until_pos):
|
def _find_nodes(module_node, pos, until_pos):
|
||||||
start_node = module_node.get_leaf_for_position(pos, include_prefixes=True)
|
start_node = module_node.get_leaf_for_position(pos, include_prefixes=True)
|
||||||
|
|
||||||
@@ -347,7 +355,9 @@ def _remove_unwanted_expression_nodes(parent_node, pos, until_pos):
|
|||||||
This function makes it so for `1 * 2 + 3` you can extract `2 + 3`, even
|
This function makes it so for `1 * 2 + 3` you can extract `2 + 3`, even
|
||||||
though it is not part of the expression.
|
though it is not part of the expression.
|
||||||
"""
|
"""
|
||||||
if parent_node.type in _EXPRESSION_PARTS:
|
typ = parent_node.type
|
||||||
|
is_suite_part = typ in ('suite', 'file_input')
|
||||||
|
if typ in _EXPRESSION_PARTS or is_suite_part:
|
||||||
nodes = parent_node.children
|
nodes = parent_node.children
|
||||||
for i, n in enumerate(nodes):
|
for i, n in enumerate(nodes):
|
||||||
if n.end_pos > pos:
|
if n.end_pos > pos:
|
||||||
@@ -369,6 +379,7 @@ def _remove_unwanted_expression_nodes(parent_node, pos, until_pos):
|
|||||||
break
|
break
|
||||||
break
|
break
|
||||||
nodes = nodes[start_index:end_index + 1]
|
nodes = nodes[start_index:end_index + 1]
|
||||||
|
if not is_suite_part:
|
||||||
nodes[0:1] = _remove_unwanted_expression_nodes(nodes[0], pos, until_pos)
|
nodes[0:1] = _remove_unwanted_expression_nodes(nodes[0], pos, until_pos)
|
||||||
nodes[-1:] = _remove_unwanted_expression_nodes(nodes[-1], pos, until_pos)
|
nodes[-1:] = _remove_unwanted_expression_nodes(nodes[-1], pos, until_pos)
|
||||||
return nodes
|
return nodes
|
||||||
@@ -381,8 +392,8 @@ def _is_not_extractable_syntax(node):
|
|||||||
|
|
||||||
|
|
||||||
def extract_function(inference_state, path, module_context, name, pos, until_pos):
|
def extract_function(inference_state, path, module_context, name, pos, until_pos):
|
||||||
is_expression = True
|
|
||||||
nodes = _find_nodes(module_context.tree_node, pos, until_pos)
|
nodes = _find_nodes(module_context.tree_node, pos, until_pos)
|
||||||
|
is_expression, _ = _is_expression_with_issue(nodes)
|
||||||
context = module_context.create_context(nodes[0])
|
context = module_context.create_context(nodes[0])
|
||||||
is_bound_method = context.is_bound_method()
|
is_bound_method = context.is_bound_method()
|
||||||
params, return_variables = list(_find_inputs_and_outputs(context, nodes))
|
params, return_variables = list(_find_inputs_and_outputs(context, nodes))
|
||||||
@@ -398,9 +409,19 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos
|
|||||||
if is_expression:
|
if is_expression:
|
||||||
code_block = 'return ' + _expression_nodes_to_string(nodes) + '\n'
|
code_block = 'return ' + _expression_nodes_to_string(nodes) + '\n'
|
||||||
else:
|
else:
|
||||||
raise 1
|
# Find the actually used variables (of the defined ones). If none are
|
||||||
|
# used (e.g. if the range covers the whole function), return the last
|
||||||
|
# defined variable.
|
||||||
|
return_variables = list(_find_needed_output_variables(
|
||||||
|
context,
|
||||||
|
nodes[0].parent,
|
||||||
|
nodes[-1].end_pos,
|
||||||
|
return_variables
|
||||||
|
)) or [return_variables[-1]]
|
||||||
|
|
||||||
output_var_str = ', '.join(return_variables)
|
output_var_str = ', '.join(return_variables)
|
||||||
code_block += '\nreturn ' + output_var_str + '\n'
|
code_block = dedent(''.join(n.get_code() for n in nodes))
|
||||||
|
code_block += 'return ' + output_var_str + '\n'
|
||||||
|
|
||||||
decorator = ''
|
decorator = ''
|
||||||
self_param = None
|
self_param = None
|
||||||
@@ -442,7 +463,7 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos
|
|||||||
def _find_inputs_and_outputs(context, nodes):
|
def _find_inputs_and_outputs(context, nodes):
|
||||||
inputs = []
|
inputs = []
|
||||||
outputs = []
|
outputs = []
|
||||||
for name in _find_non_global_names(context, nodes):
|
for name in _find_non_global_names(nodes):
|
||||||
if name.is_definition():
|
if name.is_definition():
|
||||||
if name not in outputs:
|
if name not in outputs:
|
||||||
outputs.append(name.value)
|
outputs.append(name.value)
|
||||||
@@ -458,7 +479,7 @@ def _find_inputs_and_outputs(context, nodes):
|
|||||||
return inputs, outputs
|
return inputs, outputs
|
||||||
|
|
||||||
|
|
||||||
def _find_non_global_names(context, nodes):
|
def _find_non_global_names(nodes):
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
try:
|
try:
|
||||||
children = node.children
|
children = node.children
|
||||||
@@ -470,7 +491,7 @@ def _find_non_global_names(context, nodes):
|
|||||||
if node.type == 'trailer' and node.children[0] == '.':
|
if node.type == 'trailer' and node.children[0] == '.':
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for x in _find_non_global_names(context, children): # Python 2...
|
for x in _find_non_global_names(children): # Python 2...
|
||||||
yield x
|
yield x
|
||||||
|
|
||||||
|
|
||||||
@@ -482,3 +503,19 @@ def _get_code_insertion_node(node, is_bound_method):
|
|||||||
while node.parent.type in ('async_funcdef', 'decorated', 'async_stmt'):
|
while node.parent.type in ('async_funcdef', 'decorated', 'async_stmt'):
|
||||||
node = node.parent
|
node = node.parent
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
def _find_needed_output_variables(context, search_node, at_least_pos, return_variables):
|
||||||
|
"""
|
||||||
|
Searches everything after at_least_pos in a node and checks if any of the
|
||||||
|
return_variables are used in there and returns those.
|
||||||
|
"""
|
||||||
|
for node in search_node.children:
|
||||||
|
if node.start_pos < at_least_pos:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return_variables = set(return_variables)
|
||||||
|
for name in _find_non_global_names([node]):
|
||||||
|
if not name.is_definition() and name.value in return_variables:
|
||||||
|
return_variables.remove(name.value)
|
||||||
|
yield name.value
|
||||||
|
|||||||
@@ -154,3 +154,18 @@ def x(z):
|
|||||||
def y(x):
|
def y(x):
|
||||||
#? 15 text {'new_name': 'f'}
|
#? 15 text {'new_name': 'f'}
|
||||||
return f(x, z)
|
return f(x, z)
|
||||||
|
# -------------------------------------------------- with-range-1
|
||||||
|
#? 0 text {'new_name': 'a', 'until_line': 4}
|
||||||
|
v1 = 3
|
||||||
|
v2 = 2
|
||||||
|
x = test(v1 + v2 * v3)
|
||||||
|
# ++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
#? 0 text {'new_name': 'a', 'until_line': 4}
|
||||||
|
def a(test, v3):
|
||||||
|
v1 = 3
|
||||||
|
v2 = 2
|
||||||
|
x = test(v1 + v2 * v3)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
a(test, v3)
|
||||||
|
|||||||
Reference in New Issue
Block a user