again import tests

This commit is contained in:
David Halter
2012-08-02 20:09:45 +02:00
parent 24f81ea75c
commit d8c0b8f8e9
6 changed files with 69 additions and 10 deletions

View File

@@ -18,13 +18,15 @@ class ImportPath(object):
pass pass
def __init__(self, import_stmt, is_like_search=False): def __init__(self, import_stmt, is_like_search=False):
""" replace """ self.import_stmt = import_stmt
#print import_stmt
self.import_path = [] self.import_path = []
if import_stmt.from_ns: if import_stmt.from_ns:
self.import_path += import_stmt.from_ns.names self.import_path += import_stmt.from_ns.names
if import_stmt.namespace: if import_stmt.namespace:
self.import_path += import_stmt.namespace.names if self.is_nested_import():
self.import_path.append(import_stmt.namespace.names[0])
else:
self.import_path += import_stmt.namespace.names
self.is_like_search = is_like_search self.is_like_search = is_like_search
if is_like_search: if is_like_search:
@@ -33,6 +35,27 @@ class ImportPath(object):
self.file_path = os.path.dirname(import_stmt.get_parent_until().path) self.file_path = os.path.dirname(import_stmt.get_parent_until().path)
def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self.import_stmt)
def is_nested_import(self):
"""
This checks for the special case of nested imports, without aliases and
from statement:
>>> import foo.bar
"""
return not self.import_stmt.alias and not self.import_stmt.from_ns \
and len(self.import_stmt.namespace.names) > 1
def get_nested_import(self, parent):
i = self.import_stmt
zero = (1,0)
n = parsing.Name(i.namespace.names[1:], zero, zero)
new = parsing.Import(zero, zero, n)
new.parent = parent
debug.dbg('Generated a nested import: %s' % new)
return new
def get_defined_names(self): def get_defined_names(self):
names = [] names = []
for scope in self.follow(): for scope in self.follow():
@@ -43,10 +66,12 @@ class ImportPath(object):
for s, n in evaluate.get_names_for_scope(scope, for s, n in evaluate.get_names_for_scope(scope,
include_builtin=False): include_builtin=False):
names += n names += n
if isinstance(scope, parsing.Module) \ #print s, n, n[0].parent
and scope.path.endswith('__init__.py'): #if isinstance(scope, parsing.Module) \
names += \ # and scope.path.endswith('__init__.py'):
self.get_module_names([os.path.dirname(scope.path)]) # names += \
# self.get_module_names([os.path.dirname(scope.path)])
# print names
return names return names
def get_module_names(self, search_path=None): def get_module_names(self, search_path=None):
@@ -72,6 +97,9 @@ class ImportPath(object):
for scope in scopes: for scope in scopes:
new += remove_star_imports(scope) new += remove_star_imports(scope)
scopes += new scopes += new
if self.is_nested_import():
scopes.append(self.get_nested_import(scope))
else: else:
scopes = [ImportPath.GlobalNamespace] scopes = [ImportPath.GlobalNamespace]
debug.dbg('after import', scopes) debug.dbg('after import', scopes)

View File

@@ -499,7 +499,14 @@ class Import(Simple):
return [] return []
if self.star: if self.star:
return [self] return [self]
return [self.alias] if self.alias else [self.namespace] if self.alias:
return [self.alias]
if len(self.namespace) > 1:
o = self.namespace
n = Name([o.names[0]], o.start_pos, o.end_pos, parent=o.parent)
return [n]
else:
return [self.namespace]
class Statement(Simple): class Statement(Simple):
@@ -901,9 +908,10 @@ class Name(Simple):
So a name like "module.class.function" So a name like "module.class.function"
would result in an array of [module, class, function] would result in an array of [module, class, function]
""" """
def __init__(self, names, start_pos, end_pos): def __init__(self, names, start_pos, end_pos, parent=None):
super(Name, self).__init__(start_pos, end_pos) super(Name, self).__init__(start_pos, end_pos)
self.names = tuple(NamePart(n) for n in names) self.names = tuple(NamePart(n) for n in names)
self.parent = parent
def get_code(self): def get_code(self):
""" Returns the names in a full string format """ """ Returns the names in a full string format """

View File

@@ -78,6 +78,9 @@ import import_tree
#! ['module mod1'] #! ['module mod1']
import import_tree.mod1 import import_tree.mod1
#! ['module mod1']
import import_tree.pkg.mod1
#! ['module mod1'] #! ['module mod1']
from import_tree import mod1 from import_tree import mod1

View File

@@ -0,0 +1 @@
a = list

View File

@@ -0,0 +1 @@
a = 1.0

View File

@@ -22,7 +22,7 @@ def scope_basic():
#? [] #? []
import_tree.mod1 import_tree.mod1
def scope_nested(): def scope_pkg():
import import_tree.mod1 import import_tree.mod1
#? str() #? str()
@@ -34,6 +34,24 @@ def scope_nested():
#? int() #? int()
import_tree.mod1.a import_tree.mod1.a
def scope_nested():
import import_tree.pkg.mod1
#? str()
import_tree.a
#? list
import_tree.pkg.a
#? ['a', 'pkg']
import_tree.
#? float()
import_tree.pkg.mod1.a
#? ['a', 'pkg']
import_tree.
# ----------------- # -----------------
# std lib modules # std lib modules
# ----------------- # -----------------