mirror of
https://github.com/davidhalter/typeshed.git
synced 2025-12-08 13:04:46 +08:00
tensorflow: feature columns (#10052)
Co-authored-by: Mehdi Drissi <mdrissi@snapchat.com> Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com>
This commit is contained in:
@@ -16,7 +16,9 @@ tensorflow.Graph.__getattr__
|
||||
tensorflow.Operation.__getattr__
|
||||
tensorflow.Variable.__getattr__
|
||||
tensorflow.keras.layers.Layer.__getattr__
|
||||
tensorflow.python.feature_column.feature_column_v2.SharedEmbeddingColumnCreator.__getattr__
|
||||
tensorflow.GradientTape.__getattr__
|
||||
|
||||
# Internal undocumented API
|
||||
tensorflow.RaggedTensor.__init__
|
||||
# Has an undocumented extra argument that tf.Variable which acts like subclass
|
||||
@@ -69,3 +71,13 @@ tensorflow.keras.layers.*.compute_output_shape
|
||||
# pb2.pyi generated by mypy-protobuf diverge with runtime in many ways. These stubs
|
||||
# are mainly tested in mypy-protobuf.
|
||||
.*_pb2.*
|
||||
|
||||
# Uses namedtuple at runtime, but NamedTuple in stubs and the two disagree about the name of
|
||||
# __new__ first argument (cls vs cls_).
|
||||
tensorflow.io.RaggedFeature.__new__
|
||||
tensorflow.io.FixedLenSequenceFeature.__new__
|
||||
tensorflow.io.FixedLenFeature.__new__
|
||||
tensorflow.io.SparseFeature.__new__
|
||||
|
||||
# Metaclass inconsistency. The runtime metaclass is defined from c++ extension and is undocumented.
|
||||
tensorflow.io.TFRecordWriter
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Any, NoReturn, TypeVar, overload
|
||||
from typing_extensions import ParamSpec, Self, TypeAlias
|
||||
|
||||
import numpy
|
||||
from tensorflow import initializers as initializers, keras as keras, math as math
|
||||
from tensorflow import feature_column as feature_column, initializers as initializers, io as io, keras as keras, math as math
|
||||
from tensorflow._aliases import ContainerGradients, ContainerTensors, ContainerTensorsLike, Gradients, TensorLike
|
||||
|
||||
# Explicit import of DType is covered by the wildcard, but
|
||||
|
||||
95
stubs/tensorflow/tensorflow/feature_column/__init__.pyi
Normal file
95
stubs/tensorflow/tensorflow/feature_column/__init__.pyi
Normal file
@@ -0,0 +1,95 @@
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow import _ShapeLike
|
||||
from tensorflow.python.feature_column import feature_column_v2 as fc, sequence_feature_column as seq_fc
|
||||
|
||||
def numeric_column(
|
||||
key: str,
|
||||
shape: _ShapeLike = (1,),
|
||||
default_value: float | None = None,
|
||||
dtype: tf.DType = ...,
|
||||
normalizer_fn: Callable[[tf.Tensor], tf.Tensor] | None = None,
|
||||
) -> fc.NumericColumn: ...
|
||||
def bucketized_column(source_column: fc.NumericColumn, boundaries: list[float] | tuple[float, ...]) -> fc.BucketizedColumn: ...
|
||||
def embedding_column(
|
||||
categorical_column: fc.CategoricalColumn,
|
||||
dimension: int,
|
||||
combiner: fc._Combiners = "mean",
|
||||
initializer: Callable[[_ShapeLike], tf.Tensor] | None = None,
|
||||
ckpt_to_load_from: str | None = None,
|
||||
tensor_name_in_ckpt: str | None = None,
|
||||
max_norm: float | None = None,
|
||||
trainable: bool = True,
|
||||
use_safe_embedding_lookup: bool = True,
|
||||
) -> fc.EmbeddingColumn: ...
|
||||
def shared_embeddings(
|
||||
categorical_columns: Iterable[fc.CategoricalColumn],
|
||||
dimension: int,
|
||||
combiner: fc._Combiners = "mean",
|
||||
initializer: Callable[[_ShapeLike], tf.Tensor] | None = None,
|
||||
shared_embedding_collection_name: str | None = None,
|
||||
ckpt_to_load_from: str | None = None,
|
||||
tensor_name_in_ckpt: str | None = None,
|
||||
max_norm: float | None = None,
|
||||
trainable: bool = True,
|
||||
use_safe_embedding_lookup: bool = True,
|
||||
) -> list[fc.SharedEmbeddingColumn]: ...
|
||||
def categorical_column_with_identity(
|
||||
key: str, num_buckets: int, default_value: int | None = None
|
||||
) -> fc.IdentityCategoricalColumn: ...
|
||||
def categorical_column_with_hash_bucket(key: str, hash_bucket_size: int, dtype: tf.DType = ...) -> fc.HashedCategoricalColumn: ...
|
||||
def categorical_column_with_vocabulary_file(
|
||||
key: str,
|
||||
vocabulary_file: str,
|
||||
vocabulary_size: int | None = None,
|
||||
dtype: tf.DType = ...,
|
||||
default_value: str | int | None = None,
|
||||
num_oov_buckets: int = 0,
|
||||
file_format: str | None = None,
|
||||
) -> fc.VocabularyFileCategoricalColumn: ...
|
||||
def categorical_column_with_vocabulary_list(
|
||||
key: str,
|
||||
vocabulary_list: Sequence[str] | Sequence[int],
|
||||
dtype: tf.DType | None = None,
|
||||
default_value: str | int | None = -1,
|
||||
num_oov_buckets: int = 0,
|
||||
) -> fc.VocabularyListCategoricalColumn: ...
|
||||
def indicator_column(categorical_column: fc.CategoricalColumn) -> fc.IndicatorColumn: ...
|
||||
def weighted_categorical_column(
|
||||
categorical_column: fc.CategoricalColumn, weight_feature_key: str, dtype: tf.DType = ...
|
||||
) -> fc.WeightedCategoricalColumn: ...
|
||||
def crossed_column(
|
||||
keys: Iterable[str | fc.CategoricalColumn], hash_bucket_size: int, hash_key: int | None = None
|
||||
) -> fc.CrossedColumn: ...
|
||||
def sequence_numeric_column(
|
||||
key: str,
|
||||
shape: _ShapeLike = (1,),
|
||||
default_value: float = 0.0,
|
||||
dtype: tf.DType = ...,
|
||||
normalizer_fn: Callable[[tf.Tensor], tf.Tensor] | None = None,
|
||||
) -> seq_fc.SequenceNumericColumn: ...
|
||||
def sequence_categorical_column_with_identity(
|
||||
key: str, num_buckets: int, default_value: int | None = None
|
||||
) -> fc.SequenceCategoricalColumn: ...
|
||||
def sequence_categorical_column_with_hash_bucket(
|
||||
key: str, hash_bucket_size: int, dtype: tf.DType = ...
|
||||
) -> fc.SequenceCategoricalColumn: ...
|
||||
def sequence_categorical_column_with_vocabulary_file(
|
||||
key: str,
|
||||
vocabulary_file: str,
|
||||
vocabulary_size: int | None = None,
|
||||
num_oov_buckets: int = 0,
|
||||
default_value: str | int | None = None,
|
||||
dtype: tf.DType = ...,
|
||||
) -> fc.SequenceCategoricalColumn: ...
|
||||
def sequence_categorical_column_with_vocabulary_list(
|
||||
key: str,
|
||||
vocabulary_list: Sequence[str] | Sequence[int],
|
||||
dtype: tf.DType | None = None,
|
||||
default_value: str | int | None = -1,
|
||||
num_oov_buckets: int = 0,
|
||||
) -> fc.SequenceCategoricalColumn: ...
|
||||
def make_parse_example_spec(
|
||||
feature_columns: Iterable[fc.FeatureColumn],
|
||||
) -> dict[str, tf.io.FixedLenFeature | tf.io.VarLenFeature]: ...
|
||||
106
stubs/tensorflow/tensorflow/io/__init__.pyi
Normal file
106
stubs/tensorflow/tensorflow/io/__init__.pyi
Normal file
@@ -0,0 +1,106 @@
|
||||
from _typeshed import Incomplete
|
||||
from collections.abc import Iterable, Mapping
|
||||
from types import TracebackType
|
||||
from typing import NamedTuple
|
||||
from typing_extensions import Literal, Self, TypeAlias
|
||||
|
||||
from tensorflow import _DTypeLike, _ShapeLike, _TensorCompatible
|
||||
from tensorflow._aliases import TensorLike
|
||||
from tensorflow.io import gfile as gfile
|
||||
|
||||
_FeatureSpecs: TypeAlias = Mapping[str, FixedLenFeature | FixedLenSequenceFeature | VarLenFeature | RaggedFeature | SparseFeature]
|
||||
|
||||
_CompressionTypes: TypeAlias = Literal["ZLIB", "GZIP", "", 0, 1, 2] | None
|
||||
_CompressionLevels: TypeAlias = Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] | None
|
||||
_MemoryLevels: TypeAlias = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9] | None
|
||||
|
||||
class TFRecordOptions:
|
||||
compression_type: _CompressionTypes | TFRecordOptions
|
||||
flush_mode: int | None # The exact values allowed comes from zlib
|
||||
input_buffer_size: int | None
|
||||
output_buffer_size: int | None
|
||||
window_bits: int | None
|
||||
compression_level: _CompressionLevels
|
||||
compression_method: str | None
|
||||
mem_level: _MemoryLevels
|
||||
compression_strategy: int | None # The exact values allowed comes from zlib
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
compression_type: _CompressionTypes | TFRecordOptions = None,
|
||||
flush_mode: int | None = None,
|
||||
input_buffer_size: int | None = None,
|
||||
output_buffer_size: int | None = None,
|
||||
window_bits: int | None = None,
|
||||
compression_level: _CompressionLevels = None,
|
||||
compression_method: str | None = None,
|
||||
mem_level: _MemoryLevels = None,
|
||||
compression_strategy: int | None = None,
|
||||
) -> None: ...
|
||||
@classmethod
|
||||
def get_compression_type_string(cls, options: _CompressionTypes | TFRecordOptions) -> str: ...
|
||||
|
||||
class TFRecordWriter:
|
||||
def __init__(self, path: str, options: _CompressionTypes | TFRecordOptions | None = None) -> None: ...
|
||||
def write(self, record: bytes) -> None: ...
|
||||
def flush(self) -> None: ...
|
||||
def close(self) -> None: ...
|
||||
def __enter__(self) -> Self: ...
|
||||
def __exit__(
|
||||
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
|
||||
) -> None: ...
|
||||
|
||||
# Also defaults are missing here because pytype crashes when a default is present reported
|
||||
# in this [issue](https://github.com/google/pytype/issues/1410#issue-1669793588). After
|
||||
# next release the defaults can be added back.
|
||||
class FixedLenFeature(NamedTuple):
|
||||
shape: _ShapeLike
|
||||
dtype: _DTypeLike
|
||||
default_value: _TensorCompatible | None = ...
|
||||
|
||||
class FixedLenSequenceFeature(NamedTuple):
|
||||
shape: _ShapeLike
|
||||
dtype: _DTypeLike
|
||||
allow_missing: bool = ...
|
||||
default_value: _TensorCompatible | None = ...
|
||||
|
||||
class VarLenFeature(NamedTuple):
|
||||
dtype: _DTypeLike
|
||||
|
||||
class SparseFeature(NamedTuple):
|
||||
index_key: str | list[str]
|
||||
value_key: str
|
||||
dtype: _DTypeLike
|
||||
size: int | list[int]
|
||||
already_sorted: bool = ...
|
||||
|
||||
class RaggedFeature(NamedTuple):
|
||||
# Mypy doesn't support nested NamedTuples, but at runtime they actually do use
|
||||
# nested collections.namedtuple.
|
||||
class RowSplits(NamedTuple): # type: ignore[misc]
|
||||
key: str
|
||||
|
||||
class RowLengths(NamedTuple): # type: ignore[misc]
|
||||
key: str
|
||||
|
||||
class RowStarts(NamedTuple): # type: ignore[misc]
|
||||
key: str
|
||||
|
||||
class RowLimits(NamedTuple): # type: ignore[misc]
|
||||
key: str
|
||||
|
||||
class ValueRowIds(NamedTuple): # type: ignore[misc]
|
||||
key: str
|
||||
|
||||
class UniformRowLength(NamedTuple): # type: ignore[misc]
|
||||
length: int
|
||||
dtype: _DTypeLike
|
||||
value_key: str | None = ...
|
||||
partitions: tuple[RowSplits | RowLengths | RowStarts | RowLimits | ValueRowIds | UniformRowLength, ...] = ... # type: ignore[name-defined]
|
||||
row_splits_dtype: _DTypeLike = ...
|
||||
validate: bool = ...
|
||||
|
||||
def parse_example(
|
||||
serialized: _TensorCompatible, features: _FeatureSpecs, example_names: Iterable[str] | None = None, name: str | None = None
|
||||
) -> dict[str, TensorLike]: ...
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
11
stubs/tensorflow/tensorflow/io/gfile.pyi
Normal file
11
stubs/tensorflow/tensorflow/io/gfile.pyi
Normal file
@@ -0,0 +1,11 @@
|
||||
from _typeshed import Incomplete, StrOrBytesPath
|
||||
from collections.abc import Iterable
|
||||
|
||||
def rmtree(path: StrOrBytesPath) -> None: ...
|
||||
def isdir(path: StrOrBytesPath) -> bool: ...
|
||||
def listdir(path: StrOrBytesPath) -> list[str]: ...
|
||||
def exists(path: StrOrBytesPath) -> bool: ...
|
||||
def copy(src: StrOrBytesPath, dst: StrOrBytesPath, overwrite: bool = False) -> None: ...
|
||||
def makedirs(path: StrOrBytesPath) -> None: ...
|
||||
def glob(pattern: str | bytes | Iterable[str | bytes]) -> list[str]: ...
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
@@ -0,0 +1,271 @@
|
||||
# The types here are all undocumented, but all feature columns are return types of the
|
||||
# public functions in tf.feature_column. As they are undocumented internals while some
|
||||
# common methods are included, they are incomplete and do not have getattr Incomplete fallback.
|
||||
from _typeshed import Incomplete
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing_extensions import Literal, Self, TypeAlias
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow import _ShapeLike
|
||||
|
||||
_Combiners: TypeAlias = Literal["mean", "sqrtn", "sum"]
|
||||
_ExampleSpec: TypeAlias = dict[str, tf.io.FixedLenFeature | tf.io.VarLenFeature]
|
||||
|
||||
class FeatureColumn(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
@abstractmethod
|
||||
def parse_example_spec(self) -> _ExampleSpec: ...
|
||||
def __lt__(self, other: FeatureColumn) -> bool: ...
|
||||
def __gt__(self, other: FeatureColumn) -> bool: ...
|
||||
@property
|
||||
@abstractmethod
|
||||
def parents(self) -> list[FeatureColumn | str]: ...
|
||||
|
||||
class DenseColumn(FeatureColumn, metaclass=ABCMeta): ...
|
||||
class SequenceDenseColumn(FeatureColumn, metaclass=ABCMeta): ...
|
||||
|
||||
# These classes are mostly subclasses of collections.namedtuple but we can't use
|
||||
# typing.NamedTuple because they use multiple inheritance with other non namedtuple classes.
|
||||
# _cls instead of cls is because collections.namedtuple uses _cls for __new__.
|
||||
class NumericColumn(DenseColumn):
|
||||
key: str
|
||||
shape: _ShapeLike
|
||||
default_value: float
|
||||
dtype: tf.DType
|
||||
normalizer_fn: Callable[[tf.Tensor], tf.Tensor] | None
|
||||
|
||||
def __new__(
|
||||
_cls,
|
||||
key: str,
|
||||
shape: _ShapeLike,
|
||||
default_value: float,
|
||||
dtype: tf.DType,
|
||||
normalizer_fn: Callable[[tf.Tensor], tf.Tensor] | None,
|
||||
) -> Self: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
def parse_example_spec(self) -> _ExampleSpec: ...
|
||||
@property
|
||||
def parents(self) -> list[FeatureColumn | str]: ...
|
||||
|
||||
class CategoricalColumn(FeatureColumn):
|
||||
@property
|
||||
@abstractmethod
|
||||
def num_buckets(self) -> int: ...
|
||||
|
||||
class BucketizedColumn(DenseColumn, CategoricalColumn):
|
||||
source_column: NumericColumn
|
||||
boundaries: list[float] | tuple[float, ...]
|
||||
|
||||
def __new__(_cls, source_column: NumericColumn, boundaries: list[float] | tuple[float, ...]) -> Self: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
def num_buckets(self) -> int: ...
|
||||
@property
|
||||
def parse_example_spec(self) -> _ExampleSpec: ...
|
||||
@property
|
||||
def parents(self) -> list[FeatureColumn | str]: ...
|
||||
|
||||
class EmbeddingColumn(DenseColumn, SequenceDenseColumn):
|
||||
categorical_column: CategoricalColumn
|
||||
dimension: int
|
||||
combiner: _Combiners
|
||||
initializer: Callable[[_ShapeLike], tf.Tensor] | None
|
||||
ckpt_to_load_from: str | None
|
||||
tensor_name_in_ckpt: str | None
|
||||
max_norm: float | None
|
||||
trainable: bool
|
||||
use_safe_embedding_lookup: bool
|
||||
|
||||
# This one subclasses collections.namedtuple and overrides __new__.
|
||||
def __new__(
|
||||
cls,
|
||||
categorical_column: CategoricalColumn,
|
||||
dimension: int,
|
||||
combiner: _Combiners,
|
||||
initializer: Callable[[_ShapeLike], tf.Tensor] | None,
|
||||
ckpt_to_load_from: str | None,
|
||||
tensor_name_in_ckpt: str | None,
|
||||
max_norm: float | None,
|
||||
trainable: bool,
|
||||
use_safe_embedding_lookup: bool = True,
|
||||
) -> Self: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
def parse_example_spec(self) -> _ExampleSpec: ...
|
||||
@property
|
||||
def parents(self) -> list[FeatureColumn | str]: ...
|
||||
|
||||
class SharedEmbeddingColumnCreator:
|
||||
def __init__(
|
||||
self,
|
||||
dimension: int,
|
||||
initializer: Callable[[_ShapeLike], tf.Tensor] | None,
|
||||
ckpt_to_load_from: str | None,
|
||||
tensor_name_in_ckpt: str | None,
|
||||
num_buckets: int,
|
||||
trainable: bool,
|
||||
name: str = "shared_embedding_column_creator",
|
||||
use_safe_embedding_lookup: bool = True,
|
||||
) -> None: ...
|
||||
def __getattr__(self, name: str) -> Incomplete: ...
|
||||
|
||||
class SharedEmbeddingColumn(DenseColumn, SequenceDenseColumn):
|
||||
categorical_column: CategoricalColumn
|
||||
shared_embedding_column_creator: SharedEmbeddingColumnCreator
|
||||
combiner: _Combiners
|
||||
max_norm: float | None
|
||||
use_safe_embedding_lookup: bool
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
categorical_column: CategoricalColumn,
|
||||
shared_embedding_column_creator: SharedEmbeddingColumnCreator,
|
||||
combiner: _Combiners,
|
||||
max_norm: float | None,
|
||||
use_safe_embedding_lookup: bool = True,
|
||||
) -> Self: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
def parse_example_spec(self) -> _ExampleSpec: ...
|
||||
@property
|
||||
def parents(self) -> list[FeatureColumn | str]: ...
|
||||
|
||||
class CrossedColumn(CategoricalColumn):
|
||||
keys: tuple[str, ...]
|
||||
hash_bucket_size: int
|
||||
hash_key: int | None
|
||||
|
||||
def __new__(_cls, keys: tuple[str, ...], hash_bucket_size: int, hash_key: int | None) -> Self: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
def num_buckets(self) -> int: ...
|
||||
@property
|
||||
def parse_example_spec(self) -> _ExampleSpec: ...
|
||||
@property
|
||||
def parents(self) -> list[FeatureColumn | str]: ...
|
||||
|
||||
class IdentityCategoricalColumn(CategoricalColumn):
|
||||
key: str
|
||||
number_buckets: int
|
||||
default_value: int | None
|
||||
|
||||
def __new__(_cls, key: str, number_buckets: int, default_value: int | None) -> Self: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
def num_buckets(self) -> int: ...
|
||||
@property
|
||||
def parse_example_spec(self) -> _ExampleSpec: ...
|
||||
@property
|
||||
def parents(self) -> list[FeatureColumn | str]: ...
|
||||
|
||||
class HashedCategoricalColumn(CategoricalColumn):
|
||||
key: str
|
||||
hash_bucket_size: int
|
||||
dtype: tf.DType
|
||||
|
||||
def __new__(_cls, key: str, hash_bucket_size: int, dtype: tf.DType) -> Self: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
def num_buckets(self) -> int: ...
|
||||
@property
|
||||
def parse_example_spec(self) -> _ExampleSpec: ...
|
||||
@property
|
||||
def parents(self) -> list[FeatureColumn | str]: ...
|
||||
|
||||
class VocabularyFileCategoricalColumn(CategoricalColumn):
|
||||
key: str
|
||||
vocabulary_file: str
|
||||
vocabulary_size: int | None
|
||||
num_oov_buckets: int
|
||||
dtype: tf.DType
|
||||
default_value: str | int | None
|
||||
file_format: str | None
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
key: str,
|
||||
vocabulary_file: str,
|
||||
vocabulary_size: int | None,
|
||||
num_oov_buckets: int,
|
||||
dtype: tf.DType,
|
||||
default_value: str | int | None,
|
||||
file_format: str | None = None,
|
||||
) -> Self: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
def num_buckets(self) -> int: ...
|
||||
@property
|
||||
def parse_example_spec(self) -> _ExampleSpec: ...
|
||||
@property
|
||||
def parents(self) -> list[FeatureColumn | str]: ...
|
||||
|
||||
class VocabularyListCategoricalColumn(CategoricalColumn):
|
||||
key: str
|
||||
vocabulary_list: Sequence[str] | Sequence[int]
|
||||
dtype: tf.DType
|
||||
default_value: str | int | None
|
||||
num_oov_buckets: int
|
||||
|
||||
def __new__(
|
||||
_cls, key: str, vocabulary_list: Sequence[str], dtype: tf.DType, default_value: str | int | None, num_oov_buckets: int
|
||||
) -> Self: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
def num_buckets(self) -> int: ...
|
||||
@property
|
||||
def parse_example_spec(self) -> _ExampleSpec: ...
|
||||
@property
|
||||
def parents(self) -> list[FeatureColumn | str]: ...
|
||||
|
||||
class WeightedCategoricalColumn(CategoricalColumn):
|
||||
categorical_column: CategoricalColumn
|
||||
weight_feature_key: str
|
||||
dtype: tf.DType
|
||||
|
||||
def __new__(_cls, categorical_column: CategoricalColumn, weight_feature_key: str, dtype: tf.DType) -> Self: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
def num_buckets(self) -> int: ...
|
||||
@property
|
||||
def parse_example_spec(self) -> _ExampleSpec: ...
|
||||
@property
|
||||
def parents(self) -> list[FeatureColumn | str]: ...
|
||||
|
||||
class IndicatorColumn(DenseColumn, SequenceDenseColumn):
|
||||
categorical_column: CategoricalColumn
|
||||
|
||||
def __new__(_cls, categorical_column: CategoricalColumn) -> Self: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
def parse_example_spec(self) -> _ExampleSpec: ...
|
||||
@property
|
||||
def parents(self) -> list[FeatureColumn | str]: ...
|
||||
|
||||
class SequenceCategoricalColumn(CategoricalColumn):
|
||||
categorical_column: CategoricalColumn
|
||||
|
||||
def __new__(_cls, categorical_column: CategoricalColumn) -> Self: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
def num_buckets(self) -> int: ...
|
||||
@property
|
||||
def parse_example_spec(self) -> _ExampleSpec: ...
|
||||
@property
|
||||
def parents(self) -> list[FeatureColumn | str]: ...
|
||||
@@ -0,0 +1,30 @@
|
||||
from collections.abc import Callable
|
||||
from typing_extensions import Self
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow import _ShapeLike
|
||||
from tensorflow.python.feature_column.feature_column_v2 import FeatureColumn, SequenceDenseColumn, _ExampleSpec
|
||||
|
||||
# Strangely at runtime most of Sequence feature columns are defined in feature_column_v2 except
|
||||
# for this one.
|
||||
class SequenceNumericColumn(SequenceDenseColumn):
|
||||
key: str
|
||||
shape: _ShapeLike
|
||||
default_value: float
|
||||
dtype: tf.DType
|
||||
normalizer_fn: Callable[[tf.Tensor], tf.Tensor] | None
|
||||
|
||||
def __new__(
|
||||
_cls,
|
||||
key: str,
|
||||
shape: _ShapeLike,
|
||||
default_value: float,
|
||||
dtype: tf.DType,
|
||||
normalizer_fn: Callable[[tf.Tensor], tf.Tensor] | None,
|
||||
) -> Self: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
@property
|
||||
def parse_example_spec(self) -> _ExampleSpec: ...
|
||||
@property
|
||||
def parents(self) -> list[FeatureColumn | str]: ...
|
||||
Reference in New Issue
Block a user