Fix some more context issues

This commit is contained in:
Dave Halter
2019-08-19 19:33:12 +02:00
parent f54617867d
commit b19ba12566
7 changed files with 41 additions and 43 deletions

View File

@@ -120,6 +120,10 @@ class HelperValueMixin(object):
return class2.is_same_class(self) return class2.is_same_class(self)
return self == class2 return self == class2
@memoize_method
def as_context(self):
return self._as_context()
class Value(HelperValueMixin, BaseValue): class Value(HelperValueMixin, BaseValue):
""" """
@@ -222,10 +226,6 @@ class Value(HelperValueMixin, BaseValue):
# The root value knows if it's a stub or not. # The root value knows if it's a stub or not.
return self.parent_context.is_stub() return self.parent_context.is_stub()
@memoize_method
def as_context(self):
return self._as_context()
def _as_context(self): def _as_context(self):
raise NotImplementedError('Not all values need to be converted to contexts') raise NotImplementedError('Not all values need to be converted to contexts')

View File

@@ -21,7 +21,7 @@ from jedi import debug
from jedi import parser_utils from jedi import parser_utils
def infer_annotation(value, annotation): def infer_annotation(context, annotation):
""" """
Inferes an annotation node. This means that it inferes the part of Inferes an annotation node. This means that it inferes the part of
`int` here: `int` here:
@@ -30,7 +30,7 @@ def infer_annotation(value, annotation):
Also checks for forward references (strings) Also checks for forward references (strings)
""" """
value_set = value.infer_node(annotation) value_set = context.infer_node(annotation)
if len(value_set) != 1: if len(value_set) != 1:
debug.warning("Inferred typing index %s should lead to 1 object, " debug.warning("Inferred typing index %s should lead to 1 object, "
" not %s" % (annotation, value_set)) " not %s" % (annotation, value_set))
@@ -38,18 +38,18 @@ def infer_annotation(value, annotation):
inferred_value = list(value_set)[0] inferred_value = list(value_set)[0]
if is_string(inferred_value): if is_string(inferred_value):
result = _get_forward_reference_node(value, inferred_value.get_safe_value()) result = _get_forward_reference_node(context, inferred_value.get_safe_value())
if result is not None: if result is not None:
return value.infer_node(result) return context.infer_node(result)
return value_set return value_set
def _infer_annotation_string(value, string, index=None): def _infer_annotation_string(context, string, index=None):
node = _get_forward_reference_node(value, string) node = _get_forward_reference_node(context, string)
if node is None: if node is None:
return NO_VALUES return NO_VALUES
value_set = value.infer_node(node) value_set = context.infer_node(node)
if index is not None: if index is not None:
value_set = value_set.filter( value_set = value_set.filter(
lambda value: value.array_type == u'tuple' # noqa lambda value: value.array_type == u'tuple' # noqa
@@ -58,9 +58,9 @@ def _infer_annotation_string(value, string, index=None):
return value_set return value_set
def _get_forward_reference_node(value, string): def _get_forward_reference_node(context, string):
try: try:
new_node = value.inference_state.grammar.parse( new_node = context.inference_state.grammar.parse(
force_unicode(string), force_unicode(string),
start_symbol='eval_input', start_symbol='eval_input',
error_recovery=False error_recovery=False
@@ -69,9 +69,9 @@ def _get_forward_reference_node(value, string):
debug.warning('Annotation not parsed: %s' % string) debug.warning('Annotation not parsed: %s' % string)
return None return None
else: else:
module = value.tree_node.get_root_node() module = context.tree_node.get_root_node()
parser_utils.move(new_node, module.end_pos[0]) parser_utils.move(new_node, module.end_pos[0])
new_node.parent = value.tree_node new_node.parent = context.tree_node
return new_node return new_node
@@ -173,8 +173,8 @@ def _infer_param(execution_context, param):
param_comment param_comment
) )
# Annotations are like default params and resolve in the same way. # Annotations are like default params and resolve in the same way.
value = execution_context.function_value.get_default_param_context() context = execution_context.function_value.get_default_param_context()
return infer_annotation(value, annotation) return infer_annotation(context, annotation)
def py__annotations__(funcdef): def py__annotations__(funcdef):
@@ -216,9 +216,9 @@ def infer_return_types(function_execution_context):
if annotation is None: if annotation is None:
return NO_VALUES return NO_VALUES
value = function_execution_context.function_value.get_default_param_context() context = function_execution_context.function_value.get_default_param_context()
unknown_type_vars = list(find_unknown_type_vars(value, annotation)) unknown_type_vars = list(find_unknown_type_vars(context, annotation))
annotation_values = infer_annotation(value, annotation) annotation_values = infer_annotation(context, annotation)
if not unknown_type_vars: if not unknown_type_vars:
return annotation_values.execute_annotation() return annotation_values.execute_annotation()
@@ -241,7 +241,7 @@ def infer_type_vars_for_execution(execution_context, annotation_dict):
2. Infer type vars with the execution state we have. 2. Infer type vars with the execution state we have.
3. Return the union of all type vars that have been found. 3. Return the union of all type vars that have been found.
""" """
value = execution_context.function_value.get_default_param_context() context = execution_context.function_value.get_default_param_context()
annotation_variable_results = {} annotation_variable_results = {}
executed_params, _ = execution_context.get_executed_params_and_issues() executed_params, _ = execution_context.get_executed_params_and_issues()
@@ -251,10 +251,10 @@ def infer_type_vars_for_execution(execution_context, annotation_dict):
except KeyError: except KeyError:
continue continue
annotation_variables = find_unknown_type_vars(value, annotation_node) annotation_variables = find_unknown_type_vars(context, annotation_node)
if annotation_variables: if annotation_variables:
# Infer unknown type var # Infer unknown type var
annotation_value_set = value.infer_node(annotation_node) annotation_value_set = context.infer_node(annotation_node)
star_count = executed_param._param_node.star_count star_count = executed_param._param_node.star_count
actual_value_set = executed_param.infer(use_hints=False) actual_value_set = executed_param.infer(use_hints=False)
if star_count == 1: if star_count == 1:
@@ -337,22 +337,22 @@ def _infer_type_vars(annotation_value, value_set):
return type_var_dict return type_var_dict
def find_type_from_comment_hint_for(value, node, name): def find_type_from_comment_hint_for(context, node, name):
return _find_type_from_comment_hint(value, node, node.children[1], name) return _find_type_from_comment_hint(context, node, node.children[1], name)
def find_type_from_comment_hint_with(value, node, name): def find_type_from_comment_hint_with(context, node, name):
assert len(node.children[1].children) == 3, \ assert len(node.children[1].children) == 3, \
"Can only be here when children[1] is 'foo() as f'" "Can only be here when children[1] is 'foo() as f'"
varlist = node.children[1].children[2] varlist = node.children[1].children[2]
return _find_type_from_comment_hint(value, node, varlist, name) return _find_type_from_comment_hint(context, node, varlist, name)
def find_type_from_comment_hint_assign(value, node, name): def find_type_from_comment_hint_assign(context, node, name):
return _find_type_from_comment_hint(value, node, node.children[0], name) return _find_type_from_comment_hint(context, node, node.children[0], name)
def _find_type_from_comment_hint(value, node, varlist, name): def _find_type_from_comment_hint(context, node, varlist, name):
index = None index = None
if varlist.type in ("testlist_star_expr", "exprlist", "testlist"): if varlist.type in ("testlist_star_expr", "exprlist", "testlist"):
# something like "a, b = 1, 2" # something like "a, b = 1, 2"
@@ -373,11 +373,11 @@ def _find_type_from_comment_hint(value, node, varlist, name):
if match is None: if match is None:
return [] return []
return _infer_annotation_string( return _infer_annotation_string(
value, match.group(1).strip(), index context, match.group(1).strip(), index
).execute_annotation() ).execute_annotation()
def find_unknown_type_vars(value, node): def find_unknown_type_vars(context, node):
def check_node(node): def check_node(node):
if node.type in ('atom_expr', 'power'): if node.type in ('atom_expr', 'power'):
trailer = node.children[-1] trailer = node.children[-1]
@@ -385,7 +385,7 @@ def find_unknown_type_vars(value, node):
for subscript_node in _unpack_subscriptlist(trailer.children[1]): for subscript_node in _unpack_subscriptlist(trailer.children[1]):
check_node(subscript_node) check_node(subscript_node)
else: else:
type_var_set = value.infer_node(node) type_var_set = context.infer_node(node)
for type_var in type_var_set: for type_var in type_var_set:
if isinstance(type_var, TypeVar) and type_var not in found: if isinstance(type_var, TypeVar) and type_var not in found:
found.append(type_var) found.append(type_var)

View File

@@ -7,6 +7,7 @@ from jedi._compatibility import FileNotFoundError, cast_path
from jedi.parser_utils import get_cached_code_lines from jedi.parser_utils import get_cached_code_lines
from jedi.inference.base_value import ValueSet, NO_VALUES from jedi.inference.base_value import ValueSet, NO_VALUES
from jedi.inference.gradual.stub_value import TypingModuleWrapper, StubModuleValue from jedi.inference.gradual.stub_value import TypingModuleWrapper, StubModuleValue
from jedi.inference.context import ModuleContext
_jedi_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) _jedi_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
TYPESHED_PATH = os.path.join(_jedi_path, 'third_party', 'typeshed') TYPESHED_PATH = os.path.join(_jedi_path, 'third_party', 'typeshed')
@@ -235,7 +236,7 @@ def _load_from_typeshed(inference_state, python_value_set, parent_module_context
if len(import_names) == 1: if len(import_names) == 1:
map_ = _cache_stub_file_map(inference_state.grammar.version_info) map_ = _cache_stub_file_map(inference_state.grammar.version_info)
import_name = _IMPORT_MAP.get(import_name, import_name) import_name = _IMPORT_MAP.get(import_name, import_name)
elif isinstance(parent_module_context, StubModuleValue): elif isinstance(parent_module_context, ModuleContext):
if not parent_module_context.is_package: if not parent_module_context.is_package:
# Only if it's a package (= a folder) something can be # Only if it's a package (= a folder) something can be
# imported. # imported.

View File

@@ -563,7 +563,6 @@ class AnnotatedClassContext(ClassContext):
yield self._value.get_type_var_filter() yield self._value.get_type_var_filter()
class AbstractAnnotatedClass(ClassMixin, ValueWrapper): class AbstractAnnotatedClass(ClassMixin, ValueWrapper):
def get_type_var_filter(self): def get_type_var_filter(self):
return TypeVarFilter(self.get_generics(), self.list_type_vars()) return TypeVarFilter(self.get_generics(), self.list_type_vars())

View File

@@ -153,7 +153,7 @@ class MethodValue(FunctionValue):
self.class_value = class_value self.class_value = class_value
def get_default_param_context(self): def get_default_param_context(self):
return self.class_value return self.class_value.as_context()
def get_qualified_names(self): def get_qualified_names(self):
# Need to implement this, because the parent value of a method # Need to implement this, because the parent value of a method

View File

@@ -473,7 +473,7 @@ class InstanceClassFilter(AbstractFilter):
] ]
def __repr__(self): def __repr__(self):
return '<%s for %s>' % (self.__class__.__name__, self._class_filter.value) return '<%s for %s>' % (self.__class__.__name__, self._class_filter.context)
class SelfAttributeFilter(ClassFilter): class SelfAttributeFilter(ClassFilter):

View File

@@ -292,10 +292,8 @@ class DictComprehension(ComprehensionMixin, Sequence):
def py__simple_getitem__(self, index): def py__simple_getitem__(self, index):
for keys, values in self._iterate(): for keys, values in self._iterate():
for k in keys: for k in keys:
# TODO remove this isinstance.
if isinstance(k, compiled.CompiledObject):
# Be careful in the future if refactoring, index could be a # Be careful in the future if refactoring, index could be a
# slice. # slice object.
if k.get_safe_value(default=object()) == index: if k.get_safe_value(default=object()) == index:
return values return values
raise SimpleGetItemNotFound() raise SimpleGetItemNotFound()