Add include_setitem for get_defined_names, is_definition and get_definition

This commit is contained in:
Dave Halter
2019-08-28 22:38:01 +02:00
parent e8653a49ff
commit 8f83e9b3c5
2 changed files with 25 additions and 20 deletions

View File

@@ -200,13 +200,13 @@ class Name(_LeafWithoutNewlines):
return "<%s: %s@%s,%s>" % (type(self).__name__, self.value,
self.line, self.column)
def is_definition(self):
def is_definition(self, include_setitem=False):
"""
Returns True if the name is being defined.
"""
return self.get_definition() is not None
return self.get_definition(include_setitem=include_setitem) is not None
def get_definition(self, import_name_always=False):
def get_definition(self, import_name_always=False, include_setitem=False):
"""
Returns None if there's on definition for a name.
@@ -234,7 +234,7 @@ class Name(_LeafWithoutNewlines):
if node.type == 'suite':
return None
if node.type in _GET_DEFINITION_TYPES:
if self in node.get_defined_names():
if self in node.get_defined_names(include_setitem):
return node
if import_name_always and node.type in _IMPORTS:
return node
@@ -772,8 +772,8 @@ class ForStmt(Flow):
"""
return self.children[3]
def get_defined_names(self):
return _defined_names(self.children[1])
def get_defined_names(self, include_setitem=False):
return _defined_names(self.children[1], include_setitem)
class TryStmt(Flow):
@@ -796,7 +796,7 @@ class WithStmt(Flow):
type = 'with_stmt'
__slots__ = ()
def get_defined_names(self):
def get_defined_names(self, include_setitem=False):
"""
Returns the a list of `Name` that the with statement defines. The
defined names are set after `as`.
@@ -805,7 +805,7 @@ class WithStmt(Flow):
for with_item in self.children[1:-2:2]:
# Check with items for 'as' names.
if with_item.type == 'with_item':
names += _defined_names(with_item.children[2])
names += _defined_names(with_item.children[2], include_setitem)
return names
def get_test_node_from_name(self, name):
@@ -846,7 +846,7 @@ class ImportFrom(Import):
type = 'import_from'
__slots__ = ()
def get_defined_names(self):
def get_defined_names(self, include_setitem=False):
"""
Returns the a list of `Name` that the import defines. The
defined names are set after `import` or in case an alias - `as` - is
@@ -917,7 +917,7 @@ class ImportName(Import):
type = 'import_name'
__slots__ = ()
def get_defined_names(self):
def get_defined_names(self, include_setitem=False):
"""
Returns the a list of `Name` that the import defines. The defined names
is always the first name after `import` or in case an alias - `as` - is
@@ -1018,7 +1018,7 @@ class YieldExpr(PythonBaseNode):
__slots__ = ()
def _defined_names(current):
def _defined_names(current, include_setitem):
"""
A helper function to find the defined names in statements, for loops and
list comprehensions.
@@ -1026,15 +1026,15 @@ def _defined_names(current):
names = []
if current.type in ('testlist_star_expr', 'testlist_comp', 'exprlist', 'testlist'):
for child in current.children[::2]:
names += _defined_names(child)
names += _defined_names(child, include_setitem)
elif current.type in ('atom', 'star_expr'):
names += _defined_names(current.children[1])
names += _defined_names(current.children[1], include_setitem)
elif current.type in ('power', 'atom_expr'):
if current.children[-2] != '**': # Just if there's no operation
trailer = current.children[-1]
if trailer.children[0] == '.':
names.append(trailer.children[1])
elif trailer.children[0] == '[':
elif trailer.children[0] == '[' and include_setitem:
for node in current.children[-2::-1]:
if node.type == 'trailer':
names.append(node.children[1])
@@ -1051,18 +1051,18 @@ class ExprStmt(PythonBaseNode, DocstringMixin):
type = 'expr_stmt'
__slots__ = ()
def get_defined_names(self):
def get_defined_names(self, include_setitem=False):
"""
Returns a list of `Name` defined before the `=` sign.
"""
names = []
if self.children[1].type == 'annassign':
names = _defined_names(self.children[0])
names = _defined_names(self.children[0], include_setitem)
return [
name
for i in range(0, len(self.children) - 2, 2)
if '=' in self.children[i + 1].value
for name in _defined_names(self.children[i])
for name in _defined_names(self.children[i], include_setitem)
] + names
def get_rhs(self):
@@ -1155,7 +1155,7 @@ class Param(PythonBaseNode):
else:
return self._tfpdef()
def get_defined_names(self):
def get_defined_names(self, include_setitem=False):
return [self.name]
@property
@@ -1213,12 +1213,12 @@ class SyncCompFor(PythonBaseNode):
type = 'sync_comp_for'
__slots__ = ()
def get_defined_names(self):
def get_defined_names(self, include_setitem=False):
"""
Returns the a list of `Name` that the comprehension defines.
"""
# allow async for
return _defined_names(self.children[1])
return _defined_names(self.children[1], include_setitem)
# This is simply here so an older Jedi version can work with this new parso