diff --git a/jedi/inference/gradual/annotation.py b/jedi/inference/gradual/annotation.py index 9f53a267..81780010 100644 --- a/jedi/inference/gradual/annotation.py +++ b/jedi/inference/gradual/annotation.py @@ -196,6 +196,32 @@ def py__annotations__(funcdef): return dct +def resolve_forward_references(context, all_annotations): + def resolve(node): + if node is None or node.type != 'string': + return node + + node = _get_forward_reference_node( + context, + context.inference_state.compiled_subprocess.safe_literal_eval( + node.value, + ), + ) + + if node is None: + # There was a string, but it's not a valid annotation + return None + + # The forward reference tree has an additional root node ('eval_input') + # that we don't want. Extract the node we do want, that is equivalent to + # the nodes returned by `py__annotations__` for a non-quoted node. + node = node.children[0] + + return node + + return {name: resolve(node) for name, node in all_annotations.items()} + + @inference_state_method_cache() def infer_return_types(function, arguments): """ @@ -203,7 +229,10 @@ def infer_return_types(function, arguments): according to type annotations. """ context = function.get_default_param_context() - all_annotations = py__annotations__(function.tree_node) + all_annotations = resolve_forward_references( + context, + py__annotations__(function.tree_node), + ) annotation = all_annotations.get("return", None) if annotation is None: # If there is no Python 3-type annotation, look for an annotation @@ -222,18 +251,6 @@ def infer_return_types(function, arguments): match.group(1).strip() ).execute_annotation() - elif annotation.type == 'string': - annotation = _get_forward_reference_node( - context, - context.inference_state.compiled_subprocess.safe_literal_eval( - annotation.value, - ), - ) - # The forward reference tree has an additional root node ('eval_input') - # that we don't want. Extract the node we do want, that is equivalent to - # the nodes returned by `py__annotations__` for a non-quoted annotation. - annotation = annotation.children[0] - unknown_type_vars = find_unknown_type_vars(context, annotation) annotation_values = infer_annotation(context, annotation) if not unknown_type_vars: diff --git a/test/completion/pep0484_generic_passthroughs.py b/test/completion/pep0484_generic_passthroughs.py index e250b1a4..7c7b8820 100644 --- a/test/completion/pep0484_generic_passthroughs.py +++ b/test/completion/pep0484_generic_passthroughs.py @@ -61,9 +61,13 @@ def typed_bound_generic_passthrough(x: TList) -> TList: # Forward references are more likely with custom types, however this aims to # test just the handling of the quoted type rather than any other part of the # machinery. -def typed_quoted_generic_passthrough(x: T) -> 'List[T]': +def typed_quoted_return_generic_passthrough(x: T) -> 'List[T]': return [x] +def typed_quoted_input_generic_passthrough(x: 'Tuple[T]') -> T: + x + return x[0] + for a in untyped_passthrough(untyped_list_str): #? str() @@ -152,15 +156,22 @@ for q in typed_bound_generic_passthrough(typed_list_str): q -for r in typed_quoted_generic_passthrough("something"): +for r in typed_quoted_return_generic_passthrough("something"): #? str() r -for s in typed_quoted_generic_passthrough(42): +for s in typed_quoted_return_generic_passthrough(42): #? int() s +#? str() +typed_quoted_input_generic_passthrough(("something",)) + +#? int() +typed_quoted_input_generic_passthrough((42,)) + + class CustomList(List): def get_first(self):