Skip to content

Commit

Permalink
Typing cleanup:
Browse files Browse the repository at this point in the history
- changing input types of `tf_utils.construct_and_lookup_table` to tf.Tensor since it's not currently used or tested with composite tensors (it is applied on flat values for composite inputs to mappers).
- `common_types.ConsistentTensorType` is not currently used anywhere, so renaming `common_types.ConsistentInputTensorType` to `common_types.ConsistentTensorType` and removing the unused type.
- renaming `common_types.InputTensorType` to `common_types.TensorType`.

PiperOrigin-RevId: 404825689
  • Loading branch information
iindyk authored and tf-transform-team committed Oct 21, 2021
1 parent 35df75a commit b06e87b
Show file tree
Hide file tree
Showing 11 changed files with 140 additions and 148 deletions.
44 changes: 22 additions & 22 deletions tensorflow_transform/analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@

def _apply_cacheable_combiner(
combiner: analyzer_nodes.Combiner,
*tensor_inputs: common_types.InputTensorType) -> Tuple[tf.Tensor, ...]:
*tensor_inputs: common_types.TensorType) -> Tuple[tf.Tensor, ...]:
"""Applies the combiner over the whole dataset possibly utilizing cache."""
input_values_node = analyzer_nodes.get_input_tensors_value_nodes(
tensor_inputs)
Expand All @@ -137,7 +137,7 @@ def _apply_cacheable_combiner(

def _apply_cacheable_combiner_per_key(
combiner: analyzer_nodes.Combiner,
*tensor_inputs: common_types.InputTensorType) -> Tuple[tf.Tensor, ...]:
*tensor_inputs: common_types.TensorType) -> Tuple[tf.Tensor, ...]:
"""Similar to _apply_cacheable_combiner but this is computed per key."""
input_values_node = analyzer_nodes.get_input_tensors_value_nodes(
tensor_inputs)
Expand All @@ -162,7 +162,7 @@ def _apply_cacheable_combiner_per_key(

def _apply_cacheable_combiner_per_key_large(
combiner: analyzer_nodes.Combiner, key_vocabulary_filename: str,
*tensor_inputs: common_types.InputTensorType
*tensor_inputs: common_types.TensorType
) -> Union[tf.Tensor, common_types.Asset]:
"""Similar to above but saves the combined result to a file."""
input_values_node = analyzer_nodes.get_input_tensors_value_nodes(
Expand Down Expand Up @@ -382,7 +382,7 @@ def _numeric_combine(inputs: List[tf.Tensor],

@common.log_api_use(common.ANALYZER_COLLECTION)
def min( # pylint: disable=redefined-builtin
x: common_types.InputTensorType,
x: common_types.TensorType,
reduce_instance_dims: bool = True,
name: Optional[str] = None) -> tf.Tensor:
"""Computes the minimum of the values of a `Tensor` over the whole dataset.
Expand All @@ -409,7 +409,7 @@ def min( # pylint: disable=redefined-builtin

@common.log_api_use(common.ANALYZER_COLLECTION)
def max( # pylint: disable=redefined-builtin
x: common_types.InputTensorType,
x: common_types.TensorType,
reduce_instance_dims: bool = True,
name: Optional[str] = None) -> tf.Tensor:
"""Computes the maximum of the values of a `Tensor` over the whole dataset.
Expand All @@ -433,7 +433,7 @@ def max( # pylint: disable=redefined-builtin
return _min_and_max(x, reduce_instance_dims, name)[1]


def _min_and_max(x: common_types.InputTensorType,
def _min_and_max(x: common_types.TensorType,
reduce_instance_dims: bool = True,
name: Optional[str] = None) -> Tuple[tf.Tensor, tf.Tensor]:
"""Computes the min and max of the values of a `Tensor` or `CompositeTensor`.
Expand Down Expand Up @@ -482,8 +482,8 @@ def _min_and_max(x: common_types.InputTensorType,


def _min_and_max_per_key(
x: common_types.InputTensorType,
key: common_types.InputTensorType,
x: common_types.TensorType,
key: common_types.TensorType,
reduce_instance_dims: bool = True,
key_vocabulary_filename: Optional[str] = None,
name: Optional[str] = None
Expand Down Expand Up @@ -577,7 +577,7 @@ def _sum_combine_fn_and_dtype(

@common.log_api_use(common.ANALYZER_COLLECTION)
def sum( # pylint: disable=redefined-builtin
x: common_types.InputTensorType,
x: common_types.TensorType,
reduce_instance_dims: bool = True,
name: Optional[str] = None) -> tf.Tensor:
"""Computes the sum of the values of a `Tensor` over the whole dataset.
Expand Down Expand Up @@ -629,7 +629,7 @@ def sum( # pylint: disable=redefined-builtin


@common.log_api_use(common.ANALYZER_COLLECTION)
def histogram(x: common_types.InputTensorType,
def histogram(x: common_types.TensorType,
boundaries: Optional[Union[tf.Tensor, int]] = None,
categorical: Optional[bool] = False,
name: Optional[str] = None) -> Tuple[tf.Tensor, tf.Tensor]:
Expand Down Expand Up @@ -702,7 +702,7 @@ def histogram(x: common_types.InputTensorType,


@common.log_api_use(common.ANALYZER_COLLECTION)
def size(x: common_types.InputTensorType,
def size(x: common_types.TensorType,
reduce_instance_dims: bool = True,
name: Optional[str] = None) -> tf.Tensor:
"""Computes the total size of instances in a `Tensor` over the whole dataset.
Expand Down Expand Up @@ -730,7 +730,7 @@ def size(x: common_types.InputTensorType,


@common.log_api_use(common.ANALYZER_COLLECTION)
def count_per_key(key: common_types.InputTensorType,
def count_per_key(key: common_types.TensorType,
key_vocabulary_filename: Optional[str] = None,
name: Optional[str] = None):
"""Computes the count of each element of a `Tensor`.
Expand Down Expand Up @@ -779,7 +779,7 @@ def count_per_key(key: common_types.InputTensorType,


@common.log_api_use(common.ANALYZER_COLLECTION)
def mean(x: common_types.InputTensorType,
def mean(x: common_types.TensorType,
reduce_instance_dims: bool = True,
name: Optional[str] = None,
output_dtype: Optional[tf.DType] = None) -> tf.Tensor:
Expand Down Expand Up @@ -807,7 +807,7 @@ def mean(x: common_types.InputTensorType,


@common.log_api_use(common.ANALYZER_COLLECTION)
def var(x: common_types.InputTensorType,
def var(x: common_types.TensorType,
reduce_instance_dims: bool = True,
name: Optional[str] = None,
output_dtype: Optional[tf.DType] = None) -> tf.Tensor:
Expand Down Expand Up @@ -837,7 +837,7 @@ def var(x: common_types.InputTensorType,
return _mean_and_var(x, reduce_instance_dims, output_dtype)[1]


def _mean_and_var(x: common_types.InputTensorType,
def _mean_and_var(x: common_types.TensorType,
reduce_instance_dims: bool = True,
output_dtype: Optional[tf.DType] = None):
"""More efficient combined `mean` and `var`. See `var`."""
Expand Down Expand Up @@ -876,7 +876,7 @@ def _mean_and_var(x: common_types.InputTensorType,


@common.log_api_use(common.ANALYZER_COLLECTION)
def tukey_location(x: common_types.InputTensorType,
def tukey_location(x: common_types.TensorType,
reduce_instance_dims: Optional[bool] = True,
output_dtype: Optional[tf.DType] = None,
name: Optional[str] = None) -> tf.Tensor:
Expand Down Expand Up @@ -913,7 +913,7 @@ def tukey_location(x: common_types.InputTensorType,


@common.log_api_use(common.ANALYZER_COLLECTION)
def tukey_scale(x: common_types.InputTensorType,
def tukey_scale(x: common_types.TensorType,
reduce_instance_dims: Optional[bool] = True,
output_dtype: Optional[tf.DType] = None,
name: Optional[str] = None) -> tf.Tensor:
Expand Down Expand Up @@ -951,7 +951,7 @@ def tukey_scale(x: common_types.InputTensorType,


@common.log_api_use(common.ANALYZER_COLLECTION)
def tukey_h_params(x: common_types.InputTensorType,
def tukey_h_params(x: common_types.TensorType,
reduce_instance_dims: bool = True,
output_dtype: Optional[tf.DType] = None,
name: Optional[str] = None) -> Tuple[tf.Tensor, tf.Tensor]:
Expand Down Expand Up @@ -988,7 +988,7 @@ def tukey_h_params(x: common_types.InputTensorType,


def _tukey_parameters(
x: common_types.InputTensorType,
x: common_types.TensorType,
reduce_instance_dims: bool = True,
output_dtype: Optional[tf.DType] = None
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
Expand Down Expand Up @@ -1027,8 +1027,8 @@ def _tukey_parameters(


def _mean_and_var_per_key(
x: common_types.InputTensorType,
key: common_types.InputTensorType,
x: common_types.TensorType,
key: common_types.TensorType,
reduce_instance_dims: bool = True,
output_dtype: Optional[tf.DType] = None,
key_vocabulary_filename: Optional[str] = None
Expand Down Expand Up @@ -1652,7 +1652,7 @@ def _register_vocab(sanitized_filename: str,
# https://github.com/tensorflow/community/blob/master/rfcs/20190116-embedding-partitioned-variable.md#goals
@common.log_api_use(common.ANALYZER_COLLECTION)
def vocabulary(
x: common_types.InputTensorType,
x: common_types.TensorType,
top_k: Optional[int] = None,
frequency_threshold: Optional[int] = None,
vocab_filename: Optional[str] = None,
Expand Down
8 changes: 3 additions & 5 deletions tensorflow_transform/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,14 @@

DomainType = Union[schema_pb2.IntDomain, schema_pb2.FloatDomain,
schema_pb2.StringDomain]
InputTensorType = Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor]
TensorType = Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor]
ConsistentTensorType = TypeVar('ConsistentTensorType', tf.Tensor,
tf.SparseTensor, tf.RaggedTensor)
SparseTensorValueType = Union[tf.SparseTensor, tf.compat.v1.SparseTensorValue]
RaggedTensorValueType = Union[tf.RaggedTensor,
tf.compat.v1.ragged.RaggedTensorValue]
TensorValueType = Union[tf.Tensor, np.ndarray, SparseTensorValueType,
RaggedTensorValueType]
ConsistentInputTensorType = TypeVar('ConsistentInputTensorType', tf.Tensor,
tf.SparseTensor, tf.RaggedTensor)
ConsistentTensorType = TypeVar('ConsistentTensorType', tf.Tensor,
tf.SparseTensor)
TemporaryAnalyzerOutputType = Union[tf.Tensor, Asset]
VocabularyFileFormatType = Literal['text', 'tfrecord_gzip']

Expand Down
2 changes: 1 addition & 1 deletion tensorflow_transform/experimental/analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@


def _apply_analyzer(analyzer_def_cls: Type[analyzer_nodes.AnalyzerDef],
*tensor_inputs: common_types.InputTensorType,
*tensor_inputs: common_types.TensorType,
**analyzer_def_kwargs: Any) -> Tuple[tf.Tensor, ...]:
"""Applies the analyzer over the whole dataset.
Expand Down
3 changes: 1 addition & 2 deletions tensorflow_transform/graph_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,8 +903,7 @@ def validate_value(self, value):


def get_analyzers_fingerprint(
graph: tf.Graph,
structured_inputs: Mapping[str, common_types.InputTensorType]
graph: tf.Graph, structured_inputs: Mapping[str, common_types.TensorType]
) -> Mapping[str, Set[bytes]]:
"""Computes fingerprints for all analyzers in `graph`.
Expand Down
20 changes: 10 additions & 10 deletions tensorflow_transform/impl_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,8 +595,8 @@ def _check_valid_sparse_tensor(indices: Union[_CompositeComponentType,
# `preprocessing_fn` using tf.function as is and another that will return
# specific outputs requested for.
def get_traced_transform_fn(
preprocessing_fn: Callable[[Mapping[str, common_types.InputTensorType]],
Mapping[str, common_types.InputTensorType]],
preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
Mapping[str, common_types.TensorType]],
input_signature: Mapping[str, tf.TypeSpec],
tf_graph_context: graph_context.TFGraphContext,
output_keys_to_name_map: Optional[Dict[str,
Expand Down Expand Up @@ -720,8 +720,8 @@ def trace_preprocessing_function(preprocessing_fn,

def _trace_and_write_transform_fn(
saved_model_dir: str,
preprocessing_fn: Callable[[Mapping[str, common_types.InputTensorType]],
Mapping[str, common_types.InputTensorType]],
preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
Mapping[str, common_types.TensorType]],
input_signature: Mapping[str, tf.TypeSpec], base_temp_dir: Optional[str],
tensor_replacement_map: Optional[Dict[str, tf.Tensor]],
output_keys_to_name_map: Optional[Dict[str,
Expand All @@ -743,9 +743,9 @@ def _trace_and_write_transform_fn(

def _trace_and_get_metadata(
concrete_transform_fn: function.ConcreteFunction,
structured_inputs: Mapping[str, common_types.InputTensorType],
preprocessing_fn: Callable[[Mapping[str, common_types.InputTensorType]],
Mapping[str, common_types.InputTensorType]],
structured_inputs: Mapping[str, common_types.TensorType],
preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
Mapping[str, common_types.TensorType]],
base_temp_dir: Optional[str],
tensor_replacement_map: Optional[Dict[str, tf.Tensor]]
) -> dataset_metadata.DatasetMetadata:
Expand All @@ -768,7 +768,7 @@ def _trace_and_get_metadata(

def _validate_analyzers_fingerprint(
baseline_analyzers_fingerprint: Mapping[str, Set[bytes]], graph: tf.Graph,
structured_inputs: Mapping[str, common_types.InputTensorType]):
structured_inputs: Mapping[str, common_types.TensorType]):
"""Validates analyzers fingerprint in `graph` is same as baseline."""
analyzers_fingerprint = graph_tools.get_analyzers_fingerprint(
graph, structured_inputs)
Expand All @@ -787,8 +787,8 @@ def _validate_analyzers_fingerprint(

def trace_and_write_v2_saved_model(
saved_model_dir: str,
preprocessing_fn: Callable[[Mapping[str, common_types.InputTensorType]],
Mapping[str, common_types.InputTensorType]],
preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
Mapping[str, common_types.TensorType]],
input_signature: Mapping[str, tf.TypeSpec], base_temp_dir: Optional[str],
baseline_analyzers_fingerprint: Mapping[str, Set[bytes]],
tensor_replacement_map: Optional[Dict[str, tf.Tensor]],
Expand Down
Loading

0 comments on commit b06e87b

Please sign in to comment.