diff --git a/imports.py b/imports.py index 6b4b83a1..22c6b5e8 100644 --- a/imports.py +++ b/imports.py @@ -18,13 +18,15 @@ class ImportPath(object): pass def __init__(self, import_stmt, is_like_search=False): - """ replace """ - #print import_stmt + self.import_stmt = import_stmt self.import_path = [] if import_stmt.from_ns: self.import_path += import_stmt.from_ns.names if import_stmt.namespace: - self.import_path += import_stmt.namespace.names + if self.is_nested_import(): + self.import_path.append(import_stmt.namespace.names[0]) + else: + self.import_path += import_stmt.namespace.names self.is_like_search = is_like_search if is_like_search: @@ -33,6 +35,27 @@ class ImportPath(object): self.file_path = os.path.dirname(import_stmt.get_parent_until().path) + def __repr__(self): + return '<%s: %s>' % (self.__class__.__name__, self.import_stmt) + + def is_nested_import(self): + """ + This checks for the special case of nested imports, without aliases and + from statement: + >>> import foo.bar + """ + return not self.import_stmt.alias and not self.import_stmt.from_ns \ + and len(self.import_stmt.namespace.names) > 1 + + def get_nested_import(self, parent): + i = self.import_stmt + zero = (1,0) + n = parsing.Name(i.namespace.names[1:], zero, zero) + new = parsing.Import(zero, zero, n) + new.parent = parent + debug.dbg('Generated a nested import: %s' % new) + return new + def get_defined_names(self): names = [] for scope in self.follow(): @@ -43,10 +66,12 @@ class ImportPath(object): for s, n in evaluate.get_names_for_scope(scope, include_builtin=False): names += n - if isinstance(scope, parsing.Module) \ - and scope.path.endswith('__init__.py'): - names += \ - self.get_module_names([os.path.dirname(scope.path)]) + #print s, n, n[0].parent + #if isinstance(scope, parsing.Module) \ + # and scope.path.endswith('__init__.py'): + # names += \ + # self.get_module_names([os.path.dirname(scope.path)]) + # print names return names def get_module_names(self, search_path=None): @@ -72,6 +97,9 @@ class ImportPath(object): for scope in scopes: new += remove_star_imports(scope) scopes += new + + if self.is_nested_import(): + scopes.append(self.get_nested_import(scope)) else: scopes = [ImportPath.GlobalNamespace] debug.dbg('after import', scopes) diff --git a/parsing.py b/parsing.py index dc36b5e4..13491bfe 100644 --- a/parsing.py +++ b/parsing.py @@ -499,7 +499,14 @@ class Import(Simple): return [] if self.star: return [self] - return [self.alias] if self.alias else [self.namespace] + if self.alias: + return [self.alias] + if len(self.namespace) > 1: + o = self.namespace + n = Name([o.names[0]], o.start_pos, o.end_pos, parent=o.parent) + return [n] + else: + return [self.namespace] class Statement(Simple): @@ -901,9 +908,10 @@ class Name(Simple): So a name like "module.class.function" would result in an array of [module, class, function] """ - def __init__(self, names, start_pos, end_pos): + def __init__(self, names, start_pos, end_pos, parent=None): super(Name, self).__init__(start_pos, end_pos) self.names = tuple(NamePart(n) for n in names) + self.parent = parent def get_code(self): """ Returns the names in a full string format """ diff --git a/test/completion/goto.py b/test/completion/goto.py index 912fad48..3fccfa6f 100644 --- a/test/completion/goto.py +++ b/test/completion/goto.py @@ -78,6 +78,9 @@ import import_tree #! ['module mod1'] import import_tree.mod1 +#! ['module mod1'] +import import_tree.pkg.mod1 + #! ['module mod1'] from import_tree import mod1 diff --git a/test/completion/import_tree/pkg/__init__.py b/test/completion/import_tree/pkg/__init__.py new file mode 100644 index 00000000..e8e57d1c --- /dev/null +++ b/test/completion/import_tree/pkg/__init__.py @@ -0,0 +1 @@ +a = list diff --git a/test/completion/import_tree/pkg/mod1.py b/test/completion/import_tree/pkg/mod1.py new file mode 100644 index 00000000..2b6777c8 --- /dev/null +++ b/test/completion/import_tree/pkg/mod1.py @@ -0,0 +1 @@ +a = 1.0 diff --git a/test/completion/imports.py b/test/completion/imports.py index b9db69f0..e7403cb1 100644 --- a/test/completion/imports.py +++ b/test/completion/imports.py @@ -22,7 +22,7 @@ def scope_basic(): #? [] import_tree.mod1 -def scope_nested(): +def scope_pkg(): import import_tree.mod1 #? str() @@ -34,6 +34,24 @@ def scope_nested(): #? int() import_tree.mod1.a +def scope_nested(): + import import_tree.pkg.mod1 + + #? str() + import_tree.a + + #? list + import_tree.pkg.a + + #? ['a', 'pkg'] + import_tree. + + #? float() + import_tree.pkg.mod1.a + + #? ['a', 'pkg'] + import_tree. + # ----------------- # std lib modules # -----------------