1
0
forked from VimPlug/jedi

Refactor params and what execution contexts need

This commit is contained in:
Dave Halter
2019-09-01 14:14:42 +02:00
parent 59f26ad6ab
commit edb17b8e7c
13 changed files with 142 additions and 123 deletions

View File

@@ -34,7 +34,7 @@ from jedi.inference import usages
from jedi.inference.arguments import try_iter_content
from jedi.inference.helpers import get_module_names, infer_call_of_leaf
from jedi.inference.sys_path import transform_path_to_dotted
from jedi.inference.names import TreeNameDefinition, ParamName
from jedi.inference.names import TreeNameDefinition, SimpleParamName
from jedi.inference.syntax_tree import tree_name_to_values
from jedi.inference.value import ModuleValue
from jedi.inference.base_value import ValueSet
@@ -504,14 +504,13 @@ def names(source=None, path=None, encoding='utf-8', all_scopes=False,
return definitions and is_def or references and not is_def
def create_name(name):
context = module_context.create_context(name)
if name.parent.type == 'param':
cls = ParamName
func = tree.search_ancestor(name, 'funcdef', 'lambdef')
func = context.get_root_context().create_value(func)
return SimpleParamName(func, name)
else:
cls = TreeNameDefinition
return cls(
module_context.create_context(name),
name
)
return TreeNameDefinition(context, name)
# Set line/column to a random position, because they don't matter.
script = Script(source, line=1, column=0, path=path, encoding=encoding, environment=environment)

View File

@@ -74,7 +74,7 @@ from jedi.inference import imports
from jedi.inference import recursion
from jedi.inference.cache import inference_state_function_cache
from jedi.inference import helpers
from jedi.inference.names import TreeNameDefinition, ParamName
from jedi.inference.names import TreeNameDefinition, SimpleParamName
from jedi.inference.base_value import ContextualizedName, ContextualizedNode, \
ValueSet, NO_VALUES, iterate_values
from jedi.inference.value import ClassValue, FunctionValue
@@ -305,7 +305,9 @@ class InferenceState(object):
if is_simple_name:
return [TreeNameDefinition(context, name)]
elif type_ == 'param':
return [ParamName(context, name)]
func = tree.search_ancestor(name, 'funcdef', 'lambdef')
func = context.get_root_context().create_value(func)
return [SimpleParamName(func, name)]
elif type_ in ('import_from', 'import_name'):
module_names = imports.goto_import(context, name)
return module_names

View File

@@ -145,8 +145,8 @@ class _AbstractArgumentsMixin(object):
def unpack(self, funcdef=None):
raise NotImplementedError
def get_executed_param_names_and_issues(self, execution_context):
return get_executed_param_names_and_issues(execution_context, self)
def get_executed_param_names_and_issues(self, function_value):
return get_executed_param_names_and_issues(function_value, self)
def get_calling_nodes(self):
return []
@@ -160,12 +160,12 @@ class AbstractArguments(_AbstractArgumentsMixin):
class AnonymousArguments(AbstractArguments):
@memoize_method
def get_executed_param_names_and_issues(self, execution_context):
def get_executed_param_names_and_issues(self, function_value):
from jedi.inference.dynamic_params import search_param_names
return search_param_names(
execution_context.inference_state,
execution_context,
execution_context.tree_node
function_value.inference_state,
function_value,
function_value.tree_node
), []
def __repr__(self):

View File

@@ -269,27 +269,22 @@ def _execute_array_values(inference_state, array):
@inference_state_method_cache()
def infer_param(execution_context, param):
from jedi.inference.value.instance import InstanceArguments
from jedi.inference.value import FunctionExecutionContext
def infer_param(function_value, param):
def infer_docstring(docstring):
return ValueSet(
p
for param_str in _search_param_in_docstr(docstring, param.name.value)
for p in _infer_for_statement_string(module_context, param_str)
)
module_context = execution_context.get_root_context()
module_context = function_value.get_root_context()
func = param.get_parent_function()
if func.type == 'lambdef':
return NO_VALUES
types = infer_docstring(execution_context.py__doc__())
if isinstance(execution_context, FunctionExecutionContext) \
and isinstance(execution_context.var_args, InstanceArguments) \
and execution_context.function_value.py__name__() == '__init__':
class_value = execution_context.var_args.instance.class_value
types |= infer_docstring(class_value.py__doc__())
types = infer_docstring(function_value.py__doc__())
if function_value.is_bound_method() \
and function_value.py__name__() == '__init__':
types |= infer_docstring(function_value.class_context.py__doc__())
debug.dbg('Found param types for docstring: %s', types, color='BLUE')
return types

View File

@@ -56,7 +56,7 @@ class DynamicExecutedParamName(ParamNameWrapper):
@debug.increase_indent
def search_param_names(inference_state, execution_context, funcdef):
def search_param_names(inference_state, function_value, funcdef):
"""
A dynamic search for param values. If you try to complete a type:
@@ -70,28 +70,28 @@ def search_param_names(inference_state, execution_context, funcdef):
is.
"""
if not settings.dynamic_params:
return create_default_params(execution_context, funcdef)
return create_default_params(function_value, funcdef)
inference_state.dynamic_params_depth += 1
try:
path = execution_context.get_root_context().py__file__()
path = function_value.get_root_context().py__file__()
if path is not None and is_stdlib_path(path):
# We don't want to search for usages in the stdlib. Usually people
# don't work with it (except if you are a core maintainer, sorry).
# This makes everything slower. Just disable it and run the tests,
# you will see the slowdown, especially in 3.6.
return create_default_params(execution_context, funcdef)
return create_default_params(function_value, funcdef)
if funcdef.type == 'lambdef':
string_name = _get_lambda_name(funcdef)
if string_name is None:
return create_default_params(execution_context, funcdef)
return create_default_params(function_value, funcdef)
else:
string_name = funcdef.name.value
debug.dbg('Dynamic param search in %s.', string_name, color='MAGENTA')
try:
module_context = execution_context.get_root_context()
module_context = function_value.get_root_context()
function_executions = _search_function_executions(
inference_state,
module_context,
@@ -106,7 +106,7 @@ def search_param_names(inference_state, execution_context, funcdef):
params = [DynamicExecutedParamName(executed_param_names)
for executed_param_names in zipped_param_names]
else:
return create_default_params(execution_context, funcdef)
return create_default_params(function_value, funcdef)
finally:
debug.dbg('Dynamic param result finished', color='MAGENTA')
return params

View File

@@ -145,7 +145,7 @@ class ParserTreeFilter(AbstractUsedNamesFilter):
class FunctionExecutionFilter(ParserTreeFilter):
param_name = ParamName
def __init__(self, parent_context, node_context=None,
def __init__(self, parent_context, function_value, node_context=None,
until_position=None, origin_scope=None):
super(FunctionExecutionFilter, self).__init__(
parent_context,
@@ -153,13 +153,14 @@ class FunctionExecutionFilter(ParserTreeFilter):
until_position,
origin_scope
)
self._function_value = function_value
@to_list
def _convert_names(self, names):
for name in names:
param = search_ancestor(name, 'param')
if param:
yield self.param_name(self.parent_context, name)
yield self.param_name(self._function_value, name, self.parent_context.var_args)
else:
yield TreeNameDefinition(self.parent_context, name)

View File

@@ -13,7 +13,7 @@ from jedi._compatibility import force_unicode, Parameter
from jedi.inference.cache import inference_state_method_cache
from jedi.inference.base_value import ValueSet, NO_VALUES
from jedi.inference.gradual.typing import TypeVar, LazyGenericClass, \
AbstractAnnotatedClass, TypingClassValueWithIndex
AbstractAnnotatedClass
from jedi.inference.gradual.typing import GenericClass
from jedi.inference.helpers import is_string
from jedi.inference.compiled import builtin_from_name
@@ -107,11 +107,11 @@ def _split_comment_param_declaration(decl_text):
@inference_state_method_cache()
def infer_param(execution_context, param, ignore_stars=False):
values = _infer_param(execution_context, param)
def infer_param(function_value, param, ignore_stars=False):
values = _infer_param(function_value, param)
if ignore_stars:
return values
inference_state = execution_context.inference_state
inference_state = function_value.inference_state
if param.star_count == 1:
tuple_ = builtin_from_name(inference_state, 'tuple')
return ValueSet([GenericClass(
@@ -128,7 +128,7 @@ def infer_param(execution_context, param, ignore_stars=False):
return values
def _infer_param(execution_context, param):
def _infer_param(function_value, param):
"""
Infers the type of a function parameter, using type annotations.
"""
@@ -161,7 +161,7 @@ def _infer_param(execution_context, param):
params_comments, all_params
)
from jedi.inference.value.instance import InstanceArguments
if isinstance(execution_context.var_args, InstanceArguments):
if function_value.is_bound_method():
if index == 0:
# Assume it's self, which is already handled
return NO_VALUES
@@ -171,11 +171,11 @@ def _infer_param(execution_context, param):
param_comment = params_comments[index]
return _infer_annotation_string(
execution_context.function_value.get_default_param_context(),
function_value.get_default_param_context(),
param_comment
)
# Annotations are like default params and resolve in the same way.
context = execution_context.function_value.get_default_param_context()
context = function_value.get_default_param_context()
return infer_annotation(context, annotation)

View File

@@ -242,8 +242,24 @@ class BaseTreeParamName(ParamNameInterface, AbstractTreeName):
output += '=' + default.get_code(include_prefix=False)
return output
def get_public_name(self):
name = self.string_name
if name.startswith('__'):
# Params starting with __ are an equivalent to positional only
# variables in typeshed.
name = name[2:]
return name
def goto(self, **kwargs):
return [self]
class SimpleParamName(BaseTreeParamName):
def __init__(self, function_value, tree_name):
super(BaseTreeParamName, self).__init__(
function_value.get_default_param_context(), tree_name)
self.function_value = function_value
class ParamName(BaseTreeParamName):
def _get_param_node(self):
return search_ancestor(self.tree_name, 'param')
@@ -254,7 +270,7 @@ class ParamName(BaseTreeParamName):
def infer_annotation(self, execute_annotation=True, ignore_stars=False):
from jedi.inference.gradual.annotation import infer_param
values = infer_param(
self.parent_context, self._get_param_node(),
self.function_value, self._get_param_node(),
ignore_stars=ignore_stars)
if execute_annotation:
values = values.execute_annotation()
@@ -264,20 +280,12 @@ class ParamName(BaseTreeParamName):
node = self.default_node
if node is None:
return NO_VALUES
return self.parent_context.parent_context.infer_node(node)
return self.parent_context.infer_node(node)
@property
def default_node(self):
return self._get_param_node().default
def get_public_name(self):
name = self.string_name
if name.startswith('__'):
# Params starting with __ are an equivalent to positional only
# variables in typeshed.
name = name[2:]
return name
def get_kind(self):
tree_param = self._get_param_node()
if tree_param.star_count == 1: # *args
@@ -311,14 +319,24 @@ class ParamName(BaseTreeParamName):
if values:
return values
doc_params = docstrings.infer_param(self.parent_context, self._get_param_node())
if doc_params:
doc_params = docstrings.infer_param(self.function_value, self._get_param_node())
return doc_params
class ParamName(SimpleParamName):
def __init__(self, function_value, tree_name, arguments):
super(ParamName, self).__init__(function_value, tree_name)
self.arguments = arguments
def infer(self):
values = super(ParamName, self).infer()
if values:
return values
return self.get_executed_param_name().infer()
def get_executed_param_name(self):
params_names, _ = self.parent_context.get_executed_param_names_and_issues()
params_names, _ = self.arguments.get_executed_param_names_and_issues(self.function_value)
return params_names[self._get_param_node().position_index]

View File

@@ -20,8 +20,11 @@ def _add_argument_issue(error_name, lazy_value, message):
class ExecutedParamName(ParamName):
"""Fake a param and give it values."""
def __init__(self, execution_context, param_node, lazy_value, is_default=False):
super(ExecutedParamName, self).__init__(execution_context, param_node.name)
def __init__(self, function_value, arguments, param_node, lazy_value, is_default=False):
# The arguments parameter is not needed, because it's an executed param
# name.
super(ExecutedParamName, self).__init__(
function_value, param_node.name, arguments=None)
self._lazy_value = lazy_value
self._is_default = is_default
@@ -48,13 +51,13 @@ class ExecutedParamName(ParamName):
@property
def var_args(self):
return self.parent_context.var_args
return self.arguments
def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self.string_name)
def get_executed_param_names_and_issues(execution_context, arguments):
def get_executed_param_names_and_issues(function_value, arguments):
def too_many_args(argument):
m = _error_argument_count(funcdef, len(unpacked_va))
# Just report an error for the first param that is not needed (like
@@ -74,11 +77,11 @@ def get_executed_param_names_and_issues(execution_context, arguments):
issues = [] # List[Optional[analysis issue]]
result_params = []
param_dict = {}
funcdef = execution_context.tree_node
funcdef = function_value.tree_node
# Default params are part of the value where the function was defined.
# This means that they might have access on class variables that the
# function itself doesn't have.
default_param_context = execution_context.function_value.get_default_param_context()
default_param_context = function_value.get_default_param_context()
for param in funcdef.get_params():
param_dict[param.name.value] = param
@@ -114,7 +117,8 @@ def get_executed_param_names_and_issues(execution_context, arguments):
contextualized_node.node, message=m)
)
else:
keys_used[key] = ExecutedParamName(execution_context, key_param, argument)
keys_used[key] = ExecutedParamName(
function_value, arguments, key_param, argument)
key, argument = next(var_arg_iterator, (None, None))
try:
@@ -134,13 +138,13 @@ def get_executed_param_names_and_issues(execution_context, arguments):
var_arg_iterator.push_back((key, argument))
break
lazy_value_list.append(argument)
seq = iterable.FakeTuple(execution_context.inference_state, lazy_value_list)
seq = iterable.FakeTuple(function_value.inference_state, lazy_value_list)
result_arg = LazyKnownValue(seq)
elif param.star_count == 2:
if argument is not None:
too_many_args(argument)
# **kwargs param
dct = iterable.FakeDict(execution_context.inference_state, dict(non_matching_keys))
dct = iterable.FakeDict(function_value.inference_state, dict(non_matching_keys))
result_arg = LazyKnownValue(dct)
non_matching_keys = {}
else:
@@ -167,8 +171,7 @@ def get_executed_param_names_and_issues(execution_context, arguments):
result_arg = argument
result_params.append(ExecutedParamName(
execution_context, param, result_arg,
is_default=is_default
function_value, arguments, param, result_arg, is_default=is_default
))
if not isinstance(result_arg, LazyUnknownValue):
keys_used[param.name.value] = result_params[-1]
@@ -221,22 +224,22 @@ def _error_argument_count(funcdef, actual_count):
% (funcdef.name, before, len(params), actual_count))
def _create_default_param(execution_context, param):
def _create_default_param(function_value, arguments, param):
if param.star_count == 1:
result_arg = LazyKnownValue(
iterable.FakeTuple(execution_context.inference_state, [])
iterable.FakeTuple(function_value.inference_state, [])
)
elif param.star_count == 2:
result_arg = LazyKnownValue(
iterable.FakeDict(execution_context.inference_state, {})
iterable.FakeDict(function_value.inference_state, {})
)
elif param.default is None:
result_arg = LazyUnknownValue()
else:
result_arg = LazyTreeValue(execution_context.parent_context, param.default)
return ExecutedParamName(execution_context, param, result_arg)
result_arg = LazyTreeValue(function_value.parent_context, param.default)
return ExecutedParamName(function_value, arguments, param, result_arg)
def create_default_params(execution_context, funcdef):
return [_create_default_param(execution_context, p)
def create_default_params(function_value, funcdef):
return [_create_default_param(function_value, None, p)
for p in funcdef.get_params()]

View File

@@ -1,5 +1,7 @@
from jedi._compatibility import Parameter
from jedi.cache import memoize_method
from jedi import debug
from jedi import parser_utils
class _SignatureMixin(object):
@@ -91,6 +93,25 @@ class TreeSignature(AbstractSignature):
params = process_params(params)
return params
def matches_signature(self, arguments):
executed_param_names, issues = \
arguments.get_executed_param_names_and_issues(self._function_value)
if issues:
return False
matches = all(executed_param_name.matches_signature()
for executed_param_name in executed_param_names)
if debug.enable_notice:
tree_node = self._function_value.tree_node
signature = parser_utils.get_call_signature(tree_node)
if matches:
debug.dbg("Overloading match: %s@%s (%s)",
signature, tree_node.start_pos[0], arguments, color='BLUE')
else:
debug.dbg("Overloading no match: %s@%s (%s)",
signature, tree_node.start_pos[0], arguments, color='BLUE')
return matches
class BuiltinSignature(AbstractSignature):
def __init__(self, value, return_string, is_bound=False):

View File

@@ -69,8 +69,8 @@ class FunctionMixin(object):
return ValueSet([BoundMethod(instance, self)])
def get_param_names(self):
function_execution = self.as_context()
return [ParamName(function_execution, param.name)
arguments = AnonymousArguments()
return [ParamName(self, param.name, arguments)
for param in self.tree_node.get_params()]
@property
@@ -282,29 +282,11 @@ class FunctionExecutionContext(ValueContext, TreeContextMixin):
)
def get_filters(self, until_position=None, origin_scope=None):
yield self.function_execution_filter(self,
until_position=until_position,
origin_scope=origin_scope)
yield FunctionExecutionFilter(
self, self._value, until_position=until_position, origin_scope=origin_scope)
def get_executed_param_names_and_issues(self):
return self.var_args.get_executed_param_names_and_issues(self)
def matches_signature(self):
executed_param_names, issues = self.get_executed_param_names_and_issues()
if issues:
return False
matches = all(executed_param_name.matches_signature()
for executed_param_name in executed_param_names)
if debug.enable_notice:
signature = parser_utils.get_call_signature(self.tree_node)
if matches:
debug.dbg("Overloading match: %s@%s (%s)",
signature, self.tree_node.start_pos[0], self.var_args, color='BLUE')
else:
debug.dbg("Overloading no match: %s@%s (%s)",
signature, self.tree_node.start_pos[0], self.var_args, color='BLUE')
return matches
return self.var_args.get_executed_param_names_and_issues(self._value)
def infer(self):
"""
@@ -355,18 +337,12 @@ class OverloadedFunctionValue(FunctionMixin, ValueWrapper):
def py__call__(self, arguments):
debug.dbg("Execute overloaded function %s", self._wrapped_value, color='BLUE')
function_executions = []
value_set = NO_VALUES
matched = False
for f in self._overloaded_functions:
function_execution = f.as_context(arguments)
for signature in self.get_signatures():
function_execution = signature.value.as_context(arguments)
function_executions.append(function_execution)
if function_execution.matches_signature():
matched = True
if signature.matches_signature(arguments):
return function_execution.infer()
if matched:
return value_set
if self.inference_state.is_analysis:
# In this case we want precision.
return NO_VALUES

View File

@@ -22,8 +22,9 @@ from jedi.parser_utils import get_parent_scope
class InstanceExecutedParamName(ParamName):
def __init__(self, instance, execution_context, tree_param):
super(InstanceExecutedParamName, self).__init__(execution_context, tree_param.name)
def __init__(self, instance, function_value, tree_param):
super(InstanceExecutedParamName, self).__init__(
function_value, tree_param.name, arguments=None)
self._instance = instance
def infer(self):
@@ -37,22 +38,22 @@ class AnonymousInstanceArguments(AnonymousArguments):
def __init__(self, instance):
self._instance = instance
def get_executed_param_names_and_issues(self, execution_context):
def get_executed_param_names_and_issues(self, function_value):
from jedi.inference.dynamic_params import search_param_names
tree_params = execution_context.tree_node.get_params()
tree_params = function_value.tree_node.get_params()
if not tree_params:
return [], []
self_param = InstanceExecutedParamName(
self._instance, execution_context, tree_params[0])
self._instance, function_value, tree_params[0])
if len(tree_params) == 1:
# If the only param is self, we don't need to try to find
# executions of this function, we have all the params already.
return [self_param], []
executed_param_names = list(search_param_names(
execution_context.inference_state,
execution_context,
execution_context.tree_node
function_value.inference_state,
function_value,
function_value.tree_node
))
executed_param_names[0] = self_param
return executed_param_names, []
@@ -289,16 +290,16 @@ class TreeInstance(AbstractInstanceValue):
from jedi.inference.gradual.annotation import py__annotations__, \
infer_type_vars_for_execution
args = InstanceArguments(self, self.var_args)
for signature in self.class_value.py__getattribute__('__init__').get_signatures():
# Just take the first result, it should always be one, because we
# control the typeshed code.
bound_method = BoundMethod(self, signature.value)
execution = bound_method.as_context(self.var_args)
if not execution.matches_signature():
if not signature.matches_signature(args):
# First check if the signature even matches, if not we don't
# need to infer anything.
continue
bound_method = BoundMethod(self, signature.value)
execution = bound_method.as_context(self.var_args)
all_annotations = py__annotations__(execution.tree_node)
type_var_dict = infer_type_vars_for_execution(execution, all_annotations)
if type_var_dict:
@@ -574,8 +575,8 @@ class InstanceArguments(TreeArgumentsWrapper):
for values in self._wrapped_arguments.unpack(func):
yield values
def get_executed_param_names_and_issues(self, execution_context):
def get_executed_param_names_and_issues(self, function_value):
if isinstance(self._wrapped_arguments, AnonymousInstanceArguments):
return self._wrapped_arguments.get_executed_param_names_and_issues(execution_context)
return self._wrapped_arguments.get_executed_param_names_and_issues(function_value)
return super(InstanceArguments, self).get_executed_param_names_and_issues(execution_context)
return super(InstanceArguments, self).get_executed_param_names_and_issues(function_value)

View File

@@ -561,6 +561,9 @@ class FakeDict(_DictMixin, Sequence):
def exact_key_items(self):
return self._dct.items()
def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self._dct)
class MergedArray(Sequence):
def __init__(self, inference_state, arrays):