Prepare a test for #1479

This commit is contained in:
Dave Halter
2020-01-25 01:07:20 +01:00
parent 066b8b7165
commit 9c0efd5a67
6 changed files with 59 additions and 11 deletions

View File

@@ -1,6 +1,7 @@
import tempfile import tempfile
import shutil import shutil
import os import os
import sys
from functools import partial from functools import partial
import pytest import pytest
@@ -16,6 +17,9 @@ collect_ignore = [
'build/', 'build/',
'test/examples', 'test/examples',
] ]
if sys.version_info < (3, 5):
# Python 2 not supported syntax
collect_ignore.append('test/test_inference/test_mixed.py')
# The following hooks (pytest_configure, pytest_unconfigure) are used # The following hooks (pytest_configure, pytest_unconfigure) are used

View File

@@ -24,15 +24,18 @@ class MixedModuleContext(ModuleContext):
super(MixedModuleContext, self).__init__(tree_module_value) super(MixedModuleContext, self).__init__(tree_module_value)
self._namespace_objects = [NamespaceObject(n) for n in namespaces] self._namespace_objects = [NamespaceObject(n) for n in namespaces]
def _get_mixed_object(self, compiled_object):
return mixed.MixedObject(
compiled_object=compiled_object,
tree_value=self._value
)
def get_filters(self, *args, **kwargs): def get_filters(self, *args, **kwargs):
for filter in self._value.as_context().get_filters(*args, **kwargs): for filter in self._value.as_context().get_filters(*args, **kwargs):
yield filter yield filter
for namespace_obj in self._namespace_objects: for namespace_obj in self._namespace_objects:
compiled_object = _create(self.inference_state, namespace_obj) compiled_object = _create(self.inference_state, namespace_obj)
mixed_object = mixed.MixedObject( mixed_object = self._get_mixed_object(compiled_object)
compiled_object=compiled_object,
tree_value=self._value
)
for filter in mixed_object.get_filters(*args, **kwargs): for filter in mixed_object.get_filters(*args, **kwargs):
yield filter yield filter

View File

@@ -84,9 +84,10 @@ class MixedObject(ValueWrapper):
return MixedContext(self) return MixedContext(self)
def __repr__(self): def __repr__(self):
return '<%s: %s>' % ( return '<%s: %s; %s>' % (
type(self).__name__, type(self).__name__,
self.access_handle.get_repr() self.access_handle.get_repr(),
self._wrapped_value,
) )

View File

@@ -13,6 +13,7 @@ from jedi.inference.compiled.value import create_from_access_path
from jedi.inference.imports import _load_python_module from jedi.inference.imports import _load_python_module
from jedi.file_io import KnownContentFileIO from jedi.file_io import KnownContentFileIO
from jedi.inference.base_value import ValueSet from jedi.inference.base_value import ValueSet
from jedi.api.interpreter import MixedModuleContext
# For interpreter tests sometimes the path of this directory is in the sys # For interpreter tests sometimes the path of this directory is in the sys
# path, which we definitely don't want. So just remove it globally. # path, which we definitely don't want. So just remove it globally.
@@ -173,3 +174,14 @@ def module_injector():
inference_state.module_cache.add(names, ValueSet([v])) inference_state.module_cache.add(names, ValueSet([v]))
return module_injector return module_injector
@pytest.fixture(params=[False, True])
def class_findable(monkeypatch, request):
if not request.param:
monkeypatch.setattr(
MixedModuleContext,
'_get_mixed_object',
lambda self, compiled_object: compiled_object.as_context()
)
return request.param

View File

@@ -18,7 +18,7 @@ else:
eval(compile("""def exec_(source, global_map): eval(compile("""def exec_(source, global_map):
exec source in global_map """, 'blub', 'exec')) exec source in global_map """, 'blub', 'exec'))
if py_version > 34: if py_version > 35:
import typing import typing
else: else:
typing = None typing = None

View File

@@ -1,6 +1,4 @@
import sys from typing import Generic, TypeVar, List
if sys.version_info > (3, 5):
from typing import Generic, TypeVar, List
import pytest import pytest
@@ -18,7 +16,6 @@ def test_on_code():
assert i.infer() assert i.infer()
@pytest.mark.skipif('sys.version_info < (3,5)')
def test_generics_without_definition(): def test_generics_without_definition():
# Used to raise a recursion error # Used to raise a recursion error
T = TypeVar('T') T = TypeVar('T')
@@ -43,6 +40,37 @@ def test_generics_without_definition():
assert not interpreter('s.stack.pop().', locals()).complete() assert not interpreter('s.stack.pop().', locals()).complete()
@pytest.mark.parametrize(
'code, expected', [
('Foo.method()', 'int'),
('Foo.method()', 'int'),
('Foo().read()', 'str'),
('Foo.read()', 'str'),
]
)
def test_generics_methods(code, expected, class_findable):
T = TypeVar("T")
class Reader(Generic[T]):
@classmethod
def read(cls) -> T:
return cls()
def method(self) -> T:
return 1
class Foo(Reader[str]):
def transform(self) -> int:
return 42
defs = jedi.Interpreter(code, [locals()]).infer()
if class_findable:
def_, = defs
assert def_.name == expected
else:
assert not defs
def test_mixed_module_cache(): def test_mixed_module_cache():
"""Caused by #1479""" """Caused by #1479"""
interpreter = jedi.Interpreter('jedi', [{'jedi': jedi}]) interpreter = jedi.Interpreter('jedi', [{'jedi': jedi}])