1
0
forked from VimPlug/jedi

Fix the selection of overloaded functions. Now it's at least partially working

This commit is contained in:
Dave Halter
2018-08-26 23:04:54 +02:00
parent 5261cdf4a1
commit 4a7bded98d
3 changed files with 62 additions and 41 deletions

View File

@@ -40,7 +40,10 @@ class BaseContextSet(object):
return cls.from_set(aggregated) return cls.from_set(aggregated)
def __or__(self, other): def __or__(self, other):
return type(self).from_set(self._set | other._set) return self.from_set(self._set | other._set)
def __and__(self, other):
return self.from_set(self._set & other._set)
def __iter__(self): def __iter__(self):
for element in self._set: for element in self._set:
@@ -56,11 +59,11 @@ class BaseContextSet(object):
return '%s(%s)' % (self.__class__.__name__, ', '.join(str(s) for s in self._set)) return '%s(%s)' % (self.__class__.__name__, ', '.join(str(s) for s in self._set))
def filter(self, filter_func): def filter(self, filter_func):
return type(self).from_iterable(filter(filter_func, self._set)) return self.from_iterable(filter(filter_func, self._set))
def __getattr__(self, name): def __getattr__(self, name):
def mapper(*args, **kwargs): def mapper(*args, **kwargs):
return type(self).from_sets( return self.from_sets(
getattr(context, name)(*args, **kwargs) getattr(context, name)(*args, **kwargs)
for context in self._set for context in self._set
) )

View File

@@ -68,32 +68,6 @@ class AbstractFunction(TreeContext):
def get_function_execution(self, arguments=None): def get_function_execution(self, arguments=None):
raise NotImplementedError raise NotImplementedError
def py__call__(self, arguments):
function_execution = self.get_function_execution(arguments)
return self.infer_function_execution(function_execution)
def infer_function_execution(self, function_execution):
"""
Created to be used by inheritance.
"""
is_coroutine = self.tree_node.parent.type == 'async_stmt'
is_generator = bool(get_yield_exprs(self.evaluator, self.tree_node))
if is_coroutine:
if is_generator:
if self.evaluator.environment.version_info < (3, 6):
return NO_CONTEXTS
return ContextSet(asynchronous.AsyncGenerator(self.evaluator, function_execution))
else:
if self.evaluator.environment.version_info < (3, 5):
return NO_CONTEXTS
return ContextSet(asynchronous.Coroutine(self.evaluator, function_execution))
else:
if is_generator:
return ContextSet(iterable.Generator(self.evaluator, function_execution))
else:
return function_execution.get_return_values()
def py__name__(self): def py__name__(self):
return self.name.string_name return self.name.string_name
@@ -123,6 +97,10 @@ class FunctionContext(use_metaclass(CachedMetaClass, AbstractFunction)):
) )
return function return function
def py__call__(self, arguments):
function_execution = self.get_function_execution(arguments)
return function_execution.infer()
def get_function_execution(self, arguments=None): def get_function_execution(self, arguments=None):
if arguments is None: if arguments is None:
arguments = AnonymousArguments() arguments = AnonymousArguments()
@@ -268,6 +246,29 @@ class FunctionExecutionContext(TreeContext):
def get_executed_params(self): def get_executed_params(self):
return self.var_args.get_executed_params(self) return self.var_args.get_executed_params(self)
def infer(self):
"""
Created to be used by inheritance.
"""
evaluator = self.evaluator
is_coroutine = self.tree_node.parent.type == 'async_stmt'
is_generator = bool(get_yield_exprs(evaluator, self.tree_node))
if is_coroutine:
if is_generator:
if evaluator.environment.version_info < (3, 6):
return NO_CONTEXTS
return ContextSet(asynchronous.AsyncGenerator(evaluator, self))
else:
if evaluator.environment.version_info < (3, 5):
return NO_CONTEXTS
return ContextSet(asynchronous.Coroutine(evaluator, self))
else:
if is_generator:
return ContextSet(iterable.Generator(evaluator, self))
else:
return self.get_return_values()
class OverloadedFunctionContext(object): class OverloadedFunctionContext(object):
def __init__(self, function, overloaded_functions): def __init__(self, function, overloaded_functions):
@@ -275,11 +276,16 @@ class OverloadedFunctionContext(object):
self._overloaded_functions = overloaded_functions self._overloaded_functions = overloaded_functions
def py__call__(self, arguments): def py__call__(self, arguments):
return ContextSet.from_sets( context_set = ContextSet()
f.py__call__(arguments=arguments) debug.dbg("Execute overloaded function %s", self._function, color='BLUE')
for f in self._overloaded_functions for f in self._overloaded_functions:
if signature_matches(f, arguments) signature = parser_utils.get_call_signature(f.tree_node)
) if signature_matches(f, arguments):
debug.dbg("Overloading - signature %s matches", signature, color='BLUE')
context_set |= f.py__call__(arguments=arguments)
else:
debug.dbg("Overloading - signature %s doesn't match", signature, color='BLUE')
return context_set
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self._function, name) return getattr(self._function, name)
@@ -289,7 +295,6 @@ def signature_matches(function_context, arguments):
unpacked_arguments = arguments.unpack() unpacked_arguments = arguments.unpack()
for param_node in function_context.tree_node.get_params(): for param_node in function_context.tree_node.get_params():
key, argument = next(unpacked_arguments, (None, None)) key, argument = next(unpacked_arguments, (None, None))
print(param_node)
if argument is None: if argument is None:
# This signature has an parameter more than arguments were given. # This signature has an parameter more than arguments were given.
return False return False
@@ -299,15 +304,19 @@ def signature_matches(function_context, arguments):
return False return False
if param_node.annotation is not None: if param_node.annotation is not None:
annotation_result = function_context.evaluator.eval_node( annotation_result = function_context.evaluator.eval_element(
function_context.parent_context, function_context.parent_context,
param_node.annotation param_node.annotation
) )
print(annotation_result) return has_same_class(argument.infer().py__class__(), annotation_result)
return True return True
def has_same_class(context_set1, context_set2):
return bool(context_set1 & context_set2)
def _find_overload_functions(context, tree_node): def _find_overload_functions(context, tree_node):
def _is_overload_decorated(funcdef): def _is_overload_decorated(funcdef):
if funcdef.parent.type == 'decorated': if funcdef.parent.type == 'decorated':

