diff --git a/jedi/api/__init__.py b/jedi/api/__init__.py index 934d2773..ea9ddafa 100644 --- a/jedi/api/__init__.py +++ b/jedi/api/__init__.py @@ -567,8 +567,8 @@ class Script(object): else: if until_line is None: until_line = line - if until_line is None: - raise TypeError('If you provide until_line you have to provide until_column') + if until_column is None: + until_column = len(self._code_lines[until_line - 1]) until_pos = until_line, until_column return refactoring.extract_variable( self._grammar, self.path, self._module_node, new_name, (line, column), until_pos) @@ -588,8 +588,8 @@ class Script(object): else: if until_line is None: until_line = line - if until_line is None: - raise TypeError('If you provide until_line you have to provide until_column') + if until_column is None: + until_column = len(self._code_lines[until_line - 1]) until_pos = until_line, until_column return refactoring.extract_function( self._inference_state, self.path, self._get_module_context(), diff --git a/jedi/api/refactoring.py b/jedi/api/refactoring.py index 1f4ef917..e550f44a 100644 --- a/jedi/api/refactoring.py +++ b/jedi/api/refactoring.py @@ -2,6 +2,7 @@ from os.path import dirname, basename, join import os import re import difflib +from textwrap import dedent from parso import split_lines @@ -227,17 +228,24 @@ def extract_variable(grammar, path, module_node, name, pos, until_pos): nodes = _find_nodes(module_node, pos, until_pos) debug.dbg('Extracting nodes: %s', nodes) - if any(node.type == 'name' and node.is_definition() for node in nodes): - raise RefactoringError('Cannot extract a name that defines something') - - if nodes[0].type not in _VARIABLE_EXCTRACTABLE: - raise RefactoringError('Cannot extract a "%s"' % nodes[0].type) + is_expression, message = _is_expression_with_issue(nodes) + if not is_expression: + raise RefactoringError(message) generated_code = name + ' = ' + _expression_nodes_to_string(nodes) file_to_node_changes = {path: _replace(nodes, name, generated_code)} return Refactoring(grammar, file_to_node_changes) +def _is_expression_with_issue(nodes): + if any(node.type == 'name' and node.is_definition() for node in nodes): + return False, 'Cannot extract a name that defines something' + + if nodes[0].type not in _VARIABLE_EXCTRACTABLE: + return False, 'Cannot extract a "%s"' % nodes[0].type + return True, '' + + def _find_nodes(module_node, pos, until_pos): start_node = module_node.get_leaf_for_position(pos, include_prefixes=True) @@ -347,7 +355,9 @@ def _remove_unwanted_expression_nodes(parent_node, pos, until_pos): This function makes it so for `1 * 2 + 3` you can extract `2 + 3`, even though it is not part of the expression. """ - if parent_node.type in _EXPRESSION_PARTS: + typ = parent_node.type + is_suite_part = typ in ('suite', 'file_input') + if typ in _EXPRESSION_PARTS or is_suite_part: nodes = parent_node.children for i, n in enumerate(nodes): if n.end_pos > pos: @@ -369,8 +379,9 @@ def _remove_unwanted_expression_nodes(parent_node, pos, until_pos): break break nodes = nodes[start_index:end_index + 1] - nodes[0:1] = _remove_unwanted_expression_nodes(nodes[0], pos, until_pos) - nodes[-1:] = _remove_unwanted_expression_nodes(nodes[-1], pos, until_pos) + if not is_suite_part: + nodes[0:1] = _remove_unwanted_expression_nodes(nodes[0], pos, until_pos) + nodes[-1:] = _remove_unwanted_expression_nodes(nodes[-1], pos, until_pos) return nodes return [parent_node] @@ -381,8 +392,8 @@ def _is_not_extractable_syntax(node): def extract_function(inference_state, path, module_context, name, pos, until_pos): - is_expression = True nodes = _find_nodes(module_context.tree_node, pos, until_pos) + is_expression, _ = _is_expression_with_issue(nodes) context = module_context.create_context(nodes[0]) is_bound_method = context.is_bound_method() params, return_variables = list(_find_inputs_and_outputs(context, nodes)) @@ -398,9 +409,19 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos if is_expression: code_block = 'return ' + _expression_nodes_to_string(nodes) + '\n' else: - raise 1 + # Find the actually used variables (of the defined ones). If none are + # used (e.g. if the range covers the whole function), return the last + # defined variable. + return_variables = list(_find_needed_output_variables( + context, + nodes[0].parent, + nodes[-1].end_pos, + return_variables + )) or [return_variables[-1]] + output_var_str = ', '.join(return_variables) - code_block += '\nreturn ' + output_var_str + '\n' + code_block = dedent(''.join(n.get_code() for n in nodes)) + code_block += 'return ' + output_var_str + '\n' decorator = '' self_param = None @@ -442,7 +463,7 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos def _find_inputs_and_outputs(context, nodes): inputs = [] outputs = [] - for name in _find_non_global_names(context, nodes): + for name in _find_non_global_names(nodes): if name.is_definition(): if name not in outputs: outputs.append(name.value) @@ -458,7 +479,7 @@ def _find_inputs_and_outputs(context, nodes): return inputs, outputs -def _find_non_global_names(context, nodes): +def _find_non_global_names(nodes): for node in nodes: try: children = node.children @@ -470,7 +491,7 @@ def _find_non_global_names(context, nodes): if node.type == 'trailer' and node.children[0] == '.': continue - for x in _find_non_global_names(context, children): # Python 2... + for x in _find_non_global_names(children): # Python 2... yield x @@ -482,3 +503,19 @@ def _get_code_insertion_node(node, is_bound_method): while node.parent.type in ('async_funcdef', 'decorated', 'async_stmt'): node = node.parent return node + + +def _find_needed_output_variables(context, search_node, at_least_pos, return_variables): + """ + Searches everything after at_least_pos in a node and checks if any of the + return_variables are used in there and returns those. + """ + for node in search_node.children: + if node.start_pos < at_least_pos: + continue + + return_variables = set(return_variables) + for name in _find_non_global_names([node]): + if not name.is_definition() and name.value in return_variables: + return_variables.remove(name.value) + yield name.value diff --git a/test/refactor/extract_function.py b/test/refactor/extract_function.py index adf435be..a53794ac 100644 --- a/test/refactor/extract_function.py +++ b/test/refactor/extract_function.py @@ -154,3 +154,18 @@ def x(z): def y(x): #? 15 text {'new_name': 'f'} return f(x, z) +# -------------------------------------------------- with-range-1 +#? 0 text {'new_name': 'a', 'until_line': 4} +v1 = 3 +v2 = 2 +x = test(v1 + v2 * v3) +# ++++++++++++++++++++++++++++++++++++++++++++++++++ +#? 0 text {'new_name': 'a', 'until_line': 4} +def a(test, v3): + v1 = 3 + v2 = 2 + x = test(v1 + v2 * v3) + return x + + +a(test, v3)