Refactoring: Make Import.get_all_import_names return NameParts.

This commit is contained in:
Dave Halter
2014-09-19 01:40:05 +02:00
parent 83d2af5138
commit b2342c76be
5 changed files with 13 additions and 28 deletions

View File

@@ -28,14 +28,14 @@ def get_on_import_stmt(evaluator, user_context, user_stmt, is_like_search=False)
import_names = user_stmt.get_all_import_names() import_names = user_stmt.get_all_import_names()
kill_count = -1 kill_count = -1
cur_name_part = None cur_name_part = None
for i in import_names: for name in import_names:
if user_stmt.alias == i: if user_stmt.alias_name_part == name:
continue continue
for name_part in i.names:
if name_part.end_pos >= user_context.position: if name.end_pos >= user_context.position:
if not cur_name_part: if not cur_name_part:
cur_name_part = name_part cur_name_part = name
kill_count += 1 kill_count += 1
context = user_context.get_context() context = user_context.get_context()
just_from = next(context) == 'from' just_from = next(context) == 'from'

View File

@@ -70,11 +70,10 @@ def usages(evaluator, definitions, mods):
if isinstance(stmt, pr.Import): if isinstance(stmt, pr.Import):
count = 0 count = 0
imps = [] imps = []
for i in stmt.get_all_import_names(): for name in stmt.get_all_import_names():
for name_part in i.names: count += 1
count += 1 if unicode(name) == search_name:
if unicode(name_part) == search_name: imps.append((count, name))
imps.append((count, name_part))
for used_count, name_part in imps: for used_count, name_part in imps:
i = imports.ImportWrapper(evaluator, stmt, kill_count=count - used_count, i = imports.ImportWrapper(evaluator, stmt, kill_count=count - used_count,

View File

@@ -323,7 +323,7 @@ class Evaluator(object):
if stmt.alias_name_part == call_path[0]: if stmt.alias_name_part == call_path[0]:
return [call_path[0]] return [call_path[0]]
names = stmt.get_all_import_name_parts() names = stmt.get_all_import_names()
# Filter names that are after our Name # Filter names that are after our Name
removed_names = len(names) - names.index(call_path[0]) - 1 removed_names = len(names) - names.index(call_path[0]) - 1
i = imports.ImportWrapper(self, stmt, kill_count=removed_names) i = imports.ImportWrapper(self, stmt, kill_count=removed_names)

View File

@@ -226,7 +226,7 @@ def get_module_name_parts(module):
for stmt_or_import in statements_or_imports: for stmt_or_import in statements_or_imports:
if isinstance(stmt_or_import, pr.Import): if isinstance(stmt_or_import, pr.Import):
for name in stmt_or_import.get_all_import_names(): for name in stmt_or_import.get_all_import_names():
name_parts.update(name.names) name_parts.add(name)
else: else:
# Running this ensures that all the expression lists are generated # Running this ensures that all the expression lists are generated
# and the parents are all set. (Important for Lambdas) Howeer, this # and the parents are all set. (Important for Lambdas) Howeer, this

View File

@@ -831,20 +831,6 @@ class Import(Simple):
return [self.namespace] return [self.namespace]
def get_all_import_names(self): def get_all_import_names(self):
n = []
if self.from_ns:
n.append(self.from_ns)
if self.namespace:
n.append(self.namespace)
if self.alias:
n.append(self.alias)
return n
def get_all_import_name_parts(self):
"""
TODO refactor and use this method, because NamePart will not exist in
the future.
"""
n = [] n = []
if self.from_ns: if self.from_ns:
n += self.from_ns.names n += self.from_ns.names