Make it possible to infer Callable TypeVars, fixes #1449

This commit is contained in:
Dave Halter
2019-12-12 23:22:52 +01:00
parent 536a77551b
commit e656a5f18f
6 changed files with 101 additions and 5 deletions

View File

@@ -277,6 +277,39 @@ def infer_type_vars_for_execution(function, arguments, annotation_dict):
return annotation_variable_results
def infer_return_for_callable(arguments, param_values, result_values):
result = NO_VALUES
for pv in param_values:
if pv.array_type == 'list':
type_var_dict = infer_type_vars_for_callable(arguments, pv.py__iter__())
result |= ValueSet.from_sets(
v.define_generics(type_var_dict)
if isinstance(v, (DefineGenericBase, TypeVar)) else ValueSet({v})
for v in result_values
).execute_annotation()
return result
def infer_type_vars_for_callable(arguments, lazy_params):
"""
Infers type vars for the Calllable class:
def x() -> Callable[[Callable[..., _T]], _T]: ...
"""
annotation_variable_results = {}
for (_, lazy_value), lazy_callable_param in zip(arguments.unpack(), lazy_params):
callable_param_values = lazy_callable_param.infer()
# Infer unknown type var
actual_value_set = lazy_value.infer()
for v in callable_param_values:
_merge_type_var_dicts(
annotation_variable_results,
_infer_type_vars(v, actual_value_set),
)
return annotation_variable_results
def _merge_type_var_dicts(base_dict, new_dict):
for type_var_name, values in new_dict.items():
if values:
@@ -419,16 +452,21 @@ def find_unknown_type_vars(context, node):
for subscript_node in _unpack_subscriptlist(trailer.children[1]):
check_node(subscript_node)
else:
type_var_set = context.infer_node(node)
for type_var in type_var_set:
if isinstance(type_var, TypeVar) and type_var not in found:
found.append(type_var)
found[:] = _filter_type_vars(context.infer_node(node), found)
found = [] # We're not using a set, because the order matters.
check_node(node)
return found
def _filter_type_vars(value_set, found=()):
new_found = list(found)
for type_var in value_set:
if isinstance(type_var, TypeVar) and type_var not in found:
new_found.append(type_var)
return new_found
def _unpack_subscriptlist(subscriptlist):
if subscriptlist.type == 'subscriptlist':
for subscript in subscriptlist.children[::2]:

View File

@@ -297,6 +297,9 @@ class BaseTypingValue(LazyValueWrapper):
def _get_wrapped_value(self):
return _PseudoTreeNameClass(self.parent_context, self._tree_name)
def __repr__(self):
return '%s(%s)' % (self.__class__.__name__, self._tree_name.value)
class BaseTypingValueWithGenerics(DefineGenericBase):
def __init__(self, parent_context, tree_name, generics_manager):
@@ -307,3 +310,7 @@ class BaseTypingValueWithGenerics(DefineGenericBase):
def _get_wrapped_value(self):
return _PseudoTreeNameClass(self.parent_context, self._tree_name)
def __repr__(self):
return '%s(%s%s)' % (self.__class__.__name__, self._tree_name.value,
self._generics_manager)

View File

@@ -74,6 +74,9 @@ class LazyGenericManager(_AbstractGenericManager):
return True
return False
def __repr__(self):
return '<LazyG>[%s]' % (', '.join(repr(x) for x in self.to_tuple()))
class TupleGenericManager(_AbstractGenericManager):
def __init__(self, tup):
@@ -90,3 +93,6 @@ class TupleGenericManager(_AbstractGenericManager):
def is_homogenous_tuple(self):
return False
def __repr__(self):
return '<TupG>[%s]' % (', '.join(repr(x) for x in self.to_tuple()))

View File

@@ -216,8 +216,19 @@ class TypeAlias(LazyValueWrapper):
class Callable(BaseTypingValueWithGenerics):
def py__call__(self, arguments):
"""
def x() -> Callable[[Callable[..., _T]], _T]: ...
"""
# The 0th index are the arguments.
return self._generics_manager.get_index_and_execute(1)
try:
param_values = self._generics_manager[0]
result_values = self._generics_manager[1]
except IndexError:
debug.warning('Callable[...] defined without two arguments')
return NO_VALUES
else:
from jedi.inference.gradual.annotation import infer_return_for_callable
return infer_return_for_callable(arguments, param_values, result_values)
class Tuple(LazyValueWrapper):

View File

@@ -420,6 +420,19 @@ xxx([0])[1]
#?
xxx([0])[2]
def call_pls() -> typing.Callable[[TYPE_VARX], TYPE_VARX]: ...
#? int()
call_pls()(1)
def call2_pls() -> typing.Callable[[str, typing.Callable[[int], TYPE_VARX]], TYPE_VARX]: ...
#? float()
call2_pls('')(1, lambda x: 3.0)
def call3_pls() -> typing.Callable[[typing.Callable[[int], TYPE_VARX]], typing.List[TYPE_VARX]]: ...
def the_callable() -> float: ...
#? float()
call3_pls()(the_callable)[0]
# -------------------------
# TYPE_CHECKING
# -------------------------

View File

@@ -329,3 +329,24 @@ X.attr_y.value
X().name
#? float()
X().attr_x.attr_y.value
# -----------------
# functools Python 3.8
# -----------------
# python >= 3.8
@functools.lru_cache
def x() -> int: ...
@functools.lru_cache()
def y() -> float: ...
@functools.lru_cache(8)
def z() -> str: ...
#? int()
x()
#? float()
y()
#? str()
z()