View File

@@ -11,7 +11,7 @@ from jedi.evaluate.cache import evaluator_method_cache
from jedi.evaluate.arguments import AbstractArguments, AnonymousArguments, \ from jedi.evaluate.arguments import AbstractArguments, AnonymousArguments, \
ValuesArguments ValuesArguments
from jedi.evaluate.context.function import FunctionExecutionContext, \ from jedi.evaluate.context.function import FunctionExecutionContext, \
FunctionContext, AbstractFunction FunctionContext, AbstractFunction, OverloadedFunctionContext
from jedi.evaluate.context.klass import ClassContext, apply_py__get__, ClassFilter from jedi.evaluate.context.klass import ClassContext, apply_py__get__, ClassFilter
from jedi.evaluate.context import iterable from jedi.evaluate.context import iterable
from jedi.parser_utils import get_parent_scope from jedi.parser_utils import get_parent_scope
@@ -325,11 +325,14 @@ class BoundMethod(AbstractFunction):
def py__class__(self): def py__class__(self):
return compiled.get_special_object(self.evaluator, u'BOUND_METHOD_CLASS') return compiled.get_special_object(self.evaluator, u'BOUND_METHOD_CLASS')
def get_function_execution(self, arguments=None): def _get_arguments(self, arguments):
if arguments is None: if arguments is None:
arguments = AnonymousInstanceArguments(self._instance) arguments = AnonymousInstanceArguments(self._instance)
arguments = InstanceArguments(self._instance, arguments) return InstanceArguments(self._instance, arguments)
def get_function_execution(self, arguments=None):
arguments = self._get_arguments(arguments)
if isinstance(self._function, compiled.CompiledObject): if isinstance(self._function, compiled.CompiledObject):
# This is kind of weird, because it's coming from a compiled object # This is kind of weird, because it's coming from a compiled object
@@ -340,6 +343,12 @@ class BoundMethod(AbstractFunction):
return self._function.get_function_execution(arguments) return self._function.get_function_execution(arguments)
def py__call__(self, arguments):
if isinstance(self._function, OverloadedFunctionContext):
return self._function.py__call__(self._get_arguments(arguments))
function_execution = self.get_function_execution(arguments)
return function_execution.infer()
def __repr__(self): def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self._function) return '<%s: %s>' % (self.__class__.__name__, self._function)
@@ -376,7 +385,7 @@ class LazyInstanceClassName(object):
@iterator_to_context_set @iterator_to_context_set
def infer(self): def infer(self):
for result_context in self._class_member_name.infer(): for result_context in self._class_member_name.infer():
if isinstance(result_context, FunctionContext): if isinstance(result_context, (FunctionContext, OverloadedFunctionContext)):
# Classes are never used to resolve anything within the # Classes are never used to resolve anything within the
# functions. Only other functions and modules will resolve # functions. Only other functions and modules will resolve
# those things. # those things.