diff --git a/jedi/evaluate/docstrings.py b/jedi/evaluate/docstrings.py index f9c3ffb3..277012f9 100644 --- a/jedi/evaluate/docstrings.py +++ b/jedi/evaluate/docstrings.py @@ -207,12 +207,12 @@ def follow_param(module_context, param): @memoize_default() -def infer_return_types(module_context, func): +def infer_return_types(function_context): def search_return_in_docstr(code): for p in DOCSTRING_RETURN_PATTERNS: match = p.search(code) if match: return _strip_rst_role(match.group(1)) - type_str = search_return_in_docstr(clean_scope_docstring(func)) - return _evaluate_for_statement_string(module_context, type_str) + type_str = search_return_in_docstr(function_context.py__doc__()) + return _evaluate_for_statement_string(function_context.get_root_context(), type_str) diff --git a/jedi/evaluate/pep0484.py b/jedi/evaluate/pep0484.py index 54f55903..db1f4c3d 100644 --- a/jedi/evaluate/pep0484.py +++ b/jedi/evaluate/pep0484.py @@ -99,9 +99,10 @@ def py__annotations__(funcdef): @memoize_default() -def infer_return_types(context, func): - annotation = py__annotations__(func).get("return", None) - return _evaluate_for_annotation(context, annotation) +def infer_return_types(function_context): + annotation = py__annotations__(function_context.tree_node).get("return", None) + module_context = function_context.get_root_context() + return _evaluate_for_annotation(module_context, annotation) _typing_module = None diff --git a/jedi/evaluate/representation.py b/jedi/evaluate/representation.py index 57de5daa..322ce723 100644 --- a/jedi/evaluate/representation.py +++ b/jedi/evaluate/representation.py @@ -330,8 +330,8 @@ class FunctionExecutionContext(context.TreeContext): returns = funcdef.yields else: returns = funcdef.returns - types = set(docstrings.infer_return_types(self.get_root_context(), funcdef)) - types |= set(pep0484.infer_return_types(self.get_root_context(), funcdef)) + types = set(docstrings.infer_return_types(self.function_context)) + types |= set(pep0484.infer_return_types(self.function_context)) for r in returns: check = flow_analysis.reachability_check(self, funcdef, r)