diff --git a/jedi/api/__init__.py b/jedi/api/__init__.py index ea9ddafa..1e087cec 100644 --- a/jedi/api/__init__.py +++ b/jedi/api/__init__.py @@ -32,6 +32,7 @@ from jedi.api.environment import InterpreterEnvironment from jedi.api.project import get_default_project, Project from jedi.api.errors import parso_to_jedi_errors from jedi.api import refactoring +from jedi.api.refactoring.extract import extract_function, extract_variable from jedi.inference import InferenceState from jedi.inference import imports from jedi.inference.references import find_references @@ -570,7 +571,7 @@ class Script(object): if until_column is None: until_column = len(self._code_lines[until_line - 1]) until_pos = until_line, until_column - return refactoring.extract_variable( + return extract_variable( self._grammar, self.path, self._module_node, new_name, (line, column), until_pos) @no_py2_support @@ -591,7 +592,7 @@ class Script(object): if until_column is None: until_column = len(self._code_lines[until_line - 1]) until_pos = until_line, until_column - return refactoring.extract_function( + return extract_function( self._inference_state, self.path, self._get_module_context(), new_name, (line, column), until_pos ) diff --git a/jedi/api/refactoring/__init__.py b/jedi/api/refactoring/__init__.py index f984f937..0d0e2fad 100644 --- a/jedi/api/refactoring/__init__.py +++ b/jedi/api/refactoring/__init__.py @@ -2,24 +2,15 @@ from os.path import dirname, basename, join import os import re import difflib -from textwrap import dedent from parso import split_lines -from jedi import debug from jedi.api.exceptions import RefactoringError -from jedi.common.utils import indent_block -from jedi.parser_utils import function_is_classmethod, function_is_staticmethod -_EXPRESSION_PARTS = ( +EXPRESSION_PARTS = ( 'or_test and_test not_test comparison ' 'expr xor_expr and_expr shift_expr arith_expr term factor power atom_expr' ).split() -_EXTRACT_USE_PARENT = _EXPRESSION_PARTS + ['trailer'] -_DEFINITION_SCOPES = ('suite', 'file_input') -_VARIABLE_EXCTRACTABLE = _EXPRESSION_PARTS + \ - ('atom testlist_star_expr testlist test lambdef lambdef_nocond ' - 'keyword name number string fstring').split() class ChangedFile(object): @@ -194,7 +185,7 @@ def inline(grammar, names): path = name.get_root_context().py__file__() s = replace_code if rhs.type == 'testlist_star_expr' \ - or tree_name.parent.type in _EXPRESSION_PARTS \ + or tree_name.parent.type in EXPRESSION_PARTS \ or tree_name.parent.type == 'trailer' \ and tree_name.parent.get_next_sibling() is not None: s = '(' + replace_code + ')' @@ -224,307 +215,8 @@ def inline(grammar, names): return Refactoring(grammar, file_to_node_changes) -def extract_variable(grammar, path, module_node, name, pos, until_pos): - nodes = _find_nodes(module_node, pos, until_pos) - debug.dbg('Extracting nodes: %s', nodes) - - is_expression, message = _is_expression_with_error(nodes) - if not is_expression: - raise RefactoringError(message) - - generated_code = name + ' = ' + _expression_nodes_to_string(nodes) - file_to_node_changes = {path: _replace(nodes, name, generated_code, pos)} - return Refactoring(grammar, file_to_node_changes) - - -def _is_expression_with_error(nodes): - """ - Returns a tuple (is_expression, error_string). - """ - if any(node.type == 'name' and node.is_definition() for node in nodes): - return False, 'Cannot extract a name that defines something' - - if nodes[0].type not in _VARIABLE_EXCTRACTABLE: - return False, 'Cannot extract a "%s"' % nodes[0].type - return True, '' - - -def _find_nodes(module_node, pos, until_pos): - """ - Looks up a module and tries to find the appropriate amount of nodes that - are in there. - """ - start_node = module_node.get_leaf_for_position(pos, include_prefixes=True) - - if until_pos is None: - if start_node.type == 'operator': - next_leaf = start_node.get_next_leaf() - if next_leaf is not None and next_leaf.start_pos == pos: - start_node = next_leaf - - if _is_not_extractable_syntax(start_node): - start_node = start_node.parent - - while start_node.parent.type in _EXTRACT_USE_PARENT: - start_node = start_node.parent - - nodes = [start_node] - else: - # Get the next leaf if we are at the end of a leaf - if start_node.end_pos == pos: - next_leaf = start_node.get_next_leaf() - if next_leaf is not None: - start_node = next_leaf - - # Some syntax is not exactable, just use its parent - if _is_not_extractable_syntax(start_node): - start_node = start_node.parent - - # Find the end - end_leaf = module_node.get_leaf_for_position(until_pos, include_prefixes=True) - if end_leaf.start_pos > until_pos: - end_leaf = end_leaf.get_previous_leaf() - if end_leaf is None: - raise RefactoringError('Cannot extract anything from that') - - parent_node = start_node - while parent_node.end_pos < end_leaf.end_pos: - parent_node = parent_node.parent - - nodes = _remove_unwanted_expression_nodes(parent_node, pos, until_pos) - - # If the user marks just a return statement, we return the expression - # instead of the whole statement, because the user obviously wants to - # extract that part. - if len(nodes) == 1 and start_node.type in ('return_stmt', 'yield_expr'): - return [nodes[0].children[1]] - return nodes - - -def _replace(nodes, expression_replacement, extracted, pos, insert_before_leaf=None): - # Now try to replace the nodes found with a variable and move the code - # before the current statement. - definition = _get_parent_definition(nodes[0]) - if insert_before_leaf is None: - insert_before_leaf = definition.get_first_leaf() - first_node_leaf = nodes[0].get_first_leaf() - - lines = split_lines(insert_before_leaf.prefix, keepends=True) - if first_node_leaf is insert_before_leaf: - removed_line_count = nodes[0].start_pos[0] - pos[0] - lines = lines[:-removed_line_count - 1] + [lines[-1]] - lines[-1:-1] = [indent_block(extracted, lines[-1]) + '\n'] - extracted_prefix = ''.join(lines) - - replacement_dct = {} - if first_node_leaf is insert_before_leaf: - replacement_dct[nodes[0]] = extracted_prefix + expression_replacement - else: - replacement_dct[nodes[0]] = first_node_leaf.prefix + expression_replacement - replacement_dct[insert_before_leaf] = extracted_prefix + insert_before_leaf.value - - for node in nodes[1:]: - replacement_dct[node] = '' - return replacement_dct - - -def _expression_nodes_to_string(nodes): - return ''.join(n.get_code(include_prefix=i != 0) for i, n in enumerate(nodes)) - - -def _suite_nodes_to_string(nodes, pos): - n = nodes[0] - included_line_count = n.start_pos[0] - pos[0] - lines = split_lines(n.get_first_leaf().prefix, keepends=True)[-included_line_count - 1] - return ''.join(lines) + n.get_code(include_prefix=False) \ - + ''.join(n.get_code() for n in nodes[1:]) - - def _remove_indent_of_prefix(prefix): r""" Removes the last indentation of a prefix, e.g. " \n \n " becomes " \n \n". """ return ''.join(split_lines(prefix, keepends=True)[:-1]) - - -def _get_indentation(node): - return split_lines(node.get_first_leaf().prefix)[-1] - - -def _get_parent_definition(node): - """ - Returns the statement where a node is defined. - """ - while node is not None: - if node.parent.type in _DEFINITION_SCOPES: - return node - node = node.parent - raise NotImplementedError('We should never even get here') - - -def _remove_unwanted_expression_nodes(parent_node, pos, until_pos): - """ - This function makes it so for `1 * 2 + 3` you can extract `2 + 3`, even - though it is not part of the expression. - """ - typ = parent_node.type - is_suite_part = typ in ('suite', 'file_input') - if typ in _EXPRESSION_PARTS or is_suite_part: - nodes = parent_node.children - for i, n in enumerate(nodes): - if n.end_pos > pos: - start_index = i - if n.type == 'operator': - start_index -= 1 - break - for i, n in reversed(list(enumerate(nodes))): - if n.start_pos < until_pos: - end_index = i - if n.type == 'operator': - end_index += 1 - - # Something like `not foo or bar` should not be cut after not - for n in nodes[i:]: - if _is_not_extractable_syntax(n): - end_index += 1 - else: - break - break - nodes = nodes[start_index:end_index + 1] - if not is_suite_part: - nodes[0:1] = _remove_unwanted_expression_nodes(nodes[0], pos, until_pos) - nodes[-1:] = _remove_unwanted_expression_nodes(nodes[-1], pos, until_pos) - return nodes - return [parent_node] - - -def _is_not_extractable_syntax(node): - return node.type == 'operator' \ - or node.type == 'keyword' and node.value not in ('None', 'True', 'False') - - -def extract_function(inference_state, path, module_context, name, pos, until_pos): - nodes = _find_nodes(module_context.tree_node, pos, until_pos) - is_expression, _ = _is_expression_with_error(nodes) - context = module_context.create_context(nodes[0]) - is_bound_method = context.is_bound_method() - params, return_variables = list(_find_inputs_and_outputs(context, nodes)) - - # Find variables - # Is a class method / method - if context.is_module(): - insert_before_leaf = None # Leaf will be determined later - else: - node = _get_code_insertion_node(context.tree_node, is_bound_method) - insert_before_leaf = node.get_first_leaf() - if is_expression: - code_block = 'return ' + _expression_nodes_to_string(nodes) + '\n' - else: - # Find the actually used variables (of the defined ones). If none are - # used (e.g. if the range covers the whole function), return the last - # defined variable. - return_variables = list(_find_needed_output_variables( - context, - nodes[0].parent, - nodes[-1].end_pos, - return_variables - )) or [return_variables[-1]] - - output_var_str = ', '.join(return_variables) - code_block = dedent(_suite_nodes_to_string(nodes, pos)) - code_block += 'return ' + output_var_str + '\n' - - decorator = '' - self_param = None - if is_bound_method: - if not function_is_staticmethod(context.tree_node): - function_param_names = context.get_value().get_param_names() - if len(function_param_names): - self_param = function_param_names[0].string_name - params = [p for p in params if p != self_param] - - if function_is_classmethod(context.tree_node): - decorator = '@classmethod\n' - else: - code_block += '\n' - - function_code = '%sdef %s(%s):\n%s' % ( - decorator, - name, - ', '.join(params if self_param is None else [self_param] + params), - indent_block(code_block) - ) - - function_call = '%s(%s)' % ( - ('' if self_param is None else self_param + '.') + name, - ', '.join(params) - ) - if is_expression: - replacement = function_call - else: - replacement = _get_indentation(nodes[0]) + output_var_str + ' = ' + function_call - - replaced_str = _replace(nodes, replacement, function_code, pos, insert_before_leaf) - file_to_node_changes = {path: replaced_str} - return Refactoring(inference_state.grammar, file_to_node_changes) - - -def _find_inputs_and_outputs(context, nodes): - inputs = [] - outputs = [] - for name in _find_non_global_names(nodes): - if name.is_definition(): - if name not in outputs: - outputs.append(name.value) - else: - if name.value not in inputs: - name_definitions = context.goto(name, name.start_pos) - if not name_definitions \ - or any(not n.parent_context.is_module() or n.api_type == 'param' - for n in name_definitions): - inputs.append(name.value) - - # Check if outputs are really needed: - return inputs, outputs - - -def _find_non_global_names(nodes): - for node in nodes: - try: - children = node.children - except AttributeError: - if node.type == 'name': - yield node - else: - # We only want to check foo in foo.bar - if node.type == 'trailer' and node.children[0] == '.': - continue - - for x in _find_non_global_names(children): # Python 2... - yield x - - -def _get_code_insertion_node(node, is_bound_method): - if not is_bound_method or function_is_staticmethod(node): - while node.parent.type != 'file_input': - node = node.parent - - while node.parent.type in ('async_funcdef', 'decorated', 'async_stmt'): - node = node.parent - return node - - -def _find_needed_output_variables(context, search_node, at_least_pos, return_variables): - """ - Searches everything after at_least_pos in a node and checks if any of the - return_variables are used in there and returns those. - """ - for node in search_node.children: - if node.start_pos < at_least_pos: - continue - - return_variables = set(return_variables) - for name in _find_non_global_names([node]): - if not name.is_definition() and name.value in return_variables: - return_variables.remove(name.value) - yield name.value diff --git a/jedi/api/refactoring/extract.py b/jedi/api/refactoring/extract.py new file mode 100644 index 00000000..03dfaef6 --- /dev/null +++ b/jedi/api/refactoring/extract.py @@ -0,0 +1,315 @@ +from textwrap import dedent + +from parso import split_lines + +from jedi import debug +from jedi.api.exceptions import RefactoringError +from jedi.api.refactoring import Refactoring, EXPRESSION_PARTS +from jedi.common.utils import indent_block +from jedi.parser_utils import function_is_classmethod, function_is_staticmethod + + +_EXTRACT_USE_PARENT = EXPRESSION_PARTS + ['trailer'] +_DEFINITION_SCOPES = ('suite', 'file_input') +_VARIABLE_EXCTRACTABLE = EXPRESSION_PARTS + \ + ('atom testlist_star_expr testlist test lambdef lambdef_nocond ' + 'keyword name number string fstring').split() + + +def extract_variable(grammar, path, module_node, name, pos, until_pos): + nodes = _find_nodes(module_node, pos, until_pos) + debug.dbg('Extracting nodes: %s', nodes) + + is_expression, message = _is_expression_with_error(nodes) + if not is_expression: + raise RefactoringError(message) + + generated_code = name + ' = ' + _expression_nodes_to_string(nodes) + file_to_node_changes = {path: _replace(nodes, name, generated_code, pos)} + return Refactoring(grammar, file_to_node_changes) + + +def _is_expression_with_error(nodes): + """ + Returns a tuple (is_expression, error_string). + """ + if any(node.type == 'name' and node.is_definition() for node in nodes): + return False, 'Cannot extract a name that defines something' + + if nodes[0].type not in _VARIABLE_EXCTRACTABLE: + return False, 'Cannot extract a "%s"' % nodes[0].type + return True, '' + + +def _find_nodes(module_node, pos, until_pos): + """ + Looks up a module and tries to find the appropriate amount of nodes that + are in there. + """ + start_node = module_node.get_leaf_for_position(pos, include_prefixes=True) + + if until_pos is None: + if start_node.type == 'operator': + next_leaf = start_node.get_next_leaf() + if next_leaf is not None and next_leaf.start_pos == pos: + start_node = next_leaf + + if _is_not_extractable_syntax(start_node): + start_node = start_node.parent + + while start_node.parent.type in _EXTRACT_USE_PARENT: + start_node = start_node.parent + + nodes = [start_node] + else: + # Get the next leaf if we are at the end of a leaf + if start_node.end_pos == pos: + next_leaf = start_node.get_next_leaf() + if next_leaf is not None: + start_node = next_leaf + + # Some syntax is not exactable, just use its parent + if _is_not_extractable_syntax(start_node): + start_node = start_node.parent + + # Find the end + end_leaf = module_node.get_leaf_for_position(until_pos, include_prefixes=True) + if end_leaf.start_pos > until_pos: + end_leaf = end_leaf.get_previous_leaf() + if end_leaf is None: + raise RefactoringError('Cannot extract anything from that') + + parent_node = start_node + while parent_node.end_pos < end_leaf.end_pos: + parent_node = parent_node.parent + + nodes = _remove_unwanted_expression_nodes(parent_node, pos, until_pos) + + # If the user marks just a return statement, we return the expression + # instead of the whole statement, because the user obviously wants to + # extract that part. + if len(nodes) == 1 and start_node.type in ('return_stmt', 'yield_expr'): + return [nodes[0].children[1]] + return nodes + + +def _replace(nodes, expression_replacement, extracted, pos, insert_before_leaf=None): + # Now try to replace the nodes found with a variable and move the code + # before the current statement. + definition = _get_parent_definition(nodes[0]) + if insert_before_leaf is None: + insert_before_leaf = definition.get_first_leaf() + first_node_leaf = nodes[0].get_first_leaf() + + lines = split_lines(insert_before_leaf.prefix, keepends=True) + if first_node_leaf is insert_before_leaf: + removed_line_count = nodes[0].start_pos[0] - pos[0] + lines = lines[:-removed_line_count - 1] + [lines[-1]] + lines[-1:-1] = [indent_block(extracted, lines[-1]) + '\n'] + extracted_prefix = ''.join(lines) + + replacement_dct = {} + if first_node_leaf is insert_before_leaf: + replacement_dct[nodes[0]] = extracted_prefix + expression_replacement + else: + replacement_dct[nodes[0]] = first_node_leaf.prefix + expression_replacement + replacement_dct[insert_before_leaf] = extracted_prefix + insert_before_leaf.value + + for node in nodes[1:]: + replacement_dct[node] = '' + return replacement_dct + + +def _expression_nodes_to_string(nodes): + return ''.join(n.get_code(include_prefix=i != 0) for i, n in enumerate(nodes)) + + +def _suite_nodes_to_string(nodes, pos): + n = nodes[0] + included_line_count = n.start_pos[0] - pos[0] + lines = split_lines(n.get_first_leaf().prefix, keepends=True)[-included_line_count - 1] + return ''.join(lines) + n.get_code(include_prefix=False) \ + + ''.join(n.get_code() for n in nodes[1:]) + + +def _get_indentation(node): + return split_lines(node.get_first_leaf().prefix)[-1] + + +def _get_parent_definition(node): + """ + Returns the statement where a node is defined. + """ + while node is not None: + if node.parent.type in _DEFINITION_SCOPES: + return node + node = node.parent + raise NotImplementedError('We should never even get here') + + +def _remove_unwanted_expression_nodes(parent_node, pos, until_pos): + """ + This function makes it so for `1 * 2 + 3` you can extract `2 + 3`, even + though it is not part of the expression. + """ + typ = parent_node.type + is_suite_part = typ in ('suite', 'file_input') + if typ in EXPRESSION_PARTS or is_suite_part: + nodes = parent_node.children + for i, n in enumerate(nodes): + if n.end_pos > pos: + start_index = i + if n.type == 'operator': + start_index -= 1 + break + for i, n in reversed(list(enumerate(nodes))): + if n.start_pos < until_pos: + end_index = i + if n.type == 'operator': + end_index += 1 + + # Something like `not foo or bar` should not be cut after not + for n in nodes[i:]: + if _is_not_extractable_syntax(n): + end_index += 1 + else: + break + break + nodes = nodes[start_index:end_index + 1] + if not is_suite_part: + nodes[0:1] = _remove_unwanted_expression_nodes(nodes[0], pos, until_pos) + nodes[-1:] = _remove_unwanted_expression_nodes(nodes[-1], pos, until_pos) + return nodes + return [parent_node] + + +def _is_not_extractable_syntax(node): + return node.type == 'operator' \ + or node.type == 'keyword' and node.value not in ('None', 'True', 'False') + + +def extract_function(inference_state, path, module_context, name, pos, until_pos): + nodes = _find_nodes(module_context.tree_node, pos, until_pos) + is_expression, _ = _is_expression_with_error(nodes) + context = module_context.create_context(nodes[0]) + is_bound_method = context.is_bound_method() + params, return_variables = list(_find_inputs_and_outputs(context, nodes)) + + # Find variables + # Is a class method / method + if context.is_module(): + insert_before_leaf = None # Leaf will be determined later + else: + node = _get_code_insertion_node(context.tree_node, is_bound_method) + insert_before_leaf = node.get_first_leaf() + if is_expression: + code_block = 'return ' + _expression_nodes_to_string(nodes) + '\n' + else: + # Find the actually used variables (of the defined ones). If none are + # used (e.g. if the range covers the whole function), return the last + # defined variable. + return_variables = list(_find_needed_output_variables( + context, + nodes[0].parent, + nodes[-1].end_pos, + return_variables + )) or [return_variables[-1]] + + output_var_str = ', '.join(return_variables) + code_block = dedent(_suite_nodes_to_string(nodes, pos)) + code_block += 'return ' + output_var_str + '\n' + + decorator = '' + self_param = None + if is_bound_method: + if not function_is_staticmethod(context.tree_node): + function_param_names = context.get_value().get_param_names() + if len(function_param_names): + self_param = function_param_names[0].string_name + params = [p for p in params if p != self_param] + + if function_is_classmethod(context.tree_node): + decorator = '@classmethod\n' + else: + code_block += '\n' + + function_code = '%sdef %s(%s):\n%s' % ( + decorator, + name, + ', '.join(params if self_param is None else [self_param] + params), + indent_block(code_block) + ) + + function_call = '%s(%s)' % ( + ('' if self_param is None else self_param + '.') + name, + ', '.join(params) + ) + if is_expression: + replacement = function_call + else: + replacement = _get_indentation(nodes[0]) + output_var_str + ' = ' + function_call + + replaced_str = _replace(nodes, replacement, function_code, pos, insert_before_leaf) + file_to_node_changes = {path: replaced_str} + return Refactoring(inference_state.grammar, file_to_node_changes) + + +def _find_inputs_and_outputs(context, nodes): + inputs = [] + outputs = [] + for name in _find_non_global_names(nodes): + if name.is_definition(): + if name not in outputs: + outputs.append(name.value) + else: + if name.value not in inputs: + name_definitions = context.goto(name, name.start_pos) + if not name_definitions \ + or any(not n.parent_context.is_module() or n.api_type == 'param' + for n in name_definitions): + inputs.append(name.value) + + # Check if outputs are really needed: + return inputs, outputs + + +def _find_non_global_names(nodes): + for node in nodes: + try: + children = node.children + except AttributeError: + if node.type == 'name': + yield node + else: + # We only want to check foo in foo.bar + if node.type == 'trailer' and node.children[0] == '.': + continue + + for x in _find_non_global_names(children): # Python 2... + yield x + + +def _get_code_insertion_node(node, is_bound_method): + if not is_bound_method or function_is_staticmethod(node): + while node.parent.type != 'file_input': + node = node.parent + + while node.parent.type in ('async_funcdef', 'decorated', 'async_stmt'): + node = node.parent + return node + + +def _find_needed_output_variables(context, search_node, at_least_pos, return_variables): + """ + Searches everything after at_least_pos in a node and checks if any of the + return_variables are used in there and returns those. + """ + for node in search_node.children: + if node.start_pos < at_least_pos: + continue + + return_variables = set(return_variables) + for name in _find_non_global_names([node]): + if not name.is_definition() and name.value in return_variables: + return_variables.remove(name.value) + yield name.value