diff --git a/jedi/api.py b/jedi/api.py index 99a306d5..6ed1f0dc 100644 --- a/jedi/api.py +++ b/jedi/api.py @@ -61,13 +61,16 @@ class Script(object): path = source_path lines = source.splitlines() - line = len(lines) if line is None else line - column = len(lines[-1]) if column is None else column + if source and source[-1] == '\n': + lines.append('') + + self._line = max(len(lines), 1) if line is None else line + self._column = len(lines[-1]) if column is None else column api_classes._clear_caches() debug.reset_time() self.source = modules.source_to_unicode(source, source_encoding) - self.pos = line, column + self.pos = self._line, self._column self._module = modules.ModuleWithCursor( path, source=self.source, position=self.pos) self._source_path = path @@ -101,13 +104,24 @@ class Script(object): :return: Completion objects, sorted by name and __ comes last. :rtype: list of :class:`api_classes.Completion` """ + def get_completions(user_stmt, bs): + if isinstance(user_stmt, pr.Import): + context = self._module.get_context() + next(context) # skip the path + if next(context) == 'from': + # completion is just "import" if before stands from .. + return ((k, bs) for k in keywords.keyword_names('import')) + return self._simple_complete(path, like) + debug.speed('completions start') path = self._module.get_path_until_cursor() if re.search('^\.|\.\.$', path): return [] path, dot, like = self._get_completion_parts() - completions = self._simple_complete(path, like) + user_stmt = self._user_stmt(True) + bs = builtin.Builtin.scope + completions = get_completions(user_stmt, bs) if not dot: # named params have no dots for call_def in self.call_signatures(): @@ -115,17 +129,7 @@ class Script(object): for p in call_def.params: completions.append((p.get_name(), p)) - # Do the completion if there is no path before and no import stmt. - u = self._user_stmt(True) - bs = builtin.Builtin.scope - if isinstance(u, pr.Import): - completion_line = self._module.get_position_line() - if (u.relative_count > 0 or u.from_ns) and not re.search( - r'(,|from)\s*$|import\s+$', completion_line): - completions += ((k, bs) for k - in keywords.keyword_names('import')) - - if not path and not isinstance(u, pr.Import): + if not path and not isinstance(user_stmt, pr.Import): # add keywords completions += ((k, bs) for k in keywords.keyword_names( all=True)) @@ -140,7 +144,7 @@ class Script(object): and n.lower().startswith(like.lower()) \ or n.startswith(like): if not evaluate.filter_private_variable(s, - self._user_stmt(True) or self._parser.user_scope, n): + user_stmt or self._parser.user_scope, n): new = api_classes.Completion(c, needs_dot, len(like), s) k = (new.name, new.complete) # key if k in comp_dct and settings.no_completion_duplicates: @@ -196,11 +200,9 @@ class Script(object): if is_completion and not user_stmt: # for statements like `from x import ` (cursor not in statement) - line = self._module.get_position_line() - pos = self.pos[0], len(line) - len(re.search(' *$', line).group(0)) - # check the last statement - last_stmt = self._parser.module.get_statement_for_position(pos, - include_imports=True) + pos = next(self._module.get_context(yield_positions=True)) + last_stmt = pos and self._parser.module.get_statement_for_position( + pos, include_imports=True) if isinstance(last_stmt, pr.Import): user_stmt = last_stmt return user_stmt diff --git a/jedi/modules.py b/jedi/modules.py index 657650fb..15737752 100644 --- a/jedi/modules.py +++ b/jedi/modules.py @@ -122,8 +122,8 @@ class ModuleWithCursor(Module): def get_path_until_cursor(self): """ Get the path under the cursor. """ if self._path_until_cursor is None: # small caching - self._path_until_cursor = self._get_path_until_cursor() - self._start_cursor_pos = self._start_cursor_pos_temp + self._path_until_cursor, self._start_cursor_pos = \ + self._get_path_until_cursor(self.position) return self._path_until_cursor def _get_path_until_cursor(self, start_pos=None): @@ -149,9 +149,8 @@ class ModuleWithCursor(Module): return line[::-1] self._is_first = True - self._line_temp, self._column_temp = start_pos or self.position + self._line_temp, self._column_temp = start_cursor = start_pos self._first_line = self.get_line(self._line_temp)[:self._column_temp] - self._start_cursor_pos_temp = self.position open_brackets = ['(', '[', '{'] close_brackets = [')', ']', '}'] @@ -191,10 +190,10 @@ class ModuleWithCursor(Module): self._column_temp = self._line_length - end[1] break - x = self.position[0] - end[0] + 1 + x = start_pos[0] - end[0] + 1 l = self.get_line(x) - l = self._first_line if x == self.position[0] else l - self._start_cursor_pos_temp = x, len(l) - end[1] + l = self._first_line if x == start_pos[0] else l + start_cursor = x, len(l) - end[1] self._column_temp = self._line_length - end[1] string += tok last_type = token_type @@ -202,7 +201,7 @@ class ModuleWithCursor(Module): debug.warning("Tokenize couldn't finish", sys.exc_info) # string can still contain spaces at the end - return string[::-1].strip() + return string[::-1].strip(), start_cursor def get_path_under_cursor(self): """ @@ -222,27 +221,44 @@ class ModuleWithCursor(Module): return (before.group(0) if before is not None else '') \ + (after.group(0) if after is not None else '') - def get_context(self): + def get_context(self, yield_positions=False): pos = self._start_cursor_pos - while pos > (1, 0): + while True: # remove non important white space line = self.get_line(pos[0]) - while pos[1] > 0 and line[pos[1] - 1].isspace(): - pos = pos[0], pos[1] - 1 + while True: + if pos[1] == 0: + line = self.get_line(pos[0] - 1) + if line and line[-1] == '\\': + pos = pos[0] - 1, len(line) - 1 + continue + else: + break + + if line[pos[1] - 1].isspace(): + pos = pos[0], pos[1] - 1 + else: + break try: - yield self._get_path_until_cursor(start_pos=pos) + result, pos = self._get_path_until_cursor(start_pos=pos) + if yield_positions: + yield pos + else: + yield result except StopIteration: - yield '' - pos = self._line_temp, self._column_temp - - while True: - yield '' + if yield_positions: + yield None + else: + yield '' def get_line(self, line_nr): if not self._line_cache: self._line_cache = self.source.splitlines() - if not self.source: # ''.splitlines() == [] + if self.source: + if self.source[-1] == '\n': + self._line_cache.append('') + else: # ''.splitlines() == [] self._line_cache = [''] if line_nr == 0: diff --git a/test/test_integration_import.py b/test/test_integration_import.py index 121b6253..401b7945 100644 --- a/test/test_integration_import.py +++ b/test/test_integration_import.py @@ -56,6 +56,19 @@ def test_goto_following_on_imports(): assert (g[0].line, g[0].column) != (0, 0) +def test_after_from(): + def check(source, result, column=None): + completions = Script(source, column=column).completions() + assert [c.name for c in completions] == result + + check('from os ', ['import']) + check('\nfrom os ', ['import']) + check('\nfrom os import whatever', ['import'], len('from os im')) + + check('from os\\\n', ['import']) + check('from os \\\n', ['import']) + + def test_follow_definition(): """ github issue #45 """ c = Script("from datetime import timedelta; timedelta").completions()