Add a Definition.is_definition function to be able to check if a name is a definition or not.

This commit is contained in:
Dave Halter
2014-09-08 23:44:35 +02:00
parent 851717a968
commit 0dcb91d236
3 changed files with 27 additions and 8 deletions

View File

@@ -695,12 +695,15 @@ def names(source=None, path=None, encoding='utf-8', all_scopes=False,
:param references: If True lists all the names that are not listed by :param references: If True lists all the names that are not listed by
``definitions=True``. E.g. ``a = b`` returns ``b``. ``definitions=True``. E.g. ``a = b`` returns ``b``.
""" """
def def_ref_filter(_def):
is_def = _def.is_definition()
return definitions and is_def or references and not is_def
# Set line/column to a random position, because they don't matter. # Set line/column to a random position, because they don't matter.
script = Script(source, line=1, column=0, path=path, encoding=encoding) script = Script(source, line=1, column=0, path=path, encoding=encoding)
defs = [classes.Definition(script._evaluator, name_part) defs = [classes.Definition(script._evaluator, name_part)
for name_part in get_module_name_parts(script._parser.module()) for name_part in get_module_name_parts(script._parser.module())]
if 390 < name_part.start_pos[0] < 399] return sorted(filter(def_ref_filter, defs), key=lambda x: (x.line, x.column))
return sorted(defs, key=lambda x: (x.line, x.column))
def preload_module(*modules): def preload_module(*modules):

View File

@@ -699,6 +699,19 @@ class Definition(use_metaclass(CachedMetaClass, BaseDefinition)):
iterable = list(iterable) iterable = list(iterable)
return list(chain.from_iterable(iterable)) return list(chain.from_iterable(iterable))
def is_definition(self):
"""
Returns True, if defined as a name in a statement, function or class.
Returns False, if it's a reference to such a definition.
"""
if not isinstance(self._definition, pr.NamePart):
# Currently only handle NameParts. Once we have a proper API, this
# will be the standard anyway.
raise NotImplementedError
stmt_or_imp = self._definition.get_parent_until((pr.ExprStmt, pr.Import))
exp_list = stmt_or_imp.expression_list()
return not exp_list or self._definition.start_pos < exp_list[0].start_pos
def __eq__(self, other): def __eq__(self, other):
return self._start_pos == other._start_pos \ return self._start_pos == other._start_pos \
and self.module_path == other.module_path \ and self.module_path == other.module_path \

View File

@@ -200,8 +200,11 @@ class TestGotoAssignments(TestCase):
function. They are not really different in functionality, but really function. They are not really different in functionality, but really
different as an implementation. different as an implementation.
""" """
def test_basic(self): def test_repetition(self):
refs = names('a = 1; a', references=True, definitions=False) defs = names('a = 1; a', references=True, definitions=False)
assert len(refs) == 1 # Repeat on the same variable. Shouldn't change once we're on a
ass = refs[0].goto_assignments() # definition.
assert ass[0].description == '' for _ in range(3):
assert len(defs) == 1
ass = defs[0].goto_assignments()
assert ass[0].description == 'a = 1'