diff --git a/jedi/inference/gradual/annotation.py b/jedi/inference/gradual/annotation.py index cfde9f0f..0c47454a 100644 --- a/jedi/inference/gradual/annotation.py +++ b/jedi/inference/gradual/annotation.py @@ -14,7 +14,8 @@ from jedi.inference.cache import inference_state_method_cache from jedi.inference.base_value import ValueSet, NO_VALUES from jedi.inference.gradual.typing import TypeVar, LazyGenericClass, \ AbstractAnnotatedClass -from jedi.inference.gradual.typing import GenericClass +from jedi.inference.gradual.typing import GenericClass, \ + TypingClassValueWithIndex from jedi.inference.helpers import is_string from jedi.inference.compiled import builtin_from_name from jedi.inference.param import get_executed_param_names @@ -270,7 +271,6 @@ def infer_type_vars_for_execution(function, arguments, annotation_dict): annotation_variable_results, _infer_type_vars(ann, actual_value_set), ) - return annotation_variable_results @@ -283,7 +283,7 @@ def _merge_type_var_dicts(base_dict, new_dict): base_dict[type_var_name] = values -def _infer_type_vars(annotation_value, value_set): +def _infer_type_vars(annotation_value, value_set, is_class_value=False): """ This function tries to find information about undefined type vars and returns a dict from type var name to value set. @@ -298,7 +298,23 @@ def _infer_type_vars(annotation_value, value_set): """ type_var_dict = {} if isinstance(annotation_value, TypeVar): - return {annotation_value.py__name__(): value_set.py__class__()} + if not is_class_value: + return {annotation_value.py__name__(): value_set.py__class__()} + return {annotation_value.py__name__(): value_set} + elif isinstance(annotation_value, TypingClassValueWithIndex): + name = annotation_value.py__name__() + if name == 'Type': + given = annotation_value.get_generics() + if given: + for nested_annotation_value in given[0]: + _merge_type_var_dicts( + type_var_dict, + _infer_type_vars( + nested_annotation_value, + value_set, + is_class_value=True, + ) + ) elif isinstance(annotation_value, LazyGenericClass): name = annotation_value.py__name__() if name == 'Iterable': diff --git a/jedi/inference/gradual/typing.py b/jedi/inference/gradual/typing.py index d401b9e2..c8f64db9 100644 --- a/jedi/inference/gradual/typing.py +++ b/jedi/inference/gradual/typing.py @@ -219,7 +219,10 @@ class _TypingClassMixin(ClassMixin): class TypingClassValueWithIndex(_TypingClassMixin, TypingValueWithIndex): - pass + + @inference_state_method_cache() + def get_generics(self): + return list(_iter_over_arguments(self._index_value, self._context_of_index)) class TypingClassValue(_TypingClassMixin, TypingValue): diff --git a/test/completion/pep0484_typing.py b/test/completion/pep0484_typing.py index 163050a8..ae49d10d 100644 --- a/test/completion/pep0484_typing.py +++ b/test/completion/pep0484_typing.py @@ -363,6 +363,17 @@ in_out1(str()) #? in_out1() +def type_in_out1(x: typing.Type[TYPE_VARX]) -> TYPE_VARX: ... + +#? int() +type_in_out1(int) +#? str() +type_in_out1(str) +#? float() +type_in_out1(float) +#? +type_in_out1() + def in_out2(x: TYPE_VAR_CONSTRAINTSX) -> TYPE_VAR_CONSTRAINTSX: ... #? int() @@ -377,6 +388,18 @@ in_out2() #? float() in_out2(1.0) +def type_in_out2(x: typing.Type[TYPE_VAR_CONSTRAINTSX]) -> TYPE_VAR_CONSTRAINTSX: ... + +#? int() +type_in_out2(int) +#? str() +type_in_out2(str) +#? str() int() +type_in_out2() +# TODO this should actually be str() int(), because of the constraints. +#? float() +type_in_out2(float) + # ------------------------- # TYPE_CHECKING # -------------------------