diff --git a/jedi/api/refactoring.py b/jedi/api/refactoring.py index 7372887c..6e541f61 100644 --- a/jedi/api/refactoring.py +++ b/jedi/api/refactoring.py @@ -8,6 +8,7 @@ from parso import split_lines from jedi import debug from jedi.api.exceptions import RefactoringError from jedi.common.utils import indent_block +from jedi.parser_utils import function_is_classmethod, function_is_staticmethod _EXPRESSION_PARTS = ( 'or_test and_test not_test comparison ' @@ -380,8 +381,6 @@ def _is_not_extractable_syntax(node): def extract_function(inference_state, path, module_context, name, pos, until_pos): - is_class_method = False - is_method = False is_expression = True nodes = _find_nodes(module_context.tree_node, pos, until_pos) return_variables = [] @@ -405,13 +404,14 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos output_var_str = ', '.join(return_variables) code_block += '\nreturn ' + output_var_str + '\n' - if not context.is_bound_method(): + decorator = '' + if context.is_bound_method(): + if function_is_classmethod(context.tree_node): + decorator = '@classmethod\n' + else: code_block += '\n' function_call = '%s(%s)' % (name, ', '.join(params)) - decorator = '' - if is_class_method: - decorator = '@classmethod\n' function_code = '%sdef %s:\n%s' % (decorator, function_call, indent_block(code_block)) if is_expression: diff --git a/test/refactor/extract_function.py b/test/refactor/extract_function.py index 26332e68..9bc242f3 100644 --- a/test/refactor/extract_function.py +++ b/test/refactor/extract_function.py @@ -61,3 +61,19 @@ class X: def f(x): #? 11 text {'new_name': 'ab'} return ab() +# -------------------------------------------------- in-classmethod-1 +class X: + @classmethod + def f(x): + #? 16 text {'new_name': 'ab'} + return 25 +# ++++++++++++++++++++++++++++++++++++++++++++++++++ +class X: + @classmethod + def ab(): + return 25 + + @classmethod + def f(x): + #? 16 text {'new_name': 'ab'} + return x.ab()