From 5d9f29743c7da0955bacbfa1442197178fb6af22 Mon Sep 17 00:00:00 2001 From: Dave Halter Date: Sun, 16 Sep 2018 02:19:29 +0200 Subject: [PATCH] Get iter() working and a lot of other typeshed reverse engineering of type vars --- jedi/common/context.py | 2 +- jedi/evaluate/base_context.py | 5 + jedi/evaluate/context/function.py | 2 +- jedi/evaluate/param.py | 11 +- jedi/evaluate/pep0484.py | 160 +++++++++++++++++++++++++++--- 5 files changed, 158 insertions(+), 22 deletions(-) diff --git a/jedi/common/context.py b/jedi/common/context.py index 88074596..a1883b19 100644 --- a/jedi/common/context.py +++ b/jedi/common/context.py @@ -43,7 +43,7 @@ class BaseContextSet(object): aggregated |= set_._set else: aggregated |= frozenset(set_) - return cls._from_frozen_set(aggregated) + return cls._from_frozen_set(frozenset(aggregated)) def __or__(self, other): return self._from_frozen_set(self._set | other._set) diff --git a/jedi/evaluate/base_context.py b/jedi/evaluate/base_context.py index f81f15b9..37c86dba 100644 --- a/jedi/evaluate/base_context.py +++ b/jedi/evaluate/base_context.py @@ -26,6 +26,11 @@ class HelperContextMixin: def execute_evaluated(self, *value_list): return execute_evaluated(self, *value_list) + def merge_types_of_iterate(self): + return ContextSet.from_sets( + lazy_context.infer() for lazy_context in self.iterate() + ) + @Python3Method def py__getattribute__(self, name_or_str, name_context=None, position=None, search_global=False, is_goto=False, diff --git a/jedi/evaluate/context/function.py b/jedi/evaluate/context/function.py index 0150e0d4..7cd6ad0b 100644 --- a/jedi/evaluate/context/function.py +++ b/jedi/evaluate/context/function.py @@ -169,7 +169,7 @@ class FunctionExecutionContext(TreeContext): returns = get_yield_exprs(self.evaluator, funcdef) else: returns = funcdef.iter_return_stmts() - context_set = pep0484.infer_return_types(self.function_context) + context_set = pep0484.infer_return_types(self) if context_set: # If there are annotations, prefer them over anything else. # This will make it faster. diff --git a/jedi/evaluate/param.py b/jedi/evaluate/param.py index 37e6fd86..db7418cc 100644 --- a/jedi/evaluate/param.py +++ b/jedi/evaluate/param.py @@ -25,11 +25,12 @@ class ExecutedParam(object): self._lazy_context = lazy_context self.string_name = param_node.name.value - def infer(self): - pep0484_hints = pep0484.infer_param(self._execution_context, self._param_node) - doc_params = docstrings.infer_param(self._execution_context, self._param_node) - if pep0484_hints or doc_params: - return pep0484_hints | doc_params + def infer(self, use_hints=True): + if use_hints: + pep0484_hints = pep0484.infer_param(self._execution_context, self._param_node) + doc_params = docstrings.infer_param(self._execution_context, self._param_node) + if pep0484_hints or doc_params: + return pep0484_hints | doc_params return self._lazy_context.infer() diff --git a/jedi/evaluate/pep0484.py b/jedi/evaluate/pep0484.py index 2161dbb3..f029e070 100644 --- a/jedi/evaluate/pep0484.py +++ b/jedi/evaluate/pep0484.py @@ -21,6 +21,7 @@ x support for type hint comments for functions, `# type: (int, str) -> int`. import os import re +from collections import OrderedDict from parso import ParserSyntaxError, parse, split_lines from parso.python import tree @@ -28,9 +29,11 @@ from parso.python import tree from jedi._compatibility import unicode, force_unicode from jedi.evaluate.cache import evaluator_method_cache from jedi.evaluate import compiled -from jedi.evaluate.base_context import NO_CONTEXTS, ContextSet +from jedi.evaluate.base_context import NO_CONTEXTS, ContextSet, ContextWrapper +from jedi.evaluate.filters import DictFilter from jedi.evaluate.lazy_context import LazyTreeContext from jedi.evaluate.context import ModuleContext +from jedi.evaluate.context.typing import TypeVar, AnnotatedClass, AnnotatedSubClass from jedi.evaluate.helpers import is_string, execute_evaluated from jedi import debug from jedi import parser_utils @@ -42,8 +45,7 @@ def _evaluate_for_annotation(context, annotation, index=None): If index is not None, the annotation is expected to be a tuple and we're interested in that index """ - context_set = context.eval_node(_fix_forward_reference(context, annotation)) - return context_set.execute_annotation() + return context.eval_node(_fix_forward_reference(context, annotation)) def _evaluate_annotation_string(context, string, index=None): @@ -173,32 +175,33 @@ def infer_param(execution_context, param): ) # Annotations are like default params and resolve in the same way. context = execution_context.function_context.get_default_param_context() - return _evaluate_for_annotation(context, annotation) + return _evaluate_for_annotation(context, annotation).execute_annotation() def py__annotations__(funcdef): - return_annotation = funcdef.annotation - if return_annotation: - dct = {'return': return_annotation} - else: - dct = {} + dct = OrderedDict() for function_param in funcdef.get_params(): param_annotation = function_param.annotation if param_annotation is not None: dct[function_param.name.value] = param_annotation + + return_annotation = funcdef.annotation + if return_annotation: + dct['return'] = return_annotation return dct @evaluator_method_cache() -def infer_return_types(function_context): +def infer_return_types(function_execution_context): """ Infers the type of a function's return value, according to type annotations. """ - annotation = py__annotations__(function_context.tree_node).get("return", None) + all_annotations = py__annotations__(function_execution_context.tree_node) + annotation = all_annotations.get("return", None) if annotation is None: # If there is no Python 3-type annotation, look for a Python 2-type annotation - node = function_context.tree_node + node = function_execution_context.tree_node comment = parser_utils.get_following_comment_same_line(node) if comment is None: return NO_CONTEXTS @@ -208,12 +211,114 @@ def infer_return_types(function_context): return NO_CONTEXTS return _evaluate_annotation_string( - function_context.get_default_param_context(), + function_execution_context.function_context.get_default_param_context(), match.group(1).strip() ) - context = function_context.get_default_param_context() - return _evaluate_for_annotation(context, annotation) + context = function_execution_context.function_context.get_default_param_context() + return _define_type_vars( + context.eval_node(annotation), + _infer_type_vars(function_execution_context, all_annotations), + ).execute_annotation() + + +def _infer_type_vars(execution_context, annotation_dict): + """ + Some functions use type vars that are not defined by the class, but rather + only defined in the function. See for example `iter`. In those cases we + want to: + + 1. Search for undefined type vars. + 2. Infer type vars with the execution state we have. + 3. Return the union of all type vars that have been found. + """ + context = execution_context.function_context.get_default_param_context() + try: + return_annotation = annotation_dict['return'] + except KeyError: + return {} + + unknown_type_vars = list(find_annotation_variables(context, return_annotation)) + if not unknown_type_vars: + return {} + + executed_params = execution_context.get_executed_params() + annotation_variable_results = {} + # The annotation_dict is ordered. + for i, (annotation_name, annotation_node) in enumerate(annotation_dict.items()): + if annotation_name == 'return': + continue + + annotation_variables = find_annotation_variables(context, annotation_node) + if annotation_variables: + try: + param = executed_params[i] + except IndexError: + continue + + # Infer unknown type var + annotation_context_set = context.eval_node(annotation_node) + actual_context_set = param.infer(use_hints=False) + for ann in annotation_context_set: + _merge_type_var_dicts( + annotation_variable_results, + _unpack_type_vars(ann, actual_context_set), + ) + + return annotation_variable_results + + +def _define_type_vars(annotation_contexts, type_var_dict): + def remap_type_vars(type_var_contexts): + return ContextSet.from_sets( + type_var_dict.get(type_var, ContextSet(type_var)) + for type_var in type_var_contexts + ) + + if not type_var_dict: + return annotation_contexts + + context_set = ContextSet() + for annotation_context in annotation_contexts: + if isinstance(annotation_context, AnnotatedClass): + context_set |= ContextSet.from_iterable([ + AnnotatedSubClass( + annotation_context.evaluator, + annotation_context.parent_context, + annotation_context.tree_node, + tuple(remap_type_vars(tcs) + for tcs in annotation_context.get_given_types()) + ) + ]) + return context_set + + +def _merge_type_var_dicts(base_dict, new_dict): + for type_var, contexts in new_dict.items(): + try: + base_dict[type_var] = contexts + except KeyError: + base_dict[type_var] |= contexts + + +def _unpack_type_vars(annotation_context, context_set): + type_var_dict = {} + if isinstance(annotation_context, TypeVar): + return {annotation_context: context_set.py__class__()} + elif isinstance(annotation_context, AnnotatedClass): + name = annotation_context.py__name__() + if name == 'Iterable': + given = annotation_context.get_given_types() + if given: + for nested_annotation_context in given[0]: + _merge_type_var_dicts( + type_var_dict, + _unpack_type_vars( + nested_annotation_context, + context_set.merge_types_of_iterate() + ) + ) + return type_var_dict _typing_module = None @@ -337,3 +442,28 @@ def _find_type_from_comment_hint(context, node, varlist, name): if match is None: return [] return _evaluate_annotation_string(context, match.group(1).strip(), index) + + +def find_annotation_variables(context, node): + found = [] + if node.type == 'atom_expr': + trailer = node.children[-1] + if trailer.type == 'trailer' and trailer.children[0] == '[': + for subscript_node in _unpack_subscriptlist(trailer.children[1]): + type_var_set = context.eval_node(subscript_node) + for type_var in type_var_set: + from jedi.evaluate.context.typing import TypeVar + + if isinstance(type_var, TypeVar) and type_var not in found: + found.append(type_var) + return found + + +def _unpack_subscriptlist(subscriptlist): + if subscriptlist.type == 'subscriptlist': + for subscript in subscriptlist.children[::2]: + if subscript.type != 'subscript': + yield subscript + else: + if subscriptlist.type != 'subscript': + yield subscriptlist