diff --git a/stubs/annoy/annoy/__init__.pyi b/stubs/annoy/annoy/__init__.pyi index aa39233a7..74d9ad6e0 100644 --- a/stubs/annoy/annoy/__init__.pyi +++ b/stubs/annoy/annoy/__init__.pyi @@ -1,5 +1,8 @@ -from typing import Sequence, overload -from typing_extensions import Literal +from typing import Sized, overload +from typing_extensions import Literal, Protocol + +class _Vector(Protocol, Sized): + def __getitem__(self, i: int) -> float: ... class AnnoyIndex: f: int @@ -18,18 +21,18 @@ class AnnoyIndex: ) -> tuple[list[int], list[float]]: ... @overload def get_nns_by_vector( - self, vector: Sequence[float], n: int, search_k: int = ..., include_distances: Literal[False] = ... + self, vector: _Vector, n: int, search_k: int = ..., include_distances: Literal[False] = ... ) -> list[int]: ... @overload def get_nns_by_vector( - self, vector: Sequence[float], n: int, search_k: int, include_distances: Literal[True] + self, vector: _Vector, n: int, search_k: int, include_distances: Literal[True] ) -> tuple[list[int], list[float]]: ... @overload def get_nns_by_vector( - self, vector: Sequence[float], n: int, search_k: int = ..., *, include_distances: Literal[True] + self, vector: _Vector, n: int, search_k: int = ..., *, include_distances: Literal[True] ) -> tuple[list[int], list[float]]: ... def get_item_vector(self, __i: int) -> list[float]: ... - def add_item(self, i: int, vector: Sequence[float]) -> None: ... + def add_item(self, i: int, vector: _Vector) -> None: ... def on_disk_build(self, fn: str) -> Literal[True]: ... def build(self, n_trees: int, n_jobs: int = ...) -> Literal[True]: ... def unbuild(self) -> Literal[True]: ...