Merge pull request #1794 from PeterJCLaw/fix-quoted-generic-forward-refs

Fix quoted generic annotations
This commit is contained in:
Dave Halter
2021-07-25 20:02:38 +02:00
committed by GitHub
2 changed files with 59 additions and 3 deletions

View File

@@ -196,13 +196,43 @@ def py__annotations__(funcdef):
return dct
def resolve_forward_references(context, all_annotations):
def resolve(node):
if node is None or node.type != 'string':
return node
node = _get_forward_reference_node(
context,
context.inference_state.compiled_subprocess.safe_literal_eval(
node.value,
),
)
if node is None:
# There was a string, but it's not a valid annotation
return None
# The forward reference tree has an additional root node ('eval_input')
# that we don't want. Extract the node we do want, that is equivalent to
# the nodes returned by `py__annotations__` for a non-quoted node.
node = node.children[0]
return node
return {name: resolve(node) for name, node in all_annotations.items()}
@inference_state_method_cache()
def infer_return_types(function, arguments):
"""
Infers the type of a function's return value,
according to type annotations.
"""
all_annotations = py__annotations__(function.tree_node)
context = function.get_default_param_context()
all_annotations = resolve_forward_references(
context,
py__annotations__(function.tree_node),
)
annotation = all_annotations.get("return", None)
if annotation is None:
# If there is no Python 3-type annotation, look for an annotation
@@ -217,11 +247,10 @@ def infer_return_types(function, arguments):
return NO_VALUES
return _infer_annotation_string(
function.get_default_param_context(),
context,
match.group(1).strip()
).execute_annotation()
context = function.get_default_param_context()
unknown_type_vars = find_unknown_type_vars(context, annotation)
annotation_values = infer_annotation(context, annotation)
if not unknown_type_vars:

View File

@@ -58,6 +58,16 @@ def typed_bound_generic_passthrough(x: TList) -> TList:
return x
# Forward references are more likely with custom types, however this aims to
# test just the handling of the quoted type rather than any other part of the
# machinery.
def typed_quoted_return_generic_passthrough(x: T) -> 'List[T]':
return [x]
def typed_quoted_input_generic_passthrough(x: 'Tuple[T]') -> T:
x
return x[0]
for a in untyped_passthrough(untyped_list_str):
#? str()
@@ -146,6 +156,23 @@ for q in typed_bound_generic_passthrough(typed_list_str):
q
for r in typed_quoted_return_generic_passthrough("something"):
#? str()
r
for s in typed_quoted_return_generic_passthrough(42):
#? int()
s
#? str()
typed_quoted_input_generic_passthrough(("something",))
#? int()
typed_quoted_input_generic_passthrough((42,))
class CustomList(List):
def get_first(self):
return self[0]