mirror of
https://github.com/davidhalter/typeshed.git
synced 2025-12-07 20:54:28 +08:00
tensorflow: add tf.strings module (#11380)
Partially taken from: https://github.com/hmc-cs-mdrissi/tensorflow_stubs/blob/main/stubs/tensorflow/strings.pyi Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com>
This commit is contained in:
@@ -32,6 +32,8 @@ FloatDataSequence: TypeAlias = Sequence[float] | Sequence[FloatDataSequence]
|
||||
IntDataSequence: TypeAlias = Sequence[int] | Sequence[IntDataSequence]
|
||||
StrDataSequence: TypeAlias = Sequence[str] | Sequence[StrDataSequence]
|
||||
ScalarTensorCompatible: TypeAlias = tf.Tensor | str | float | np.ndarray[Any, Any] | np.number[Any]
|
||||
UIntTensorCompatible: TypeAlias = tf.Tensor | int | UIntArray
|
||||
StringTensorCompatible: TypeAlias = tf.Tensor | str | npt.NDArray[np.str_] | Sequence[StringTensorCompatible]
|
||||
|
||||
TensorCompatible: TypeAlias = ScalarTensorCompatible | Sequence[TensorCompatible]
|
||||
# _TensorCompatibleT = TypeVar("_TensorCompatibleT", bound=TensorCompatible)
|
||||
|
||||
241
stubs/tensorflow/tensorflow/strings.pyi
Normal file
241
stubs/tensorflow/tensorflow/strings.pyi
Normal file
@@ -0,0 +1,241 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, TypeVar, overload
|
||||
|
||||
from tensorflow import RaggedTensor, Tensor
|
||||
from tensorflow._aliases import StringTensorCompatible, TensorCompatible, UIntTensorCompatible
|
||||
from tensorflow.dtypes import DType
|
||||
|
||||
_TensorOrRaggedTensor = TypeVar("_TensorOrRaggedTensor", Tensor, RaggedTensor)
|
||||
|
||||
@overload
|
||||
def as_string(
|
||||
input: TensorCompatible,
|
||||
precision: int = -1,
|
||||
scientific: bool = False,
|
||||
shortest: bool = False,
|
||||
width: int = -1,
|
||||
fill: str = "",
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
@overload
|
||||
def as_string(
|
||||
input: RaggedTensor,
|
||||
precision: int = -1,
|
||||
scientific: bool = False,
|
||||
shortest: bool = False,
|
||||
width: int = -1,
|
||||
fill: str = "",
|
||||
name: str | None = None,
|
||||
) -> RaggedTensor: ...
|
||||
def bytes_split(input: TensorCompatible | RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
def format(
|
||||
template: str, inputs: TensorCompatible, placeholder: str = "{}", summarize: int = 3, name: str | None = None
|
||||
) -> Tensor: ...
|
||||
def join(inputs: Sequence[TensorCompatible | RaggedTensor], separator: str = "", name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def length(input: TensorCompatible, unit: Literal["BYTE", "UTF8_CHAR"] = "BYTE", name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def length(input: RaggedTensor, unit: Literal["BYTE", "UTF8_CHAR"] = "BYTE", name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def lower(input: TensorCompatible, encoding: Literal["utf-8", ""] = "", name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def lower(input: RaggedTensor, encoding: Literal["utf-8", ""] = "", name: str | None = None) -> RaggedTensor: ...
|
||||
def ngrams(
|
||||
data: StringTensorCompatible | RaggedTensor,
|
||||
ngram_width: int | Sequence[int],
|
||||
separator: str = " ",
|
||||
pad_values: tuple[int, int] | str | None = None,
|
||||
padding_width: int | None = None,
|
||||
preserve_short_sequences: bool = False,
|
||||
name: str | None = None,
|
||||
) -> RaggedTensor: ...
|
||||
def reduce_join(
|
||||
inputs: StringTensorCompatible | RaggedTensor,
|
||||
axis: int | None = None,
|
||||
keepdims: bool = False,
|
||||
separator: str = "",
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
@overload
|
||||
def regex_full_match(input: StringTensorCompatible, pattern: StringTensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def regex_full_match(input: RaggedTensor, pattern: StringTensorCompatible, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def regex_replace(
|
||||
input: StringTensorCompatible,
|
||||
pattern: StringTensorCompatible,
|
||||
rewrite: StringTensorCompatible,
|
||||
replace_global: bool = True,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
@overload
|
||||
def regex_replace(
|
||||
input: RaggedTensor,
|
||||
pattern: StringTensorCompatible,
|
||||
rewrite: StringTensorCompatible,
|
||||
replace_global: bool = True,
|
||||
name: str | None = None,
|
||||
) -> RaggedTensor: ...
|
||||
def split(
|
||||
input: StringTensorCompatible | RaggedTensor,
|
||||
sep: StringTensorCompatible | None = None,
|
||||
maxsplit: int = -1,
|
||||
name: str | None = None,
|
||||
) -> RaggedTensor: ...
|
||||
@overload
|
||||
def strip(input: StringTensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def strip(input: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def substr(
|
||||
input: StringTensorCompatible,
|
||||
pos: TensorCompatible,
|
||||
len: TensorCompatible,
|
||||
unit: Literal["BYTE", "UTF8_CHAR"] = "BYTE",
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
@overload
|
||||
def substr(
|
||||
input: RaggedTensor,
|
||||
pos: TensorCompatible,
|
||||
len: TensorCompatible,
|
||||
unit: Literal["BYTE", "UTF8_CHAR"] = "BYTE",
|
||||
name: str | None = None,
|
||||
) -> RaggedTensor: ...
|
||||
@overload
|
||||
def to_hash_bucket(input: StringTensorCompatible, num_buckets: int, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def to_hash_bucket(input: RaggedTensor, num_buckets: int, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def to_hash_bucket_fast(input: StringTensorCompatible, num_buckets: int, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def to_hash_bucket_fast(input: RaggedTensor, num_buckets: int, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def to_hash_bucket_strong(
|
||||
input: StringTensorCompatible, num_buckets: int, key: Sequence[int], name: str | None = None
|
||||
) -> Tensor: ...
|
||||
@overload
|
||||
def to_hash_bucket_strong(input: RaggedTensor, num_buckets: int, key: Sequence[int], name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def to_number(input: StringTensorCompatible, out_type: DType = ..., name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def to_number(input: RaggedTensor, out_type: DType = ..., name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def unicode_decode(
|
||||
input: StringTensorCompatible,
|
||||
input_encoding: str,
|
||||
errors: Literal["replace", "strict", "ignore"] = "replace",
|
||||
replacement_char: int = 65533,
|
||||
replace_control_characters: bool = False,
|
||||
name: str | None = None,
|
||||
) -> Tensor | RaggedTensor: ...
|
||||
@overload
|
||||
def unicode_decode(
|
||||
input: RaggedTensor,
|
||||
input_encoding: str,
|
||||
errors: Literal["replace", "strict", "ignore"] = "replace",
|
||||
replacement_char: int = 65533,
|
||||
replace_control_characters: bool = False,
|
||||
name: str | None = None,
|
||||
) -> RaggedTensor: ...
|
||||
@overload
|
||||
def unicode_decode_with_offsets(
|
||||
input: StringTensorCompatible,
|
||||
input_encoding: str,
|
||||
errors: Literal["replace", "strict", "ignore"] = "replace",
|
||||
replacement_char: int = 65533,
|
||||
replace_control_characters: bool = False,
|
||||
name: str | None = None,
|
||||
) -> tuple[_TensorOrRaggedTensor, _TensorOrRaggedTensor]: ...
|
||||
@overload
|
||||
def unicode_decode_with_offsets(
|
||||
input: RaggedTensor,
|
||||
input_encoding: str,
|
||||
errors: Literal["replace", "strict", "ignore"] = "replace",
|
||||
replacement_char: int = 65533,
|
||||
replace_control_characters: bool = False,
|
||||
name: str | None = None,
|
||||
) -> tuple[RaggedTensor, RaggedTensor]: ...
|
||||
@overload
|
||||
def unicode_encode(
|
||||
input: TensorCompatible,
|
||||
output_encoding: Literal["UTF-8", "UTF-16-BE", "UTF-32-BE"],
|
||||
errors: Literal["replace", "strict", "ignore"] = "replace",
|
||||
replacement_char: int = 65533,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
@overload
|
||||
def unicode_encode(
|
||||
input: RaggedTensor,
|
||||
output_encoding: Literal["UTF-8", "UTF-16-BE", "UTF-32-BE"],
|
||||
errors: Literal["replace", "strict", "ignore"] = "replace",
|
||||
replacement_char: int = 65533,
|
||||
name: str | None = None,
|
||||
) -> RaggedTensor: ...
|
||||
@overload
|
||||
def unicode_script(input: TensorCompatible, name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def unicode_script(input: RaggedTensor, name: str | None = None) -> RaggedTensor: ...
|
||||
@overload
|
||||
def unicode_split(
|
||||
input: StringTensorCompatible,
|
||||
input_encoding: str,
|
||||
errors: Literal["replace", "strict", "ignore"] = "replace",
|
||||
replacement_char: int = 65533,
|
||||
name: str | None = None,
|
||||
) -> Tensor | RaggedTensor: ...
|
||||
@overload
|
||||
def unicode_split(
|
||||
input: RaggedTensor,
|
||||
input_encoding: str,
|
||||
errors: Literal["replace", "strict", "ignore"] = "replace",
|
||||
replacement_char: int = 65533,
|
||||
name: str | None = None,
|
||||
) -> RaggedTensor: ...
|
||||
@overload
|
||||
def unicode_split_with_offsets(
|
||||
input: StringTensorCompatible,
|
||||
input_encoding: str,
|
||||
errors: Literal["replace", "strict", "ignore"] = "replace",
|
||||
replacement_char: int = 65533,
|
||||
name: str | None = None,
|
||||
) -> tuple[_TensorOrRaggedTensor, _TensorOrRaggedTensor]: ...
|
||||
@overload
|
||||
def unicode_split_with_offsets(
|
||||
input: RaggedTensor,
|
||||
input_encoding: str,
|
||||
errors: Literal["replace", "strict", "ignore"] = "replace",
|
||||
replacement_char: int = 65533,
|
||||
name: str | None = None,
|
||||
) -> tuple[RaggedTensor, RaggedTensor]: ...
|
||||
@overload
|
||||
def unicode_transcode(
|
||||
input: StringTensorCompatible,
|
||||
input_encoding: str,
|
||||
output_encoding: Literal["UTF-8", "UTF-16-BE", "UTF-32-BE"],
|
||||
errors: Literal["replace", "strict", "ignore"] = "replace",
|
||||
replacement_char: int = 65533,
|
||||
replace_control_characters: bool = False,
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
@overload
|
||||
def unicode_transcode(
|
||||
input: RaggedTensor,
|
||||
input_encoding: str,
|
||||
output_encoding: Literal["UTF-8", "UTF-16-BE", "UTF-32-BE"],
|
||||
errors: Literal["replace", "strict", "ignore"] = "replace",
|
||||
replacement_char: int = 65533,
|
||||
replace_control_characters: bool = False,
|
||||
name: str | None = None,
|
||||
) -> RaggedTensor: ...
|
||||
def unsorted_segment_join(
|
||||
inputs: StringTensorCompatible,
|
||||
segment_ids: UIntTensorCompatible,
|
||||
num_segments: UIntTensorCompatible,
|
||||
separator: str = "",
|
||||
name: str | None = None,
|
||||
) -> Tensor: ...
|
||||
@overload
|
||||
def upper(input: TensorCompatible, encoding: Literal["utf-8", ""] = "", name: str | None = None) -> Tensor: ...
|
||||
@overload
|
||||
def upper(input: RaggedTensor, encoding: Literal["utf-8", ""] = "", name: str | None = None) -> RaggedTensor: ...
|
||||
Reference in New Issue
Block a user