Fix extract issues when self is involved

This commit is contained in:
Dave Halter
2020-02-23 11:50:05 +01:00
parent 48c4262f66
commit cc8483a07a
2 changed files with 24 additions and 13 deletions

View File

@@ -385,6 +385,7 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos
nodes = _find_nodes(module_context.tree_node, pos, until_pos)
return_variables = []
context = module_context.create_context(nodes[0])
is_bound_method = context.is_bound_method()
params = list(_find_non_global_names(context, nodes))
dct = {}
@@ -393,7 +394,7 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos
if context.is_module():
insert_before_leaf = None # Leaf will be determined later
else:
node = _get_code_insertion_node(context)
node = _get_code_insertion_node(context.tree_node, is_bound_method)
insert_before_leaf = node.get_first_leaf()
if is_expression:
code_block = 'return ' + _expression_nodes_to_string(nodes) + '\n'
@@ -403,15 +404,26 @@ def extract_function(inference_state, path, module_context, name, pos, until_pos
code_block += '\nreturn ' + output_var_str + '\n'
decorator = ''
if context.is_bound_method():
self_param = None
if is_bound_method:
if not function_is_staticmethod(context.tree_node):
function_param_names = context.get_value().get_param_names()
if len(function_param_names):
self_param = function_param_names[0].string_name
if function_is_classmethod(context.tree_node):
decorator = '@classmethod\n'
else:
code_block += '\n'
function_call = '%s(%s)' % (name, ', '.join(params))
function_code = '%sdef %s:\n%s' % (decorator, function_call, indent_block(code_block))
function_code = '%sdef %s(%s):\n%s' % (
decorator, name, ', '.join(params), indent_block(code_block)
)
function_call = '%s(%s)' % (
('' if self_param is None else self_param + '.') + name,
', '.join(p for p in params if p != self_param)
)
if is_expression:
replacement = function_call
else:
@@ -436,9 +448,8 @@ def _find_non_global_names(context, nodes):
yield x
def _get_code_insertion_node(context):
node = context.tree_node
if not context.is_bound_method() or function_is_staticmethod(node):
def _get_code_insertion_node(node, is_bound_method):
if not is_bound_method or function_is_staticmethod(node):
while node.parent.type != 'file_input':
node = node.parent

View File

@@ -48,19 +48,19 @@ def f(x):
class X:
def z(self): pass
def f(x):
def f(x, b):
#? 11 text {'new_name': 'ab'}
return x + 1 * 2
return x + b * 2
# ++++++++++++++++++++++++++++++++++++++++++++++++++
class X:
def z(self): pass
def ab(x):
return x + 1 * 2
def ab(x, b):
return x + b * 2
def f(x):
def f(x, b):
#? 11 text {'new_name': 'ab'}
return ab(x)
return x.ab(b)
# -------------------------------------------------- in-classmethod-1
class X:
@classmethod