diff --git a/jedi/evaluate/context/function.py b/jedi/evaluate/context/function.py index eb9e1e6d..c9735a3d 100644 --- a/jedi/evaluate/context/function.py +++ b/jedi/evaluate/context/function.py @@ -16,7 +16,6 @@ from jedi.evaluate.base_context import ContextualizedNode, NO_CONTEXTS, \ ContextSet, TreeContext, ContextWrapper from jedi.evaluate.lazy_context import LazyKnownContexts, LazyKnownContext, \ LazyTreeContext -from jedi.evaluate.context.typing import TypeVar from jedi.evaluate.context import iterable from jedi.evaluate.context import asynchronous from jedi import parser_utils @@ -127,6 +126,9 @@ class FunctionContext(use_metaclass(CachedMetaClass, AbstractFunction)): def get_default_param_context(self): return self.parent_context + def get_matching_functions(self, arguments): + yield self + class MethodContext(FunctionContext): def __init__(self, evaluator, class_context, *args, **kwargs): @@ -302,18 +304,22 @@ class OverloadedFunctionContext(ContextWrapper): self._overloaded_functions = overloaded_functions def py__call__(self, arguments): - context_set = ContextSet() debug.dbg("Execute overloaded function %s", self._wrapped_context, color='BLUE') + return ContextSet.from_sets( + matching_function.py__call__(arguments=arguments) + for matching_function in self.get_matching_functions(arguments) + ) + + def get_matching_functions(self, arguments): for f in self._overloaded_functions: signature = parser_utils.get_call_signature(f.tree_node) if signature_matches(f, arguments): debug.dbg("Overloading match: %s@%s", signature, f.tree_node.start_pos[0], color='BLUE') - context_set |= f.py__call__(arguments=arguments) + yield f else: debug.dbg("Overloading no match: %s@%s (%s)", signature, f.tree_node.start_pos[0], arguments, color='BLUE') - return context_set def signature_matches(function_context, arguments): @@ -369,11 +375,17 @@ def _find_overload_functions(context, tree_node): until_position=tree_node.start_pos ) names = filter.get(tree_node.name.value) + assert isinstance(names, list) + if not names: + break + found = False for name in names: funcdef = name.tree_name.parent if funcdef.type == 'funcdef' and _is_overload_decorated(funcdef): + tree_node = funcdef + found = True yield funcdef - # TODO this is probably not good enough? Why are we always breaking? - break # By default break + if not found: + break diff --git a/jedi/evaluate/context/instance.py b/jedi/evaluate/context/instance.py index a55887bd..957364c9 100644 --- a/jedi/evaluate/context/instance.py +++ b/jedi/evaluate/context/instance.py @@ -358,6 +358,13 @@ class BoundMethod(AbstractFunction): function_execution = self.get_function_execution(arguments) return function_execution.infer() + def get_matching_functions(self, arguments): + for func in self._function.get_matching_functions(arguments): + if func is self: + yield self + else: + yield BoundMethod(self.instance, self.class_context, func) + def __repr__(self): return '<%s: %s>' % (self.__class__.__name__, self._function) diff --git a/jedi/evaluate/context/iterable.py b/jedi/evaluate/context/iterable.py index d59f2e3f..0090676c 100644 --- a/jedi/evaluate/context/iterable.py +++ b/jedi/evaluate/context/iterable.py @@ -32,13 +32,12 @@ from jedi.evaluate.lazy_context import LazyKnownContext, LazyKnownContexts, \ from jedi.evaluate.helpers import get_int_or_none, is_string, \ predefine_names, evaluate_call_of_leaf, reraise_getitem_errors, \ SimpleGetItemNotFound -from jedi.evaluate.utils import safe_property -from jedi.evaluate.utils import to_list +from jedi.evaluate.utils import safe_property, to_list from jedi.evaluate.cache import evaluator_method_cache from jedi.evaluate.helpers import execute_evaluated from jedi.evaluate.filters import ParserTreeFilter, BuiltinOverwrite, \ publish_method -from jedi.evaluate.base_context import ContextSet, NO_CONTEXTS, Context, \ +from jedi.evaluate.base_context import ContextSet, NO_CONTEXTS, \ TreeContext, ContextualizedNode, iterate_contexts from jedi.parser_utils import get_comp_fors @@ -188,9 +187,37 @@ class Sequence(BuiltinOverwrite, IterableMixin): @memoize_method def get_object(self): - compiled_obj = compiled.builtin_from_name(self.evaluator, self.array_type) - only_obj, = execute_evaluated(compiled_obj, self) - return only_obj + klass = compiled.builtin_from_name(self.evaluator, self.array_type) + return self._annotate_class(klass) + + def _annotate_class(self, klass): + from jedi.evaluate.pep0484 import define_type_vars_for_execution + + instance, = klass.execute_evaluated(self) + + for init_function in self._get_init_functions(instance): + # Just take the first result, it should always be one, because we + # control the typeshed code. + execution_context = init_function.get_function_execution() + return define_type_vars_for_execution( + ContextSet(klass), + execution_context, + klass.find_annotation_variables() + ) + return instance + assert "Should never land here, probably an issue with typeshed changes" + + def _get_init_functions(self, instance): + from jedi.evaluate.context.function import OverloadedFunctionContext + from jedi.evaluate import arguments + for init in instance.py__getattribute__('__init__'): + try: + method = init.get_matching_functions + except AttributeError: + continue + else: + for x in method(arguments.ValuesArguments([ContextSet(self)])): + yield x def py__bool__(self): return None # We don't know the length, because of appends. diff --git a/jedi/evaluate/context/klass.py b/jedi/evaluate/context/klass.py index 278f5cad..05d3b736 100644 --- a/jedi/evaluate/context/klass.py +++ b/jedi/evaluate/context/klass.py @@ -45,6 +45,7 @@ from jedi.evaluate import compiled from jedi.evaluate.lazy_context import LazyKnownContext from jedi.evaluate.filters import ParserTreeFilter, TreeNameDefinition, \ ContextName +from jedi.evaluate.arguments import unpack_arglist from jedi.evaluate.base_context import ContextSet, iterator_to_context_set, \ TreeContext @@ -166,6 +167,21 @@ class ClassContext(use_metaclass(CachedMetaClass, TreeContext)): """ api_type = u'class' + @evaluator_method_cache() + def find_annotation_variables(self): + found = [] + arglist = self.tree_node.get_super_arglist() + if arglist is None: + return [] + + for stars, node in unpack_arglist(arglist): + if stars: + continue # These are not relevant for this search. + + from jedi.evaluate.pep0484 import find_unknown_type_vars + found += find_unknown_type_vars(self.parent_context, node) + return found + @evaluator_method_cache(default=()) def py__bases__(self): arglist = self.tree_node.get_super_arglist() diff --git a/jedi/evaluate/context/typing.py b/jedi/evaluate/context/typing.py index 0522bc9e..40b97a09 100644 --- a/jedi/evaluate/context/typing.py +++ b/jedi/evaluate/context/typing.py @@ -11,7 +11,7 @@ from jedi.evaluate.base_context import ContextSet, NO_CONTEXTS, Context, \ iterator_to_context_set, HelperContextMixin from jedi.evaluate.lazy_context import LazyKnownContexts, LazyKnownContext from jedi.evaluate.context.iterable import SequenceLiteralContext -from jedi.evaluate.arguments import repack_with_argument_clinic, unpack_arglist +from jedi.evaluate.arguments import repack_with_argument_clinic from jedi.evaluate.utils import to_list from jedi.evaluate.filters import FilterWrapper, NameWrapper, \ AbstractTreeName, AbstractNameDefinition, ContextName @@ -499,36 +499,6 @@ class _AbstractAnnotatedClass(ClassContext): # not a direct lookup on the class. yield self.get_type_var_filter() - @evaluator_method_cache() - def find_annotation_variables(self): - found = [] - arglist = self.tree_node.get_super_arglist() - if arglist is None: - return [] - - for stars, node in unpack_arglist(arglist): - if stars: - continue # These are not relevant for this search. - - if node.type == 'atom_expr': - trailer = node.children[-1] - if trailer.type == 'trailer' and trailer.children[0] == '[': - for subscript_node in self._unpack_subscriptlist(trailer.children[1]): - type_var_set = self.parent_context.eval_node(subscript_node) - for type_var in type_var_set: - if isinstance(type_var, TypeVar) and type_var not in found: - found.append(type_var) - return found - - def _unpack_subscriptlist(self, subscriptlist): - if subscriptlist.type == 'subscriptlist': - for subscript in subscriptlist.children[::2]: - if subscript.type != 'subscript': - yield subscript - else: - if subscriptlist.type != 'subscript': - yield subscriptlist - def is_same_class(self, other): if not isinstance(other, _AbstractAnnotatedClass): return False diff --git a/jedi/evaluate/param.py b/jedi/evaluate/param.py index db7418cc..ba1d1715 100644 --- a/jedi/evaluate/param.py +++ b/jedi/evaluate/param.py @@ -5,7 +5,6 @@ from jedi.evaluate import analysis from jedi.evaluate.lazy_context import LazyKnownContext, \ LazyTreeContext, LazyUnknownContext from jedi.evaluate import docstrings -from jedi.evaluate import pep0484 from jedi.evaluate.context import iterable @@ -27,6 +26,7 @@ class ExecutedParam(object): def infer(self, use_hints=True): if use_hints: + from jedi.evaluate import pep0484 pep0484_hints = pep0484.infer_param(self._execution_context, self._param_node) doc_params = docstrings.infer_param(self._execution_context, self._param_node) if pep0484_hints or doc_params: diff --git a/jedi/evaluate/pep0484.py b/jedi/evaluate/pep0484.py index cac971c7..42146310 100644 --- a/jedi/evaluate/pep0484.py +++ b/jedi/evaluate/pep0484.py @@ -29,8 +29,7 @@ from parso.python import tree from jedi._compatibility import unicode, force_unicode from jedi.evaluate.cache import evaluator_method_cache from jedi.evaluate import compiled -from jedi.evaluate.base_context import NO_CONTEXTS, ContextSet, ContextWrapper -from jedi.evaluate.filters import DictFilter +from jedi.evaluate.base_context import NO_CONTEXTS, ContextSet from jedi.evaluate.lazy_context import LazyTreeContext from jedi.evaluate.context import ModuleContext from jedi.evaluate.context.typing import TypeVar, AnnotatedClass, AnnotatedSubClass @@ -214,14 +213,30 @@ def infer_return_types(function_execution_context): function_execution_context.function_context.get_default_param_context(), match.group(1).strip() ) + if annotation is None: + return NO_CONTEXTS context = function_execution_context.function_context.get_default_param_context() - return _define_type_vars( + unknown_type_vars = list(find_unknown_type_vars(context, annotation)) + if not unknown_type_vars: + return context.eval_node(annotation) + + return define_type_vars_for_execution( context.eval_node(annotation), - _infer_type_vars(function_execution_context, all_annotations), + function_execution_context, + unknown_type_vars, ).execute_annotation() +def define_type_vars_for_execution(to_define_contexts, execution_context, + unknown_type_vars): + all_annotations = py__annotations__(execution_context.tree_node) + return _define_type_vars( + to_define_contexts, + _infer_type_vars(execution_context, all_annotations), + ) + + def _infer_type_vars(execution_context, annotation_dict): """ Some functions use type vars that are not defined by the class, but rather @@ -233,32 +248,19 @@ def _infer_type_vars(execution_context, annotation_dict): 3. Return the union of all type vars that have been found. """ context = execution_context.function_context.get_default_param_context() - try: - return_annotation = annotation_dict['return'] - except KeyError: - return {} - unknown_type_vars = list(find_annotation_variables(context, return_annotation)) - if not unknown_type_vars: - return {} - - executed_params = execution_context.get_executed_params() annotation_variable_results = {} - # The annotation_dict is ordered. - for i, (annotation_name, annotation_node) in enumerate(annotation_dict.items()): - if annotation_name == 'return': + for executed_param in execution_context.get_executed_params(): + try: + annotation_node = annotation_dict[executed_param.string_name] + except KeyError: continue - annotation_variables = find_annotation_variables(context, annotation_node) + annotation_variables = find_unknown_type_vars(context, annotation_node) if annotation_variables: - try: - param = executed_params[i] - except IndexError: - continue - # Infer unknown type var annotation_context_set = context.eval_node(annotation_node) - actual_context_set = param.infer(use_hints=False) + actual_context_set = executed_param.infer(use_hints=False) for ann in annotation_context_set: _merge_type_var_dicts( annotation_variable_results, @@ -444,16 +446,21 @@ def _find_type_from_comment_hint(context, node, varlist, name): return _evaluate_annotation_string(context, match.group(1).strip(), index) -def find_annotation_variables(context, node): - found = [] - if node.type == 'atom_expr': - trailer = node.children[-1] - if trailer.type == 'trailer' and trailer.children[0] == '[': - for subscript_node in _unpack_subscriptlist(trailer.children[1]): - type_var_set = context.eval_node(subscript_node) - for type_var in type_var_set: - if isinstance(type_var, TypeVar) and type_var not in found: - found.append(type_var) +def find_unknown_type_vars(context, node): + def check_node(node): + if node.type == 'atom_expr': + trailer = node.children[-1] + if trailer.type == 'trailer' and trailer.children[0] == '[': + for subscript_node in _unpack_subscriptlist(trailer.children[1]): + check_node(subscript_node) + else: + type_var_set = context.eval_node(node) + for type_var in type_var_set: + if isinstance(type_var, TypeVar): + found.add(type_var) + + found = set() + check_node(node) return found