Add include_setitem for get_defined_names, is_definition and get_definition

This commit is contained in:
Dave Halter
2019-08-28 22:38:01 +02:00
parent e8653a49ff
commit 8f83e9b3c5
2 changed files with 25 additions and 20 deletions
+5
View File
@@ -3,6 +3,11 @@
Changelog Changelog
--------- ---------
0.5.2 (2019-XX-XX)
++++++++++++++++++
- Add include_setitem to get_definition/is_definition and get_defined_names (#66)
0.5.1 (2019-07-13) 0.5.1 (2019-07-13)
++++++++++++++++++ ++++++++++++++++++
+20 -20
View File
@@ -200,13 +200,13 @@ class Name(_LeafWithoutNewlines):
return "<%s: %s@%s,%s>" % (type(self).__name__, self.value, return "<%s: %s@%s,%s>" % (type(self).__name__, self.value,
self.line, self.column) self.line, self.column)
def is_definition(self): def is_definition(self, include_setitem=False):
""" """
Returns True if the name is being defined. Returns True if the name is being defined.
""" """
return self.get_definition() is not None return self.get_definition(include_setitem=include_setitem) is not None
def get_definition(self, import_name_always=False): def get_definition(self, import_name_always=False, include_setitem=False):
""" """
Returns None if there's on definition for a name. Returns None if there's on definition for a name.
@@ -234,7 +234,7 @@ class Name(_LeafWithoutNewlines):
if node.type == 'suite': if node.type == 'suite':
return None return None
if node.type in _GET_DEFINITION_TYPES: if node.type in _GET_DEFINITION_TYPES:
if self in node.get_defined_names(): if self in node.get_defined_names(include_setitem):
return node return node
if import_name_always and node.type in _IMPORTS: if import_name_always and node.type in _IMPORTS:
return node return node
@@ -772,8 +772,8 @@ class ForStmt(Flow):
""" """
return self.children[3] return self.children[3]
def get_defined_names(self): def get_defined_names(self, include_setitem=False):
return _defined_names(self.children[1]) return _defined_names(self.children[1], include_setitem)
class TryStmt(Flow): class TryStmt(Flow):
@@ -796,7 +796,7 @@ class WithStmt(Flow):
type = 'with_stmt' type = 'with_stmt'
__slots__ = () __slots__ = ()
def get_defined_names(self): def get_defined_names(self, include_setitem=False):
""" """
Returns the a list of `Name` that the with statement defines. The Returns the a list of `Name` that the with statement defines. The
defined names are set after `as`. defined names are set after `as`.
@@ -805,7 +805,7 @@ class WithStmt(Flow):
for with_item in self.children[1:-2:2]: for with_item in self.children[1:-2:2]:
# Check with items for 'as' names. # Check with items for 'as' names.
if with_item.type == 'with_item': if with_item.type == 'with_item':
names += _defined_names(with_item.children[2]) names += _defined_names(with_item.children[2], include_setitem)
return names return names
def get_test_node_from_name(self, name): def get_test_node_from_name(self, name):
@@ -846,7 +846,7 @@ class ImportFrom(Import):
type = 'import_from' type = 'import_from'
__slots__ = () __slots__ = ()
def get_defined_names(self): def get_defined_names(self, include_setitem=False):
""" """
Returns the a list of `Name` that the import defines. The Returns the a list of `Name` that the import defines. The
defined names are set after `import` or in case an alias - `as` - is defined names are set after `import` or in case an alias - `as` - is
@@ -917,7 +917,7 @@ class ImportName(Import):
type = 'import_name' type = 'import_name'
__slots__ = () __slots__ = ()
def get_defined_names(self): def get_defined_names(self, include_setitem=False):
""" """
Returns the a list of `Name` that the import defines. The defined names Returns the a list of `Name` that the import defines. The defined names
is always the first name after `import` or in case an alias - `as` - is is always the first name after `import` or in case an alias - `as` - is
@@ -1018,7 +1018,7 @@ class YieldExpr(PythonBaseNode):
__slots__ = () __slots__ = ()
def _defined_names(current): def _defined_names(current, include_setitem):
""" """
A helper function to find the defined names in statements, for loops and A helper function to find the defined names in statements, for loops and
list comprehensions. list comprehensions.
@@ -1026,15 +1026,15 @@ def _defined_names(current):
names = [] names = []
if current.type in ('testlist_star_expr', 'testlist_comp', 'exprlist', 'testlist'): if current.type in ('testlist_star_expr', 'testlist_comp', 'exprlist', 'testlist'):
for child in current.children[::2]: for child in current.children[::2]:
names += _defined_names(child) names += _defined_names(child, include_setitem)
elif current.type in ('atom', 'star_expr'): elif current.type in ('atom', 'star_expr'):
names += _defined_names(current.children[1]) names += _defined_names(current.children[1], include_setitem)
elif current.type in ('power', 'atom_expr'): elif current.type in ('power', 'atom_expr'):
if current.children[-2] != '**': # Just if there's no operation if current.children[-2] != '**': # Just if there's no operation
trailer = current.children[-1] trailer = current.children[-1]
if trailer.children[0] == '.': if trailer.children[0] == '.':
names.append(trailer.children[1]) names.append(trailer.children[1])
elif trailer.children[0] == '[': elif trailer.children[0] == '[' and include_setitem:
for node in current.children[-2::-1]: for node in current.children[-2::-1]:
if node.type == 'trailer': if node.type == 'trailer':
names.append(node.children[1]) names.append(node.children[1])
@@ -1051,18 +1051,18 @@ class ExprStmt(PythonBaseNode, DocstringMixin):
type = 'expr_stmt' type = 'expr_stmt'
__slots__ = () __slots__ = ()
def get_defined_names(self): def get_defined_names(self, include_setitem=False):
""" """
Returns a list of `Name` defined before the `=` sign. Returns a list of `Name` defined before the `=` sign.
""" """
names = [] names = []
if self.children[1].type == 'annassign': if self.children[1].type == 'annassign':
names = _defined_names(self.children[0]) names = _defined_names(self.children[0], include_setitem)
return [ return [
name name
for i in range(0, len(self.children) - 2, 2) for i in range(0, len(self.children) - 2, 2)
if '=' in self.children[i + 1].value if '=' in self.children[i + 1].value
for name in _defined_names(self.children[i]) for name in _defined_names(self.children[i], include_setitem)
] + names ] + names
def get_rhs(self): def get_rhs(self):
@@ -1155,7 +1155,7 @@ class Param(PythonBaseNode):
else: else:
return self._tfpdef() return self._tfpdef()
def get_defined_names(self): def get_defined_names(self, include_setitem=False):
return [self.name] return [self.name]
@property @property
@@ -1213,12 +1213,12 @@ class SyncCompFor(PythonBaseNode):
type = 'sync_comp_for' type = 'sync_comp_for'
__slots__ = () __slots__ = ()
def get_defined_names(self): def get_defined_names(self, include_setitem=False):
""" """
Returns the a list of `Name` that the comprehension defines. Returns the a list of `Name` that the comprehension defines.
""" """
# allow async for # allow async for
return _defined_names(self.children[1]) return _defined_names(self.children[1], include_setitem)
# This is simply here so an older Jedi version can work with this new parso # This is simply here so an older Jedi version can work with this new parso