mirror of
https://github.com/davidhalter/jedi.git
synced 2025-12-06 14:04:26 +08:00
349 lines
11 KiB
Python
349 lines
11 KiB
Python
"""
|
|
Filters are objects that you can use to filter names in different scopes. They
|
|
are needed for name resolution.
|
|
"""
|
|
from abc import abstractmethod
|
|
from typing import List, MutableMapping, Type
|
|
import weakref
|
|
|
|
from parso.tree import search_ancestor
|
|
from parso.python.tree import Name, UsedNamesMapping
|
|
|
|
from jedi.inference import flow_analysis
|
|
from jedi.inference.base_value import ValueSet, ValueWrapper, \
|
|
LazyValueWrapper
|
|
from jedi.parser_utils import get_cached_parent_scope
|
|
from jedi.inference.utils import to_list
|
|
from jedi.inference.names import TreeNameDefinition, ParamName, \
|
|
AnonymousParamName, AbstractNameDefinition, NameWrapper
|
|
|
|
_definition_name_cache: MutableMapping[UsedNamesMapping, List[Name]]
|
|
_definition_name_cache = weakref.WeakKeyDictionary()
|
|
|
|
|
|
class AbstractFilter:
|
|
_until_position = None
|
|
|
|
def _filter(self, names):
|
|
if self._until_position is not None:
|
|
return [n for n in names if n.start_pos < self._until_position]
|
|
return names
|
|
|
|
@abstractmethod
|
|
def get(self, name):
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def values(self):
|
|
raise NotImplementedError
|
|
|
|
|
|
class FilterWrapper:
|
|
name_wrapper_class: Type[NameWrapper]
|
|
|
|
def __init__(self, wrapped_filter):
|
|
self._wrapped_filter = wrapped_filter
|
|
|
|
def wrap_names(self, names):
|
|
return [self.name_wrapper_class(name) for name in names]
|
|
|
|
def get(self, name):
|
|
return self.wrap_names(self._wrapped_filter.get(name))
|
|
|
|
def values(self):
|
|
return self.wrap_names(self._wrapped_filter.values())
|
|
|
|
|
|
def _get_definition_names(used_names, name_key):
|
|
try:
|
|
for_module = _definition_name_cache[used_names]
|
|
except KeyError:
|
|
for_module = _definition_name_cache[used_names] = {}
|
|
|
|
try:
|
|
return for_module[name_key]
|
|
except KeyError:
|
|
names = used_names.get(name_key, ())
|
|
result = for_module[name_key] = tuple(
|
|
name for name in names if name.is_definition(include_setitem=True)
|
|
)
|
|
return result
|
|
|
|
|
|
class AbstractUsedNamesFilter(AbstractFilter):
|
|
name_class = TreeNameDefinition
|
|
|
|
def __init__(self, parent_context, parser_scope):
|
|
self._parser_scope = parser_scope
|
|
self._module_node = self._parser_scope.get_root_node()
|
|
self._used_names = self._module_node.get_used_names()
|
|
self.parent_context = parent_context
|
|
|
|
def get(self, name):
|
|
return self._convert_names(self._filter(
|
|
_get_definition_names(self._used_names, name),
|
|
))
|
|
|
|
def _convert_names(self, names):
|
|
return [self.name_class(self.parent_context, name) for name in names]
|
|
|
|
def values(self):
|
|
return self._convert_names(
|
|
name
|
|
for name_key in self._used_names
|
|
for name in self._filter(
|
|
_get_definition_names(self._used_names, name_key),
|
|
)
|
|
)
|
|
|
|
def __repr__(self):
|
|
return '<%s: %s>' % (self.__class__.__name__, self.parent_context)
|
|
|
|
|
|
class ParserTreeFilter(AbstractUsedNamesFilter):
|
|
def __init__(self, parent_context, node_context=None, until_position=None,
|
|
origin_scope=None):
|
|
"""
|
|
node_context is an option to specify a second value for use cases
|
|
like the class mro where the parent class of a new name would be the
|
|
value, but for some type inference it's important to have a local
|
|
value of the other classes.
|
|
"""
|
|
if node_context is None:
|
|
node_context = parent_context
|
|
super().__init__(parent_context, node_context.tree_node)
|
|
self._node_context = node_context
|
|
self._origin_scope = origin_scope
|
|
self._until_position = until_position
|
|
|
|
def _filter(self, names):
|
|
names = super()._filter(names)
|
|
names = [n for n in names if self._is_name_reachable(n)]
|
|
return list(self._check_flows(names))
|
|
|
|
def _is_name_reachable(self, name):
|
|
parent = name.parent
|
|
if parent.type == 'trailer':
|
|
return False
|
|
base_node = parent if parent.type in ('classdef', 'funcdef') else name
|
|
return get_cached_parent_scope(self._used_names, base_node) == self._parser_scope
|
|
|
|
def _check_flows(self, names):
|
|
for name in sorted(names, key=lambda name: name.start_pos, reverse=True):
|
|
check = flow_analysis.reachability_check(
|
|
context=self._node_context,
|
|
value_scope=self._parser_scope,
|
|
node=name,
|
|
origin_scope=self._origin_scope
|
|
)
|
|
if check is not flow_analysis.UNREACHABLE:
|
|
yield name
|
|
|
|
if check is flow_analysis.REACHABLE:
|
|
break
|
|
|
|
|
|
class _FunctionExecutionFilter(ParserTreeFilter):
|
|
def __init__(self, parent_context, function_value, until_position, origin_scope):
|
|
super().__init__(
|
|
parent_context,
|
|
until_position=until_position,
|
|
origin_scope=origin_scope,
|
|
)
|
|
self._function_value = function_value
|
|
|
|
def _convert_param(self, param, name):
|
|
raise NotImplementedError
|
|
|
|
@to_list
|
|
def _convert_names(self, names):
|
|
for name in names:
|
|
param = search_ancestor(name, 'param')
|
|
# Here we don't need to check if the param is a default/annotation,
|
|
# because those are not definitions and never make it to this
|
|
# point.
|
|
if param:
|
|
yield self._convert_param(param, name)
|
|
else:
|
|
yield TreeNameDefinition(self.parent_context, name)
|
|
|
|
|
|
class FunctionExecutionFilter(_FunctionExecutionFilter):
|
|
def __init__(self, *args, arguments, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._arguments = arguments
|
|
|
|
def _convert_param(self, param, name):
|
|
return ParamName(self._function_value, name, self._arguments)
|
|
|
|
|
|
class AnonymousFunctionExecutionFilter(_FunctionExecutionFilter):
|
|
def _convert_param(self, param, name):
|
|
return AnonymousParamName(self._function_value, name)
|
|
|
|
|
|
class GlobalNameFilter(AbstractUsedNamesFilter):
|
|
def get(self, name):
|
|
try:
|
|
names = self._used_names[name]
|
|
except KeyError:
|
|
return []
|
|
return self._convert_names(self._filter(names))
|
|
|
|
@to_list
|
|
def _filter(self, names):
|
|
for name in names:
|
|
if name.parent.type == 'global_stmt':
|
|
yield name
|
|
|
|
def values(self):
|
|
return self._convert_names(
|
|
name for name_list in self._used_names.values()
|
|
for name in self._filter(name_list)
|
|
)
|
|
|
|
|
|
class DictFilter(AbstractFilter):
|
|
def __init__(self, dct):
|
|
self._dct = dct
|
|
|
|
def get(self, name):
|
|
try:
|
|
value = self._convert(name, self._dct[name])
|
|
except KeyError:
|
|
return []
|
|
else:
|
|
return list(self._filter([value]))
|
|
|
|
def values(self):
|
|
def yielder():
|
|
for item in self._dct.items():
|
|
try:
|
|
yield self._convert(*item)
|
|
except KeyError:
|
|
pass
|
|
return self._filter(yielder())
|
|
|
|
def _convert(self, name, value):
|
|
return value
|
|
|
|
def __repr__(self):
|
|
keys = ', '.join(self._dct.keys())
|
|
return '<%s: for {%s}>' % (self.__class__.__name__, keys)
|
|
|
|
|
|
class MergedFilter:
|
|
def __init__(self, *filters):
|
|
self._filters = filters
|
|
|
|
def get(self, name):
|
|
return [n for filter in self._filters for n in filter.get(name)]
|
|
|
|
def values(self):
|
|
return [n for filter in self._filters for n in filter.values()]
|
|
|
|
def __repr__(self):
|
|
return '%s(%s)' % (self.__class__.__name__, ', '.join(str(f) for f in self._filters))
|
|
|
|
|
|
class _BuiltinMappedMethod(ValueWrapper):
|
|
"""``Generator.__next__`` ``dict.values`` methods and so on."""
|
|
api_type = 'function'
|
|
|
|
def __init__(self, value, method, builtin_func):
|
|
super().__init__(builtin_func)
|
|
self._value = value
|
|
self._method = method
|
|
|
|
def py__call__(self, arguments):
|
|
# TODO add TypeError if params are given/or not correct.
|
|
return self._method(self._value, arguments)
|
|
|
|
|
|
class SpecialMethodFilter(DictFilter):
|
|
"""
|
|
A filter for methods that are defined in this module on the corresponding
|
|
classes like Generator (for __next__, etc).
|
|
"""
|
|
class SpecialMethodName(AbstractNameDefinition):
|
|
api_type = 'function'
|
|
|
|
def __init__(self, parent_context, string_name, callable_, builtin_value):
|
|
self.parent_context = parent_context
|
|
self.string_name = string_name
|
|
self._callable = callable_
|
|
self._builtin_value = builtin_value
|
|
|
|
def infer(self):
|
|
for filter in self._builtin_value.get_filters():
|
|
# We can take the first index, because on builtin methods there's
|
|
# always only going to be one name. The same is true for the
|
|
# inferred values.
|
|
for name in filter.get(self.string_name):
|
|
builtin_func = next(iter(name.infer()))
|
|
break
|
|
else:
|
|
continue
|
|
break
|
|
return ValueSet([
|
|
_BuiltinMappedMethod(self.parent_context, self._callable, builtin_func)
|
|
])
|
|
|
|
def __init__(self, value, dct, builtin_value):
|
|
super().__init__(dct)
|
|
self.value = value
|
|
self._builtin_value = builtin_value
|
|
"""
|
|
This value is what will be used to introspect the name, where as the
|
|
other value will be used to execute the function.
|
|
|
|
We distinguish, because we have to.
|
|
"""
|
|
|
|
def _convert(self, name, value):
|
|
return self.SpecialMethodName(self.value, name, value, self._builtin_value)
|
|
|
|
|
|
class _OverwriteMeta(type):
|
|
def __init__(cls, name, bases, dct):
|
|
super().__init__(name, bases, dct)
|
|
|
|
base_dct = {}
|
|
for base_cls in reversed(cls.__bases__):
|
|
try:
|
|
base_dct.update(base_cls.overwritten_methods)
|
|
except AttributeError:
|
|
pass
|
|
|
|
for func in cls.__dict__.values():
|
|
try:
|
|
base_dct.update(func.registered_overwritten_methods)
|
|
except AttributeError:
|
|
pass
|
|
cls.overwritten_methods = base_dct
|
|
|
|
|
|
class _AttributeOverwriteMixin:
|
|
def get_filters(self, *args, **kwargs):
|
|
yield SpecialMethodFilter(self, self.overwritten_methods, self._wrapped_value)
|
|
yield from self._wrapped_value.get_filters(*args, **kwargs)
|
|
|
|
|
|
class LazyAttributeOverwrite(_AttributeOverwriteMixin, LazyValueWrapper,
|
|
metaclass=_OverwriteMeta):
|
|
def __init__(self, inference_state):
|
|
self.inference_state = inference_state
|
|
|
|
|
|
class AttributeOverwrite(_AttributeOverwriteMixin, ValueWrapper,
|
|
metaclass=_OverwriteMeta):
|
|
pass
|
|
|
|
|
|
def publish_method(method_name):
|
|
def decorator(func):
|
|
dct = func.__dict__.setdefault('registered_overwritten_methods', {})
|
|
dct[method_name] = func
|
|
return func
|
|
return decorator
|