mirror of
https://github.com/davidhalter/jedi.git
synced 2025-12-06 14:04:26 +08:00
Get a first extract test mostly working
This commit is contained in:
@@ -567,8 +567,8 @@ class Script(object):
|
||||
else:
|
||||
if until_line is None:
|
||||
until_line = line
|
||||
if until_line is None:
|
||||
raise TypeError('If you provide until_line you have to provide until_column')
|
||||
if until_column is None:
|
||||
until_column = len(self._code_lines[until_line - 1])
|
||||
until_pos = until_line, until_column
|
||||
return refactoring.extract_variable(
|
||||
self._grammar, self.path, self._module_node, new_name, (line, column), until_pos)
|
||||
@@ -588,8 +588,8 @@ class Script(object):
|
||||
else:
|
||||
if until_line is None:
|
||||
until_line = line
|
||||
if until_line is None:
|
||||
raise TypeError('If you provide until_line you have to provide until_column')
|
||||
if until_column is None:
|
||||
until_column = len(self._code_lines[until_line - 1])
|
||||
until_pos = until_line, until_column
|
||||
return refactoring.extract_function(
|
||||
self._inference_state, self.path, self._get_module_context(),
|
||||
|
||||
@@ -2,6 +2,7 @@ from os.path import dirname, basename, join
|
||||
import os
|
||||
import re
|
||||
import difflib
|
||||
from textwrap import dedent
|
||||
|
||||
from parso import split_lines
|
||||
|
||||
@@ -227,17 +228,24 @@ def extract_variable(grammar, path, module_node, name, pos, until_pos):
|
||||
nodes = _find_nodes(module_node, pos, until_pos)
|
||||
debug.dbg('Extracting nodes: %s', nodes)
|
||||
|
||||
if any(node.type == 'name' and node.is_definition() for node in nodes):
|
||||
raise RefactoringError('Cannot extract a name that defines something')
|
||||
|
||||
if nodes[0].type not in _VARIABLE_EXCTRACTABLE:
|
||||
raise RefactoringError('Cannot extract a "%s"' % nodes[0].type)
|
||||
is_expression, message = _is_expression_with_issue(nodes)
|
||||
if not is_expression:
|
||||
raise RefactoringError(message)
|
||||
|
||||
generated_code = name + ' = ' + _expression_nodes_to_string(nodes)
|
||||
file_to_node_changes = {path: _replace(nodes, name, generated_code)}
|
||||
return Refactoring(grammar, file_to_node_changes)
|
||||
|
||||
|
||||
def _is_expression_with_issue(nodes):
|
||||
if any(node.type == 'name' and node.is_definition() for node in nodes):
|
||||
return False, 'Cannot extract a name that defines something'
|
||||
|
||||
if nodes[0].type not in _VARIABLE_EXCTRACTABLE:
|
||||
return False, 'Cannot extract a "%s"' % nodes[0].type
|
||||
return True, ''
|
||||
|
||||
|
||||
def _find_nodes(module_node, pos, until_pos):
|
||||
start_node = module_node.get_leaf_for_position(pos, include_prefixes=True)
|
||||
|
||||
@@ -347,7 +355,9 @@ def _remove_unwanted_expression_nodes(parent_node, pos, until_pos):
|
||||
This function makes it so for `1 * 2 + 3` you can extract `2 + 3`, even
|
||||
though it is not part of the expression.
|
||||
"""
|
||||
if parent_node.type in _EXPRESSION_PARTS:
|
||||
typ = parent_node.type
|
||||
is_suite_part = typ in ('suite', 'file_input')
|
||||
if typ in _EXPRESSION_PARTS or is_suite_part:
|
||||
nodes = parent_node.children
|
||||
for i, n in enumerate(nodes):
|
||||
if n.end_pos > pos:
|
||||
@@ -369,8 +379,9 @@ def _remove_unwanted_expression_nodes(parent_node, pos, until_pos):
|
||||
break
|
||||
break
|
||||
nodes = nodes[start_index:end_index + 1]
|
||||
nodes[0:1] = _remove_unwanted_expression_nodes(nodes[0], pos, until_pos)
|
||||
nodes[-1:] = _remove_unwanted_expression_nodes(nodes[-1], pos, until_pos)
|
||||
if not is_suite_part:
|
||||
nodes[0:1] = _remove_unwanted_expression_nodes(nodes[0], pos, until_pos)
|
||||
nodes[-1:] = _remove_unwanted_expression_nodes(nodes[-1], pos, until_pos)
|
||||
return nodes
|
||||
return [parent_node]
|
||||
|
||||
@@ -381,8 +392,8 @@ def _is_not_extractable_syntax(node):
|
||||
|
||||
|
||||
def extract_function(inference_state, path, module_context, name, pos, until_pos):
|
||||
is_expression = True
|
||||
nodes = _find_nodes(module_context.tree_node, pos, until_pos)
|
||||
is_expression, _ = _is_expression_with_issue(nodes)
|
||||
context = module_context.create_context(nodes[0])
|
||||
is_bound_method = context.is_bound_method()
|
||||
params, return_variables = list(_find_inputs_and_outputs(context, nodes))
|
||||
@@ -398,9 +409,19 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos
|
||||
if is_expression:
|
||||
code_block = 'return ' + _expression_nodes_to_string(nodes) + '\n'
|
||||
else:
|
||||
raise 1
|
||||
# Find the actually used variables (of the defined ones). If none are
|
||||
# used (e.g. if the range covers the whole function), return the last
|
||||
# defined variable.
|
||||
return_variables = list(_find_needed_output_variables(
|
||||
context,
|
||||
nodes[0].parent,
|
||||
nodes[-1].end_pos,
|
||||
return_variables
|
||||
)) or [return_variables[-1]]
|
||||
|
||||
output_var_str = ', '.join(return_variables)
|
||||
code_block += '\nreturn ' + output_var_str + '\n'
|
||||
code_block = dedent(''.join(n.get_code() for n in nodes))
|
||||
code_block += 'return ' + output_var_str + '\n'
|
||||
|
||||
decorator = ''
|
||||
self_param = None
|
||||
@@ -442,7 +463,7 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos
|
||||
def _find_inputs_and_outputs(context, nodes):
|
||||
inputs = []
|
||||
outputs = []
|
||||
for name in _find_non_global_names(context, nodes):
|
||||
for name in _find_non_global_names(nodes):
|
||||
if name.is_definition():
|
||||
if name not in outputs:
|
||||
outputs.append(name.value)
|
||||
@@ -458,7 +479,7 @@ def _find_inputs_and_outputs(context, nodes):
|
||||
return inputs, outputs
|
||||
|
||||
|
||||
def _find_non_global_names(context, nodes):
|
||||
def _find_non_global_names(nodes):
|
||||
for node in nodes:
|
||||
try:
|
||||
children = node.children
|
||||
@@ -470,7 +491,7 @@ def _find_non_global_names(context, nodes):
|
||||
if node.type == 'trailer' and node.children[0] == '.':
|
||||
continue
|
||||
|
||||
for x in _find_non_global_names(context, children): # Python 2...
|
||||
for x in _find_non_global_names(children): # Python 2...
|
||||
yield x
|
||||
|
||||
|
||||
@@ -482,3 +503,19 @@ def _get_code_insertion_node(node, is_bound_method):
|
||||
while node.parent.type in ('async_funcdef', 'decorated', 'async_stmt'):
|
||||
node = node.parent
|
||||
return node
|
||||
|
||||
|
||||
def _find_needed_output_variables(context, search_node, at_least_pos, return_variables):
|
||||
"""
|
||||
Searches everything after at_least_pos in a node and checks if any of the
|
||||
return_variables are used in there and returns those.
|
||||
"""
|
||||
for node in search_node.children:
|
||||
if node.start_pos < at_least_pos:
|
||||
continue
|
||||
|
||||
return_variables = set(return_variables)
|
||||
for name in _find_non_global_names([node]):
|
||||
if not name.is_definition() and name.value in return_variables:
|
||||
return_variables.remove(name.value)
|
||||
yield name.value
|
||||
|
||||
@@ -154,3 +154,18 @@ def x(z):
|
||||
def y(x):
|
||||
#? 15 text {'new_name': 'f'}
|
||||
return f(x, z)
|
||||
# -------------------------------------------------- with-range-1
|
||||
#? 0 text {'new_name': 'a', 'until_line': 4}
|
||||
v1 = 3
|
||||
v2 = 2
|
||||
x = test(v1 + v2 * v3)
|
||||
# ++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
#? 0 text {'new_name': 'a', 'until_line': 4}
|
||||
def a(test, v3):
|
||||
v1 = 3
|
||||
v2 = 2
|
||||
x = test(v1 + v2 * v3)
|
||||
return x
|
||||
|
||||
|
||||
a(test, v3)
|
||||
|
||||
Reference in New Issue
Block a user