diff --git a/jedi/helpers.py b/jedi/helpers.py index 2c057ef8..259b99bb 100644 --- a/jedi/helpers.py +++ b/jedi/helpers.py @@ -1,5 +1,4 @@ import copy -import contextlib import parsing import evaluate @@ -199,19 +198,39 @@ def generate_param_array(args_tuple, parent_stmt=None): return arr +def check_arr_index(arr, pos): + positions = arr.arr_el_pos + for index, comma_pos in enumerate(positions): + if pos < comma_pos: + return index + return len(positions) + + +def array_for_pos(arr, pos): + if arr.start_pos >= pos \ + or arr.end_pos[0] is not None and pos >= arr.end_pos: + return None, None + + result = arr + for sub in arr: + for s in sub: + if isinstance(s, parsing.Array): + result = array_for_pos(s, pos)[0] or result + elif isinstance(s, parsing.Call): + if s.execution: + result = array_for_pos(s.execution, pos)[0] or result + if s.next: + result = array_for_pos(s.next, pos)[0] or result + + return result, check_arr_index(result, pos) + + def scan_array_for_pos(arr, pos, overwrite_after=False): """ Returns the function Call that matches the position before `arr`. This is somehow stupid, probably only the name of the function. :param overwrite_after: Overwrite every statement after the found array. """ - def check_arr_index(): - positions = arr.arr_el_pos - for index, comma_pos in enumerate(positions): - if pos < comma_pos: - return index - return len(positions) - call = None stop = False for sub in arr.values: @@ -228,7 +247,7 @@ def scan_array_for_pos(arr, pos, overwrite_after=False): # check parts of calls while s is not None: if s.start_pos >= pos: - return call, check_arr_index(), stop + return call, check_arr_index(arr, pos), stop elif s.execution is not None: end = s.execution.end_pos if s.execution.start_pos < pos and \ @@ -254,4 +273,4 @@ def scan_array_for_pos(arr, pos, overwrite_after=False): # The third return is just necessary for recursion inside, because # it needs to know when to stop iterating. - return call, check_arr_index(), stop + return call, check_arr_index(arr, pos), stop diff --git a/jedi/refactoring.py b/jedi/refactoring.py index b496c03c..c36232cb 100644 --- a/jedi/refactoring.py +++ b/jedi/refactoring.py @@ -96,15 +96,11 @@ def extract(script, new_name): if user_stmt: pos = script.pos line_index = pos[0] - 1 - import parsing - assert isinstance(user_stmt, parsing.Statement) - call, index, stop = helpers.scan_array_for_pos( - user_stmt.get_assignment_calls(), pos) - assert isinstance(call, parsing.Call) - exe = call.execution - if exe: - s = exe.start_pos[0], exe.start_pos[1] + 1 - positions = [s] + call.execution.arr_el_pos + [exe.end_pos] + arr, index = helpers.array_for_pos(user_stmt.get_assignment_calls(), + pos) + if arr: + s = arr.start_pos[0], arr.start_pos[1] + 1 + positions = [s] + arr.arr_el_pos + [arr.end_pos] start_pos = positions[index] end_pos = positions[index + 1][0], positions[index + 1][1] - 1