diff --git a/jedi/inference/gradual/annotation.py b/jedi/inference/gradual/annotation.py index eadcfc2e..81780010 100644 --- a/jedi/inference/gradual/annotation.py +++ b/jedi/inference/gradual/annotation.py @@ -196,13 +196,43 @@ def py__annotations__(funcdef): return dct +def resolve_forward_references(context, all_annotations): + def resolve(node): + if node is None or node.type != 'string': + return node + + node = _get_forward_reference_node( + context, + context.inference_state.compiled_subprocess.safe_literal_eval( + node.value, + ), + ) + + if node is None: + # There was a string, but it's not a valid annotation + return None + + # The forward reference tree has an additional root node ('eval_input') + # that we don't want. Extract the node we do want, that is equivalent to + # the nodes returned by `py__annotations__` for a non-quoted node. + node = node.children[0] + + return node + + return {name: resolve(node) for name, node in all_annotations.items()} + + @inference_state_method_cache() def infer_return_types(function, arguments): """ Infers the type of a function's return value, according to type annotations. """ - all_annotations = py__annotations__(function.tree_node) + context = function.get_default_param_context() + all_annotations = resolve_forward_references( + context, + py__annotations__(function.tree_node), + ) annotation = all_annotations.get("return", None) if annotation is None: # If there is no Python 3-type annotation, look for an annotation @@ -217,11 +247,10 @@ def infer_return_types(function, arguments): return NO_VALUES return _infer_annotation_string( - function.get_default_param_context(), + context, match.group(1).strip() ).execute_annotation() - context = function.get_default_param_context() unknown_type_vars = find_unknown_type_vars(context, annotation) annotation_values = infer_annotation(context, annotation) if not unknown_type_vars: diff --git a/test/completion/pep0484_generic_passthroughs.py b/test/completion/pep0484_generic_passthroughs.py index 49b133a8..7c7b8820 100644 --- a/test/completion/pep0484_generic_passthroughs.py +++ b/test/completion/pep0484_generic_passthroughs.py @@ -58,6 +58,16 @@ def typed_bound_generic_passthrough(x: TList) -> TList: return x +# Forward references are more likely with custom types, however this aims to +# test just the handling of the quoted type rather than any other part of the +# machinery. +def typed_quoted_return_generic_passthrough(x: T) -> 'List[T]': + return [x] + +def typed_quoted_input_generic_passthrough(x: 'Tuple[T]') -> T: + x + return x[0] + for a in untyped_passthrough(untyped_list_str): #? str() @@ -146,6 +156,23 @@ for q in typed_bound_generic_passthrough(typed_list_str): q +for r in typed_quoted_return_generic_passthrough("something"): + #? str() + r + +for s in typed_quoted_return_generic_passthrough(42): + #? int() + s + + +#? str() +typed_quoted_input_generic_passthrough(("something",)) + +#? int() +typed_quoted_input_generic_passthrough((42,)) + + + class CustomList(List): def get_first(self): return self[0]