diff --git a/jedi/inference/gradual/annotation.py b/jedi/inference/gradual/annotation.py index 957ec660..c2abacfb 100644 --- a/jedi/inference/gradual/annotation.py +++ b/jedi/inference/gradual/annotation.py @@ -37,12 +37,15 @@ def infer_annotation(context, annotation): "Inferred typing index %s should lead to 1 object, not %s" % (annotation, value_set)) return value_set - inferred_value = list(value_set)[0] - if is_string(inferred_value): - result = _get_forward_reference_node(context, inferred_value.get_safe_value()) - if result is not None: - return context.infer_node(result) - return value_set + strings_removed = NO_VALUES + for part in value_set: + if is_string(part): + result = _get_forward_reference_node(context, part.get_safe_value()) + if result is not None: + strings_removed |= context.infer_node(result) + continue + strings_removed |= ValueSet([part]) + return strings_removed def _infer_annotation_string(context, string, index=None): diff --git a/test/completion/pep0484_basic.py b/test/completion/pep0484_basic.py index 35edcb8c..f1d7677e 100644 --- a/test/completion/pep0484_basic.py +++ b/test/completion/pep0484_basic.py @@ -230,3 +230,11 @@ def use_type_with_annotation() -> type[int]: ... #? int use_type_with_annotation() + +def union_with_forward_references(x: int | "str", y: "int" | str, z: "int | str"): + #? int() str() + x + #? int() str() + y + #? int() str() + z