Extract the annotation name upfront

We almost always need this and this simplifies the code within
each branch. This also means we'll be able to the name to determine
the branching.
This commit is contained in:
Peter Law
2020-02-22 19:40:10 +00:00
parent 36b4b797c1
commit 6efafb348e

View File

@@ -333,13 +333,14 @@ def _infer_type_vars(annotation_value, value_set, is_class_value=False):
unpacks the `Iterable`.
"""
type_var_dict = {}
annotation_name = annotation_value.py__name__()
if isinstance(annotation_value, TypeVar):
if not is_class_value:
return {annotation_value.py__name__(): value_set.py__class__()}
return {annotation_value.py__name__(): value_set}
return {annotation_name: value_set.py__class__()}
return {annotation_name: value_set}
elif isinstance(annotation_value, TypingClassValueWithIndex):
name = annotation_value.py__name__()
if name == 'Type':
if annotation_name == 'Type':
given = annotation_value.get_generics()
if given:
for nested_annotation_value in given[0]:
@@ -351,7 +352,7 @@ def _infer_type_vars(annotation_value, value_set, is_class_value=False):
is_class_value=True,
)
)
elif name == 'Callable':
elif annotation_name == 'Callable':
given = annotation_value.get_generics()
if len(given) == 2:
for nested_annotation_value in given[1]:
@@ -363,8 +364,7 @@ def _infer_type_vars(annotation_value, value_set, is_class_value=False):
)
)
elif isinstance(annotation_value, GenericClass):
name = annotation_value.py__name__()
if name == 'Iterable':
if annotation_name == 'Iterable':
given = annotation_value.get_generics()
if given:
for nested_annotation_value in given[0]:
@@ -375,7 +375,7 @@ def _infer_type_vars(annotation_value, value_set, is_class_value=False):
value_set.merge_types_of_iterate(),
)
)
elif name == 'Mapping':
elif annotation_name == 'Mapping':
given = annotation_value.get_generics()
if len(given) == 2:
for value in value_set: