From 8061e58dcf7d6526aff5583b601567adb08e5fe3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ho=C3=ABl=20Bagard?= <34478245+hoel-bagard@users.noreply.github.com> Date: Sat, 17 Feb 2024 13:08:43 +0900 Subject: [PATCH] `tensorflow`: add `tf.strings` module (#11380) Partially taken from: https://github.com/hmc-cs-mdrissi/tensorflow_stubs/blob/main/stubs/tensorflow/strings.pyi Co-authored-by: Jelle Zijlstra --- stubs/tensorflow/tensorflow/_aliases.pyi | 2 + stubs/tensorflow/tensorflow/strings.pyi | 241 +++++++++++++++++++++++ 2 files changed, 243 insertions(+) create mode 100644 stubs/tensorflow/tensorflow/strings.pyi diff --git a/stubs/tensorflow/tensorflow/_aliases.pyi b/stubs/tensorflow/tensorflow/_aliases.pyi index 40042b807..842c5e5ac 100644 --- a/stubs/tensorflow/tensorflow/_aliases.pyi +++ b/stubs/tensorflow/tensorflow/_aliases.pyi @@ -32,6 +32,8 @@ FloatDataSequence: TypeAlias = Sequence[float] | Sequence[FloatDataSequence] IntDataSequence: TypeAlias = Sequence[int] | Sequence[IntDataSequence] StrDataSequence: TypeAlias = Sequence[str] | Sequence[StrDataSequence] ScalarTensorCompatible: TypeAlias = tf.Tensor | str | float | np.ndarray[Any, Any] | np.number[Any] +UIntTensorCompatible: TypeAlias = tf.Tensor | int | UIntArray +StringTensorCompatible: TypeAlias = tf.Tensor | str | npt.NDArray[np.str_] | Sequence[StringTensorCompatible] TensorCompatible: TypeAlias = ScalarTensorCompatible | Sequence[TensorCompatible] # _TensorCompatibleT = TypeVar("_TensorCompatibleT", bound=TensorCompatible) diff --git a/stubs/tensorflow/tensorflow/strings.pyi b/stubs/tensorflow/tensorflow/strings.pyi new file mode 100644 index 000000000..85937f6f5 --- /dev/null +++ b/stubs/tensorflow/tensorflow/strings.pyi @@ -0,0 +1,241 @@ +from collections.abc import Sequence +from typing import Literal, TypeVar, overload + +from tensorflow import RaggedTensor, Tensor +from tensorflow._aliases import StringTensorCompatible, TensorCompatible, UIntTensorCompatible +from tensorflow.dtypes import DType + +_TensorOrRaggedTensor = TypeVar("_TensorOrRaggedTensor", Tensor, RaggedTensor) + +@overload +def as_string( + input: TensorCompatible, + precision: int = -1, + scientific: bool = False, + shortest: bool = False, + width: int = -1, + fill: str = "", + name: str | None = None, +) -> Tensor: ... +@overload +def as_string( + input: RaggedTensor, + precision: int = -1, + scientific: bool = False, + shortest: bool = False, + width: int = -1, + fill: str = "", + name: str | None = None, +) -> RaggedTensor: ... +def bytes_split(input: TensorCompatible | RaggedTensor, name: str | None = None) -> RaggedTensor: ... +def format( + template: str, inputs: TensorCompatible, placeholder: str = "{}", summarize: int = 3, name: str | None = None +) -> Tensor: ... +def join(inputs: Sequence[TensorCompatible | RaggedTensor], separator: str = "", name: str | None = None) -> Tensor: ... +@overload +def length(input: TensorCompatible, unit: Literal["BYTE", "UTF8_CHAR"] = "BYTE", name: str | None = None) -> Tensor: ... +@overload +def length(input: RaggedTensor, unit: Literal["BYTE", "UTF8_CHAR"] = "BYTE", name: str | None = None) -> RaggedTensor: ... +@overload +def lower(input: TensorCompatible, encoding: Literal["utf-8", ""] = "", name: str | None = None) -> Tensor: ... +@overload +def lower(input: RaggedTensor, encoding: Literal["utf-8", ""] = "", name: str | None = None) -> RaggedTensor: ... +def ngrams( + data: StringTensorCompatible | RaggedTensor, + ngram_width: int | Sequence[int], + separator: str = " ", + pad_values: tuple[int, int] | str | None = None, + padding_width: int | None = None, + preserve_short_sequences: bool = False, + name: str | None = None, +) -> RaggedTensor: ... +def reduce_join( + inputs: StringTensorCompatible | RaggedTensor, + axis: int | None = None, + keepdims: bool = False, + separator: str = "", + name: str | None = None, +) -> Tensor: ... +@overload +def regex_full_match(input: StringTensorCompatible, pattern: StringTensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def regex_full_match(input: RaggedTensor, pattern: StringTensorCompatible, name: str | None = None) -> RaggedTensor: ... +@overload +def regex_replace( + input: StringTensorCompatible, + pattern: StringTensorCompatible, + rewrite: StringTensorCompatible, + replace_global: bool = True, + name: str | None = None, +) -> Tensor: ... +@overload +def regex_replace( + input: RaggedTensor, + pattern: StringTensorCompatible, + rewrite: StringTensorCompatible, + replace_global: bool = True, + name: str | None = None, +) -> RaggedTensor: ... +def split( + input: StringTensorCompatible | RaggedTensor, + sep: StringTensorCompatible | None = None, + maxsplit: int = -1, + name: str | None = None, +) -> RaggedTensor: ... +@overload +def strip(input: StringTensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def strip(input: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def substr( + input: StringTensorCompatible, + pos: TensorCompatible, + len: TensorCompatible, + unit: Literal["BYTE", "UTF8_CHAR"] = "BYTE", + name: str | None = None, +) -> Tensor: ... +@overload +def substr( + input: RaggedTensor, + pos: TensorCompatible, + len: TensorCompatible, + unit: Literal["BYTE", "UTF8_CHAR"] = "BYTE", + name: str | None = None, +) -> RaggedTensor: ... +@overload +def to_hash_bucket(input: StringTensorCompatible, num_buckets: int, name: str | None = None) -> Tensor: ... +@overload +def to_hash_bucket(input: RaggedTensor, num_buckets: int, name: str | None = None) -> RaggedTensor: ... +@overload +def to_hash_bucket_fast(input: StringTensorCompatible, num_buckets: int, name: str | None = None) -> Tensor: ... +@overload +def to_hash_bucket_fast(input: RaggedTensor, num_buckets: int, name: str | None = None) -> RaggedTensor: ... +@overload +def to_hash_bucket_strong( + input: StringTensorCompatible, num_buckets: int, key: Sequence[int], name: str | None = None +) -> Tensor: ... +@overload +def to_hash_bucket_strong(input: RaggedTensor, num_buckets: int, key: Sequence[int], name: str | None = None) -> RaggedTensor: ... +@overload +def to_number(input: StringTensorCompatible, out_type: DType = ..., name: str | None = None) -> Tensor: ... +@overload +def to_number(input: RaggedTensor, out_type: DType = ..., name: str | None = None) -> RaggedTensor: ... +@overload +def unicode_decode( + input: StringTensorCompatible, + input_encoding: str, + errors: Literal["replace", "strict", "ignore"] = "replace", + replacement_char: int = 65533, + replace_control_characters: bool = False, + name: str | None = None, +) -> Tensor | RaggedTensor: ... +@overload +def unicode_decode( + input: RaggedTensor, + input_encoding: str, + errors: Literal["replace", "strict", "ignore"] = "replace", + replacement_char: int = 65533, + replace_control_characters: bool = False, + name: str | None = None, +) -> RaggedTensor: ... +@overload +def unicode_decode_with_offsets( + input: StringTensorCompatible, + input_encoding: str, + errors: Literal["replace", "strict", "ignore"] = "replace", + replacement_char: int = 65533, + replace_control_characters: bool = False, + name: str | None = None, +) -> tuple[_TensorOrRaggedTensor, _TensorOrRaggedTensor]: ... +@overload +def unicode_decode_with_offsets( + input: RaggedTensor, + input_encoding: str, + errors: Literal["replace", "strict", "ignore"] = "replace", + replacement_char: int = 65533, + replace_control_characters: bool = False, + name: str | None = None, +) -> tuple[RaggedTensor, RaggedTensor]: ... +@overload +def unicode_encode( + input: TensorCompatible, + output_encoding: Literal["UTF-8", "UTF-16-BE", "UTF-32-BE"], + errors: Literal["replace", "strict", "ignore"] = "replace", + replacement_char: int = 65533, + name: str | None = None, +) -> Tensor: ... +@overload +def unicode_encode( + input: RaggedTensor, + output_encoding: Literal["UTF-8", "UTF-16-BE", "UTF-32-BE"], + errors: Literal["replace", "strict", "ignore"] = "replace", + replacement_char: int = 65533, + name: str | None = None, +) -> RaggedTensor: ... +@overload +def unicode_script(input: TensorCompatible, name: str | None = None) -> Tensor: ... +@overload +def unicode_script(input: RaggedTensor, name: str | None = None) -> RaggedTensor: ... +@overload +def unicode_split( + input: StringTensorCompatible, + input_encoding: str, + errors: Literal["replace", "strict", "ignore"] = "replace", + replacement_char: int = 65533, + name: str | None = None, +) -> Tensor | RaggedTensor: ... +@overload +def unicode_split( + input: RaggedTensor, + input_encoding: str, + errors: Literal["replace", "strict", "ignore"] = "replace", + replacement_char: int = 65533, + name: str | None = None, +) -> RaggedTensor: ... +@overload +def unicode_split_with_offsets( + input: StringTensorCompatible, + input_encoding: str, + errors: Literal["replace", "strict", "ignore"] = "replace", + replacement_char: int = 65533, + name: str | None = None, +) -> tuple[_TensorOrRaggedTensor, _TensorOrRaggedTensor]: ... +@overload +def unicode_split_with_offsets( + input: RaggedTensor, + input_encoding: str, + errors: Literal["replace", "strict", "ignore"] = "replace", + replacement_char: int = 65533, + name: str | None = None, +) -> tuple[RaggedTensor, RaggedTensor]: ... +@overload +def unicode_transcode( + input: StringTensorCompatible, + input_encoding: str, + output_encoding: Literal["UTF-8", "UTF-16-BE", "UTF-32-BE"], + errors: Literal["replace", "strict", "ignore"] = "replace", + replacement_char: int = 65533, + replace_control_characters: bool = False, + name: str | None = None, +) -> Tensor: ... +@overload +def unicode_transcode( + input: RaggedTensor, + input_encoding: str, + output_encoding: Literal["UTF-8", "UTF-16-BE", "UTF-32-BE"], + errors: Literal["replace", "strict", "ignore"] = "replace", + replacement_char: int = 65533, + replace_control_characters: bool = False, + name: str | None = None, +) -> RaggedTensor: ... +def unsorted_segment_join( + inputs: StringTensorCompatible, + segment_ids: UIntTensorCompatible, + num_segments: UIntTensorCompatible, + separator: str = "", + name: str | None = None, +) -> Tensor: ... +@overload +def upper(input: TensorCompatible, encoding: Literal["utf-8", ""] = "", name: str | None = None) -> Tensor: ... +@overload +def upper(input: RaggedTensor, encoding: Literal["utf-8", ""] = "", name: str | None = None) -> RaggedTensor: ...