Fix an issue with contexts.

This commit is contained in:
Dave Halter
2017-01-06 00:08:01 +01:00
parent ae8e43d3c7
commit 1f15ee8bc7
4 changed files with 34 additions and 17 deletions

View File

@@ -165,9 +165,18 @@ class AbstractUsedNamesFilter(AbstractFilter):
class ParserTreeFilter(AbstractUsedNamesFilter): class ParserTreeFilter(AbstractUsedNamesFilter):
def __init__(self, evaluator, context, parser_scope, until_position=None, def __init__(self, evaluator, context, node_context=None, until_position=None,
origin_scope=None, ): origin_scope=None):
super(ParserTreeFilter, self).__init__(context, parser_scope) """
node_context is an option to specify a second context for use cases
like the class mro where the parent class of a new name would be the
context, but for some type inference it's important to have a local
context of the other classes.
"""
if node_context is None:
node_context = context
super(ParserTreeFilter, self).__init__(context, node_context.tree_node)
self._node_context = node_context
self._origin_scope = origin_scope self._origin_scope = origin_scope
self._until_position = until_position self._until_position = until_position
@@ -188,7 +197,7 @@ class ParserTreeFilter(AbstractUsedNamesFilter):
def _check_flows(self, names): def _check_flows(self, names):
for name in sorted(names, key=lambda name: name.start_pos, reverse=True): for name in sorted(names, key=lambda name: name.start_pos, reverse=True):
check = flow_analysis.reachability_check( check = flow_analysis.reachability_check(
self.context, self._parser_scope, name, self._origin_scope self._node_context, self._parser_scope, name, self._origin_scope
) )
if check is not flow_analysis.UNREACHABLE: if check is not flow_analysis.UNREACHABLE:
yield name yield name
@@ -200,12 +209,12 @@ class ParserTreeFilter(AbstractUsedNamesFilter):
class FunctionExecutionFilter(ParserTreeFilter): class FunctionExecutionFilter(ParserTreeFilter):
param_name = ParamName param_name = ParamName
def __init__(self, evaluator, context, parser_scope, def __init__(self, evaluator, context, node_context=None,
until_position=None, origin_scope=None): until_position=None, origin_scope=None):
super(FunctionExecutionFilter, self).__init__( super(FunctionExecutionFilter, self).__init__(
evaluator, evaluator,
context, context,
parser_scope, node_context,
until_position, until_position,
origin_scope origin_scope
) )

View File

@@ -322,7 +322,7 @@ class InstanceClassFilter(filters.ParserTreeFilter):
super(InstanceClassFilter, self).__init__( super(InstanceClassFilter, self).__init__(
evaluator=evaluator, evaluator=evaluator,
context=context, context=context,
parser_scope=class_context.tree_node, node_context=class_context,
origin_scope=origin_scope origin_scope=origin_scope
) )
self._class_context = class_context self._class_context = class_context

View File

@@ -190,7 +190,7 @@ class CompForContext(context.TreeContext):
return self.tree_node return self.tree_node
def get_filters(self, search_global, until_position=None, origin_scope=None): def get_filters(self, search_global, until_position=None, origin_scope=None):
yield ParserTreeFilter(self.evaluator, self, self.tree_node) yield ParserTreeFilter(self.evaluator, self)
class Comprehension(AbstractSequence): class Comprehension(AbstractSequence):

View File

@@ -166,16 +166,20 @@ class ClassContext(use_metaclass(CachedMetaClass, context.TreeContext)):
def get_filters(self, search_global, until_position=None, origin_scope=None, is_instance=False): def get_filters(self, search_global, until_position=None, origin_scope=None, is_instance=False):
if search_global: if search_global:
yield ParserTreeFilter(self.evaluator, self, self.tree_node, until_position, origin_scope=origin_scope) yield ParserTreeFilter(
self.evaluator,
context=self,
until_position=until_position,
origin_scope=origin_scope
)
else: else:
for scope in self.py__mro__(): for scope in self.py__mro__():
print(scope)
if isinstance(scope, compiled.CompiledObject): if isinstance(scope, compiled.CompiledObject):
for filter in scope.get_filters(is_instance=is_instance): for filter in scope.get_filters(is_instance=is_instance):
yield filter yield filter
else: else:
yield ClassFilter( yield ClassFilter(
self.evaluator, self, scope.tree_node, self.evaluator, self, node_context=scope,
origin_scope=origin_scope) origin_scope=origin_scope)
def is_class(self): def is_class(self):
@@ -225,7 +229,12 @@ class FunctionContext(use_metaclass(CachedMetaClass, context.TreeContext)):
def get_filters(self, search_global, until_position=None, origin_scope=None): def get_filters(self, search_global, until_position=None, origin_scope=None):
if search_global: if search_global:
yield ParserTreeFilter(self.evaluator, self, self.tree_node, until_position, origin_scope=origin_scope) yield ParserTreeFilter(
self.evaluator,
context=self,
until_position=until_position,
origin_scope=origin_scope
)
else: else:
scope = self.py__class__() scope = self.py__class__()
for filter in scope.get_filters(search_global=False, origin_scope=origin_scope): for filter in scope.get_filters(search_global=False, origin_scope=origin_scope):
@@ -372,8 +381,8 @@ class FunctionExecutionContext(Executed):
yield result yield result
def get_filters(self, search_global, until_position=None, origin_scope=None): def get_filters(self, search_global, until_position=None, origin_scope=None):
yield self.function_execution_filter(self.evaluator, self, self.tree_node, yield self.function_execution_filter(self.evaluator, self,
until_position, until_position=until_position,
origin_scope=origin_scope) origin_scope=origin_scope)
@memoize_default(default=NO_DEFAULT) @memoize_default(default=NO_DEFAULT)
@@ -419,9 +428,8 @@ class ModuleContext(use_metaclass(CachedMetaClass, context.TreeContext)):
def get_filters(self, search_global, until_position=None, origin_scope=None): def get_filters(self, search_global, until_position=None, origin_scope=None):
yield ParserTreeFilter( yield ParserTreeFilter(
self.evaluator, self.evaluator,
self, context=self,
self.tree_node, until_position=until_position,
until_position,
origin_scope=origin_scope origin_scope=origin_scope
) )
yield GlobalNameFilter(self, self.tree_node) yield GlobalNameFilter(self, self.tree_node)