diff --git a/jedi/inference/base_value.py b/jedi/inference/base_value.py index a2346977..ed20ddfb 100644 --- a/jedi/inference/base_value.py +++ b/jedi/inference/base_value.py @@ -120,6 +120,10 @@ class HelperValueMixin(object): return class2.is_same_class(self) return self == class2 + @memoize_method + def as_context(self): + return self._as_context() + class Value(HelperValueMixin, BaseValue): """ @@ -222,10 +226,6 @@ class Value(HelperValueMixin, BaseValue): # The root value knows if it's a stub or not. return self.parent_context.is_stub() - @memoize_method - def as_context(self): - return self._as_context() - def _as_context(self): raise NotImplementedError('Not all values need to be converted to contexts') diff --git a/jedi/inference/gradual/annotation.py b/jedi/inference/gradual/annotation.py index fd6c71b9..f291a1e4 100644 --- a/jedi/inference/gradual/annotation.py +++ b/jedi/inference/gradual/annotation.py @@ -21,7 +21,7 @@ from jedi import debug from jedi import parser_utils -def infer_annotation(value, annotation): +def infer_annotation(context, annotation): """ Inferes an annotation node. This means that it inferes the part of `int` here: @@ -30,7 +30,7 @@ def infer_annotation(value, annotation): Also checks for forward references (strings) """ - value_set = value.infer_node(annotation) + value_set = context.infer_node(annotation) if len(value_set) != 1: debug.warning("Inferred typing index %s should lead to 1 object, " " not %s" % (annotation, value_set)) @@ -38,18 +38,18 @@ def infer_annotation(value, annotation): inferred_value = list(value_set)[0] if is_string(inferred_value): - result = _get_forward_reference_node(value, inferred_value.get_safe_value()) + result = _get_forward_reference_node(context, inferred_value.get_safe_value()) if result is not None: - return value.infer_node(result) + return context.infer_node(result) return value_set -def _infer_annotation_string(value, string, index=None): - node = _get_forward_reference_node(value, string) +def _infer_annotation_string(context, string, index=None): + node = _get_forward_reference_node(context, string) if node is None: return NO_VALUES - value_set = value.infer_node(node) + value_set = context.infer_node(node) if index is not None: value_set = value_set.filter( lambda value: value.array_type == u'tuple' # noqa @@ -58,9 +58,9 @@ def _infer_annotation_string(value, string, index=None): return value_set -def _get_forward_reference_node(value, string): +def _get_forward_reference_node(context, string): try: - new_node = value.inference_state.grammar.parse( + new_node = context.inference_state.grammar.parse( force_unicode(string), start_symbol='eval_input', error_recovery=False @@ -69,9 +69,9 @@ def _get_forward_reference_node(value, string): debug.warning('Annotation not parsed: %s' % string) return None else: - module = value.tree_node.get_root_node() + module = context.tree_node.get_root_node() parser_utils.move(new_node, module.end_pos[0]) - new_node.parent = value.tree_node + new_node.parent = context.tree_node return new_node @@ -173,8 +173,8 @@ def _infer_param(execution_context, param): param_comment ) # Annotations are like default params and resolve in the same way. - value = execution_context.function_value.get_default_param_context() - return infer_annotation(value, annotation) + context = execution_context.function_value.get_default_param_context() + return infer_annotation(context, annotation) def py__annotations__(funcdef): @@ -216,9 +216,9 @@ def infer_return_types(function_execution_context): if annotation is None: return NO_VALUES - value = function_execution_context.function_value.get_default_param_context() - unknown_type_vars = list(find_unknown_type_vars(value, annotation)) - annotation_values = infer_annotation(value, annotation) + context = function_execution_context.function_value.get_default_param_context() + unknown_type_vars = list(find_unknown_type_vars(context, annotation)) + annotation_values = infer_annotation(context, annotation) if not unknown_type_vars: return annotation_values.execute_annotation() @@ -241,7 +241,7 @@ def infer_type_vars_for_execution(execution_context, annotation_dict): 2. Infer type vars with the execution state we have. 3. Return the union of all type vars that have been found. """ - value = execution_context.function_value.get_default_param_context() + context = execution_context.function_value.get_default_param_context() annotation_variable_results = {} executed_params, _ = execution_context.get_executed_params_and_issues() @@ -251,10 +251,10 @@ def infer_type_vars_for_execution(execution_context, annotation_dict): except KeyError: continue - annotation_variables = find_unknown_type_vars(value, annotation_node) + annotation_variables = find_unknown_type_vars(context, annotation_node) if annotation_variables: # Infer unknown type var - annotation_value_set = value.infer_node(annotation_node) + annotation_value_set = context.infer_node(annotation_node) star_count = executed_param._param_node.star_count actual_value_set = executed_param.infer(use_hints=False) if star_count == 1: @@ -337,22 +337,22 @@ def _infer_type_vars(annotation_value, value_set): return type_var_dict -def find_type_from_comment_hint_for(value, node, name): - return _find_type_from_comment_hint(value, node, node.children[1], name) +def find_type_from_comment_hint_for(context, node, name): + return _find_type_from_comment_hint(context, node, node.children[1], name) -def find_type_from_comment_hint_with(value, node, name): +def find_type_from_comment_hint_with(context, node, name): assert len(node.children[1].children) == 3, \ "Can only be here when children[1] is 'foo() as f'" varlist = node.children[1].children[2] - return _find_type_from_comment_hint(value, node, varlist, name) + return _find_type_from_comment_hint(context, node, varlist, name) -def find_type_from_comment_hint_assign(value, node, name): - return _find_type_from_comment_hint(value, node, node.children[0], name) +def find_type_from_comment_hint_assign(context, node, name): + return _find_type_from_comment_hint(context, node, node.children[0], name) -def _find_type_from_comment_hint(value, node, varlist, name): +def _find_type_from_comment_hint(context, node, varlist, name): index = None if varlist.type in ("testlist_star_expr", "exprlist", "testlist"): # something like "a, b = 1, 2" @@ -373,11 +373,11 @@ def _find_type_from_comment_hint(value, node, varlist, name): if match is None: return [] return _infer_annotation_string( - value, match.group(1).strip(), index + context, match.group(1).strip(), index ).execute_annotation() -def find_unknown_type_vars(value, node): +def find_unknown_type_vars(context, node): def check_node(node): if node.type in ('atom_expr', 'power'): trailer = node.children[-1] @@ -385,7 +385,7 @@ def find_unknown_type_vars(value, node): for subscript_node in _unpack_subscriptlist(trailer.children[1]): check_node(subscript_node) else: - type_var_set = value.infer_node(node) + type_var_set = context.infer_node(node) for type_var in type_var_set: if isinstance(type_var, TypeVar) and type_var not in found: found.append(type_var) diff --git a/jedi/inference/gradual/typeshed.py b/jedi/inference/gradual/typeshed.py index afc9bc2e..dfa8e1d7 100644 --- a/jedi/inference/gradual/typeshed.py +++ b/jedi/inference/gradual/typeshed.py @@ -7,6 +7,7 @@ from jedi._compatibility import FileNotFoundError, cast_path from jedi.parser_utils import get_cached_code_lines from jedi.inference.base_value import ValueSet, NO_VALUES from jedi.inference.gradual.stub_value import TypingModuleWrapper, StubModuleValue +from jedi.inference.context import ModuleContext _jedi_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) TYPESHED_PATH = os.path.join(_jedi_path, 'third_party', 'typeshed') @@ -235,7 +236,7 @@ def _load_from_typeshed(inference_state, python_value_set, parent_module_context if len(import_names) == 1: map_ = _cache_stub_file_map(inference_state.grammar.version_info) import_name = _IMPORT_MAP.get(import_name, import_name) - elif isinstance(parent_module_context, StubModuleValue): + elif isinstance(parent_module_context, ModuleContext): if not parent_module_context.is_package: # Only if it's a package (= a folder) something can be # imported. diff --git a/jedi/inference/gradual/typing.py b/jedi/inference/gradual/typing.py index 483990bb..391b4301 100644 --- a/jedi/inference/gradual/typing.py +++ b/jedi/inference/gradual/typing.py @@ -563,7 +563,6 @@ class AnnotatedClassContext(ClassContext): yield self._value.get_type_var_filter() - class AbstractAnnotatedClass(ClassMixin, ValueWrapper): def get_type_var_filter(self): return TypeVarFilter(self.get_generics(), self.list_type_vars()) diff --git a/jedi/inference/value/function.py b/jedi/inference/value/function.py index 21133f31..b54dc504 100644 --- a/jedi/inference/value/function.py +++ b/jedi/inference/value/function.py @@ -153,7 +153,7 @@ class MethodValue(FunctionValue): self.class_value = class_value def get_default_param_context(self): - return self.class_value + return self.class_value.as_context() def get_qualified_names(self): # Need to implement this, because the parent value of a method diff --git a/jedi/inference/value/instance.py b/jedi/inference/value/instance.py index 38e52297..c5fa0f70 100644 --- a/jedi/inference/value/instance.py +++ b/jedi/inference/value/instance.py @@ -473,7 +473,7 @@ class InstanceClassFilter(AbstractFilter): ] def __repr__(self): - return '<%s for %s>' % (self.__class__.__name__, self._class_filter.value) + return '<%s for %s>' % (self.__class__.__name__, self._class_filter.context) class SelfAttributeFilter(ClassFilter): diff --git a/jedi/inference/value/iterable.py b/jedi/inference/value/iterable.py index 6719245f..be4b7ee9 100644 --- a/jedi/inference/value/iterable.py +++ b/jedi/inference/value/iterable.py @@ -292,12 +292,10 @@ class DictComprehension(ComprehensionMixin, Sequence): def py__simple_getitem__(self, index): for keys, values in self._iterate(): for k in keys: - # TODO remove this isinstance. - if isinstance(k, compiled.CompiledObject): - # Be careful in the future if refactoring, index could be a - # slice. - if k.get_safe_value(default=object()) == index: - return values + # Be careful in the future if refactoring, index could be a + # slice object. + if k.get_safe_value(default=object()) == index: + return values raise SimpleGetItemNotFound() def _dict_keys(self):