From cc8483a07ac37a61dfb5a06ff07ed211fd022f37 Mon Sep 17 00:00:00 2001 From: Dave Halter Date: Sun, 23 Feb 2020 11:50:05 +0100 Subject: [PATCH] Fix extract issues when self is involved --- jedi/api/refactoring.py | 25 ++++++++++++++++++------- test/refactor/extract_function.py | 12 ++++++------ 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/jedi/api/refactoring.py b/jedi/api/refactoring.py index f39aeb6d..04c2e934 100644 --- a/jedi/api/refactoring.py +++ b/jedi/api/refactoring.py @@ -385,6 +385,7 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos nodes = _find_nodes(module_context.tree_node, pos, until_pos) return_variables = [] context = module_context.create_context(nodes[0]) + is_bound_method = context.is_bound_method() params = list(_find_non_global_names(context, nodes)) dct = {} @@ -393,7 +394,7 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos if context.is_module(): insert_before_leaf = None # Leaf will be determined later else: - node = _get_code_insertion_node(context) + node = _get_code_insertion_node(context.tree_node, is_bound_method) insert_before_leaf = node.get_first_leaf() if is_expression: code_block = 'return ' + _expression_nodes_to_string(nodes) + '\n' @@ -403,15 +404,26 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos code_block += '\nreturn ' + output_var_str + '\n' decorator = '' - if context.is_bound_method(): + self_param = None + if is_bound_method: + if not function_is_staticmethod(context.tree_node): + function_param_names = context.get_value().get_param_names() + if len(function_param_names): + self_param = function_param_names[0].string_name + if function_is_classmethod(context.tree_node): decorator = '@classmethod\n' else: code_block += '\n' - function_call = '%s(%s)' % (name, ', '.join(params)) - function_code = '%sdef %s:\n%s' % (decorator, function_call, indent_block(code_block)) + function_code = '%sdef %s(%s):\n%s' % ( + decorator, name, ', '.join(params), indent_block(code_block) + ) + function_call = '%s(%s)' % ( + ('' if self_param is None else self_param + '.') + name, + ', '.join(p for p in params if p != self_param) + ) if is_expression: replacement = function_call else: @@ -436,9 +448,8 @@ def _find_non_global_names(context, nodes): yield x -def _get_code_insertion_node(context): - node = context.tree_node - if not context.is_bound_method() or function_is_staticmethod(node): +def _get_code_insertion_node(node, is_bound_method): + if not is_bound_method or function_is_staticmethod(node): while node.parent.type != 'file_input': node = node.parent diff --git a/test/refactor/extract_function.py b/test/refactor/extract_function.py index 3941fabe..3e4330cf 100644 --- a/test/refactor/extract_function.py +++ b/test/refactor/extract_function.py @@ -48,19 +48,19 @@ def f(x): class X: def z(self): pass - def f(x): + def f(x, b): #? 11 text {'new_name': 'ab'} - return x + 1 * 2 + return x + b * 2 # ++++++++++++++++++++++++++++++++++++++++++++++++++ class X: def z(self): pass - def ab(x): - return x + 1 * 2 + def ab(x, b): + return x + b * 2 - def f(x): + def f(x, b): #? 11 text {'new_name': 'ab'} - return ab(x) + return x.ab(b) # -------------------------------------------------- in-classmethod-1 class X: @classmethod