1
0
forked from VimPlug/jedi

Some class fixes.

This commit is contained in:
Dave Halter
2016-11-01 18:28:47 +01:00
parent 9a55c9cf50
commit 2eb701d2d2
6 changed files with 98 additions and 55 deletions

View File

@@ -111,19 +111,16 @@ class Evaluator(object):
self.execution_recursion_detector = recursion.ExecutionRecursionDetector(self) self.execution_recursion_detector = recursion.ExecutionRecursionDetector(self)
def wrap(self, element, parent_context): def wrap(self, element, parent_context):
if isinstance(element, Context) or element is None:
# TODO this is so ugly, please refactor.
return element
if element.type == 'classdef': if element.type == 'classdef':
return er.ClassContext(self, element, parent_context) return er.ClassContext(self, element, parent_context)
elif element.type == 'funcdef': elif element.type == 'funcdef':
return er.FunctionContext(self, parent_context, element) return er.AnonymousFunctionExecution(self, parent_context, element)
elif element.type == 'lambda': elif element.type == 'lambda':
return er.LambdaWrapper(self, element) return er.LambdaWrapper(self, element)
elif element.type == 'file_input': elif element.type == 'file_input':
return er.ModuleContext(self, element) return er.ModuleContext(self, element)
else: else:
raise DeprecationWarning
return element return element
def find_types(self, context, name_str, position=None, search_global=False, def find_types(self, context, name_str, position=None, search_global=False,
@@ -317,6 +314,7 @@ class Evaluator(object):
elif element.type == 'dotted_name': elif element.type == 'dotted_name':
types = self._eval_atom(context, element.children[0]) types = self._eval_atom(context, element.children[0])
for next_name in element.children[2::2]: for next_name in element.children[2::2]:
# TODO add search_global=True?
types = set(chain.from_iterable(self.find_types(typ, next_name) types = set(chain.from_iterable(self.find_types(typ, next_name)
for typ in types)) for typ in types))
types = types types = types

View File

@@ -9,7 +9,7 @@ from jedi.evaluate import flow_analysis
from jedi.common import to_list, unite from jedi.common import to_list, unite
class AbstractNameDefinition(): class AbstractNameDefinition(object):
start_pos = None start_pos = None
string_name = None string_name = None
parent_context = None parent_context = None
@@ -96,6 +96,8 @@ class AbstractFilter(object):
class AbstractUsedNamesFilter(AbstractFilter): class AbstractUsedNamesFilter(AbstractFilter):
name_class = TreeNameDefinition
def __init__(self, context, parser_scope, origin_scope=None): def __init__(self, context, parser_scope, origin_scope=None):
super(AbstractUsedNamesFilter, self).__init__(origin_scope) super(AbstractUsedNamesFilter, self).__init__(origin_scope)
self._parser_scope = parser_scope self._parser_scope = parser_scope
@@ -111,7 +113,7 @@ class AbstractUsedNamesFilter(AbstractFilter):
return self._convert_names(self._filter(names)) return self._convert_names(self._filter(names))
def _convert_names(self, names): def _convert_names(self, names):
return [TreeNameDefinition(self._context, name) for name in names] return [self.name_class(self._context, name) for name in names]
def values(self): def values(self):
return self._convert_names(name for name_list in self._used_names.values() return self._convert_names(name for name_list in self._used_names.values()

View File

@@ -210,8 +210,8 @@ class NameFinder(object):
# deliver types. # deliver types.
self._found_predefined_if_name = types self._found_predefined_if_name = types
else: else:
check = flow_analysis.reachability_check(self._context, self.context, check = flow_analysis.reachability_check(
origin_scope) self._context, self.context, origin_scope)
if check is flow_analysis.UNREACHABLE: if check is flow_analysis.UNREACHABLE:
self._found_predefined_if_name = set() self._found_predefined_if_name = set()
else: else:
@@ -341,6 +341,7 @@ class NameFinder(object):
def _get_global_stmt_scopes(evaluator, global_stmt, name): def _get_global_stmt_scopes(evaluator, global_stmt, name):
raise DeprecationWarning
global_stmt_scope = global_stmt.get_parent_scope() global_stmt_scope = global_stmt.get_parent_scope()
module = global_stmt_scope.get_parent_until() module = global_stmt_scope.get_parent_until()
for used_name in module.used_names[str(name)]: for used_name in module.used_names[str(name)]:
@@ -400,7 +401,18 @@ def _apply_decorators(evaluator, context, node):
Returns the function, that should to be executed in the end. Returns the function, that should to be executed in the end.
This is also the places where the decorators are processed. This is also the places where the decorators are processed.
""" """
decoratee_context = evaluator.wrap(node, parent_context=context) if node.type == 'classdef':
decoratee_context = er.ClassContext(
evaluator,
parent_context=context,
classdef=node
)
else:
decoratee_context = er.FunctionContext(
evaluator,
parent_context=context,
funcdef=node
)
initial = values = set([decoratee_context]) initial = values = set([decoratee_context])
for dec in reversed(node.get_decorators()): for dec in reversed(node.get_decorators()):
debug.dbg('decorator: %s %s', dec, values) debug.dbg('decorator: %s %s', dec, values)

View File

@@ -1,9 +1,9 @@
from abc import abstractproperty from abc import abstractproperty
from jedi.common import unite from jedi.common import unite, to_list
from jedi import debug from jedi import debug
from jedi.evaluate import compiled from jedi.evaluate import compiled
from jedi.evaluate.filters import ParserTreeFilter, ContextName from jedi.evaluate.filters import ParserTreeFilter, ContextName, TreeNameDefinition
from jedi.evaluate.context import Context from jedi.evaluate.context import Context
@@ -91,13 +91,13 @@ class AbstractInstanceContext(Context):
if isinstance(cls, compiled.CompiledObject): if isinstance(cls, compiled.CompiledObject):
yield SelfNameFilter(self._evaluator, self, cls, origin_scope) yield SelfNameFilter(self._evaluator, self, cls, origin_scope)
else: else:
yield SelfNameFilter(self._evaluator, self, cls.base, origin_scope) yield SelfNameFilter(self._evaluator, self, cls.classdef, origin_scope)
for cls in self._class_context.py__mro__(): for cls in self._class_context.py__mro__():
if isinstance(cls, compiled.CompiledObject): if isinstance(cls, compiled.CompiledObject):
yield CompiledInstanceClassFilter(self._evaluator, self, cls) yield CompiledInstanceClassFilter(self._evaluator, self, cls)
else: else:
yield InstanceClassFilter(self._evaluator, self, cls.base, origin_scope) yield InstanceClassFilter(self._evaluator, self, cls.classdef, origin_scope)
def py__getitem__(self, index): def py__getitem__(self, index):
try: try:
@@ -165,7 +165,31 @@ class CompiledInstanceClassFilter(compiled.CompiledObjectFilter):
for name in names] for name in names]
class BoundMethod(object):
def __init__(self, function):
self._function = function
def __getattr__(self, name):
return getattr(self._function, name)
class InstanceNameDefinition(TreeNameDefinition):
@to_list
def infer(self):
contexts = super(InstanceNameDefinition, self).infer()
from jedi.evaluate.representation import FunctionContext
for context in contexts:
"""
if isinstance(contexts, FunctionContext):
# TODO what about compiled objects?
yield BoundMethod(context)
else:
"""
yield context
class InstanceClassFilter(ParserTreeFilter): class InstanceClassFilter(ParserTreeFilter):
name_class = InstanceNameDefinition
def __init__(self, evaluator, context, parser_scope, origin_scope): def __init__(self, evaluator, context, parser_scope, origin_scope):
super(InstanceClassFilter, self).__init__( super(InstanceClassFilter, self).__init__(
evaluator=evaluator, evaluator=evaluator,

View File

@@ -417,9 +417,9 @@ class ClassContext(use_metaclass(CachedMetaClass, context.TreeContext, Wrapper))
This class is not only important to extend `tree.Class`, it is also a This class is not only important to extend `tree.Class`, it is also a
important for descriptors (if the descriptor methods are evaluated or not). important for descriptors (if the descriptor methods are evaluated or not).
""" """
def __init__(self, evaluator, base, parent_context): def __init__(self, evaluator, classdef, parent_context):
super(ClassContext, self).__init__(evaluator, parent_context=parent_context) super(ClassContext, self).__init__(evaluator, parent_context=parent_context)
self.base = base self.classdef = classdef
@memoize_default(default=()) @memoize_default(default=())
def py__mro__(self): def py__mro__(self):
@@ -430,38 +430,41 @@ class ClassContext(use_metaclass(CachedMetaClass, context.TreeContext, Wrapper))
mro = [self] mro = [self]
# TODO Do a proper mro resolution. Currently we are just listing # TODO Do a proper mro resolution. Currently we are just listing
# classes. However, it's a complicated algorithm. # classes. However, it's a complicated algorithm.
for cls in self.py__bases__(): for lazy_cls in self.py__bases__():
# TODO detect for TypeError: duplicate base class str, # TODO there's multiple different mro paths possible if this yields
# e.g. `class X(str, str): pass` # multiple possibilities. Could be changed to be more correct.
try: for cls in lazy_cls.infer():
mro_method = cls.py__mro__ # TODO detect for TypeError: duplicate base class str,
except AttributeError: # e.g. `class X(str, str): pass`
# TODO add a TypeError like: try:
""" mro_method = cls.py__mro__
>>> class Y(lambda: test): pass except AttributeError:
Traceback (most recent call last): # TODO add a TypeError like:
File "<stdin>", line 1, in <module> """
TypeError: function() argument 1 must be code, not str >>> class Y(lambda: test): pass
>>> class Y(1): pass Traceback (most recent call last):
Traceback (most recent call last): File "<stdin>", line 1, in <module>
File "<stdin>", line 1, in <module> TypeError: function() argument 1 must be code, not str
TypeError: int() takes at most 2 arguments (3 given) >>> class Y(1): pass
""" Traceback (most recent call last):
pass File "<stdin>", line 1, in <module>
else: TypeError: int() takes at most 2 arguments (3 given)
add(cls) """
for cls_new in mro_method(): pass
add(cls_new) else:
add(cls)
for cls_new in mro_method():
add(cls_new)
return tuple(mro) return tuple(mro)
@memoize_default(default=()) @memoize_default(default=())
def py__bases__(self): def py__bases__(self):
arglist = self.base.get_super_arglist() arglist = self.classdef.get_super_arglist()
if arglist: if arglist:
args = param.Arguments(self._evaluator, self, arglist) args = param.TreeArguments(self._evaluator, self, arglist)
return list(chain.from_iterable(args.eval_args())) return [value for key, value in args.unpack() if key is None]
else: else:
return [compiled.create(self._evaluator, object)] return [context.LazyKnownContext(compiled.create(self._evaluator, object))]
def py__call__(self, params): def py__call__(self, params):
return set([TreeInstance(self._evaluator, self.parent_context, self, params)]) return set([TreeInstance(self._evaluator, self.parent_context, self, params)])
@@ -488,44 +491,38 @@ class ClassContext(use_metaclass(CachedMetaClass, context.TreeContext, Wrapper))
def get_filters(self, search_global, until_position=None, origin_scope=None, is_instance=False): def get_filters(self, search_global, until_position=None, origin_scope=None, is_instance=False):
if search_global: if search_global:
yield ParserTreeFilter(self._evaluator, self, self.base, until_position, origin_scope=origin_scope) yield ParserTreeFilter(self._evaluator, self, self.classdef, until_position, origin_scope=origin_scope)
else: else:
for scope in self.py__mro__(): for scope in self.py__mro__():
if isinstance(scope, compiled.CompiledObject): if isinstance(scope, compiled.CompiledObject):
for filter in scope.get_filters(is_instance=is_instance): for filter in scope.get_filters(is_instance=is_instance):
yield filter yield filter
else: else:
yield ParserTreeFilter(self._evaluator, self, scope.base, origin_scope=origin_scope) yield ParserTreeFilter(self._evaluator, self, scope.classdef, origin_scope=origin_scope)
def is_class(self): def is_class(self):
return True return True
def get_subscope_by_name(self, name): def get_subscope_by_name(self, name):
raise DeprecationWarning
for s in self.py__mro__(): for s in self.py__mro__():
for sub in reversed(s.subscopes): for sub in reversed(s.subscopes):
if sub.name.value == name: if sub.name.value == name:
return sub return sub
raise KeyError("Couldn't find subscope.") raise KeyError("Couldn't find subscope.")
def __getattr__(self, name):
if name not in ['start_pos', 'end_pos', 'parent', 'raw_doc',
'doc', 'get_imports', 'get_parent_until', 'get_code',
'subscopes', 'names_dict', 'type']:
return super(ClassContext, self).__getattribute__(name)
return getattr(self.base, name)
def __repr__(self): def __repr__(self):
return "<e%s of %s>" % (type(self).__name__, self.base) return "<e%s of %s>" % (type(self).__name__, self.classdef)
class FunctionContext(use_metaclass(CachedMetaClass, context.TreeContext, Wrapper)): class FunctionContext(use_metaclass(CachedMetaClass, context.TreeContext, Wrapper)):
""" """
Needed because of decorators. Decorators are evaluated here. Needed because of decorators. Decorators are evaluated here.
""" """
def __init__(self, evaluator, parent_context, func): def __init__(self, evaluator, parent_context, funcdef):
""" This should not be called directly """ """ This should not be called directly """
super(FunctionContext, self).__init__(evaluator, parent_context) super(FunctionContext, self).__init__(evaluator, parent_context)
self.base = self.base_func = func self.base = self.base_func = funcdef
def names_dicts(self, search_global): def names_dicts(self, search_global):
if search_global: if search_global:
@@ -713,6 +710,16 @@ class FunctionExecutionContext(Executed):
return "<%s of %s>" % (type(self).__name__, self.funcdef) return "<%s of %s>" % (type(self).__name__, self.funcdef)
class AnonymousFunctionExecution(FunctionExecutionContext):
def __init__(self, evaluator, parent_context, funcdef):
super(AnonymousFunctionExecution, self).__init__(
evaluator, parent_context, funcdef, var_args=None)
@memoize_default(default=NO_DEFAULT)
def get_params(self):
return []
class GlobalName(helpers.FakeName): class GlobalName(helpers.FakeName):
def __init__(self, name): def __init__(self, name):
""" """

View File

@@ -153,7 +153,7 @@ def builtins_super(evaluator, types, objects, scope):
cls = cls.base cls = cls.base
su = cls.py__bases__() su = cls.py__bases__()
if su: if su:
return evaluator.execute(su[0]) return su[0].infer()
return set() return set()