mirror of https://github.com/tensorflow/models.git
99 lines
3.4 KiB
Python
99 lines
3.4 KiB
Python
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Common utilities for the Keras uplift library."""
|
|
|
|
from typing import Tuple
|
|
import tensorflow as tf, tf_keras
|
|
|
|
|
|
def expand_to_match_rank(a: tf.Tensor, b: tf.Tensor) -> tf.Tensor:
|
|
"""Expands tensor a to match the rank of tensor b.
|
|
|
|
Args:
|
|
a: a `tf.Tensor` of shape (D0, D1, ..., Dn).
|
|
b: a `tf.Tensor` of shape (D0, D1, ..., Dn, Dn+1, ... Dn+m).
|
|
|
|
Returns:
|
|
A `tf.Tensor` of shape (D0, D1, ..., DN, 1, ..., 1) if b has a higher rank
|
|
than a, otherwise a `tf.Tensor` of shape (D0, D1, ..., Dn)
|
|
"""
|
|
rank_deficit = b.shape.rank - a.shape.rank
|
|
for _ in range(rank_deficit):
|
|
a = tf.expand_dims(a, axis=-1)
|
|
return a
|
|
|
|
|
|
def split_by_treatment(
|
|
values: tf.Tensor, is_treatment: tf.Tensor
|
|
) -> Tuple[tf.Tensor, tf.Tensor]:
|
|
"""Splits a tensor into control and treatment tensors.
|
|
|
|
Args:
|
|
values: a `tf.Tensor` of shape (D0, D1, ..., DN).
|
|
is_treatment: a `tf.Tensor` of shape (D0,) or (D0, 1) castable to boolean
|
|
indicating if the example belongs to the treatment group (True) or control
|
|
group (False).
|
|
|
|
Returns:
|
|
A tuple with control and treatment values sliced by the is_treatment tensor.
|
|
"""
|
|
if is_treatment.shape.rank > 2 or (
|
|
is_treatment.shape == 2 and is_treatment.shape[1] != 1
|
|
):
|
|
raise ValueError(
|
|
"is_treatment tensor must be a tensor of shape (D0,) (D0, 1) but got a"
|
|
f" tensor of shape {is_treatment.shape} instead."
|
|
)
|
|
|
|
if values.shape[0] != is_treatment.shape[0]:
|
|
raise ValueError(
|
|
"values and is_treatment must be tensors of shapes (D0, D1, ..., DN)"
|
|
f" and (D0, 1) (or (D0,)), but got tensors of shapes {values.shape} and"
|
|
f" {is_treatment.shape} respectively."
|
|
)
|
|
|
|
if is_treatment.dtype == tf.string:
|
|
raise ValueError(
|
|
"is_treatment must be a tensor castable to boolean but got tensor"
|
|
f" {is_treatment} of dtype {is_treatment.dtype} instead."
|
|
)
|
|
|
|
# Assert is_treatment tensor containss only 0 or 1 values.
|
|
if is_treatment.dtype != tf.bool:
|
|
is_treatment_float = tf.cast(is_treatment, tf.float32)
|
|
tf.debugging.assert_equal(
|
|
tf.reduce_all(
|
|
tf.logical_or(is_treatment_float == 1.0, is_treatment_float == 0.0)
|
|
),
|
|
tf.convert_to_tensor(True),
|
|
message=(
|
|
"When is_treatment is not a boolean tensor all of its values must"
|
|
f" either be 0 or 1, but got tensor {is_treatment} instead."
|
|
),
|
|
)
|
|
|
|
if is_treatment.shape.rank == 1:
|
|
is_treatment = tf.expand_dims(is_treatment, axis=1)
|
|
|
|
is_treatment = tf.cast(is_treatment, tf.bool)
|
|
|
|
control_indices = tf.cast(tf.where(~is_treatment)[:, 0], dtype=tf.int32)
|
|
treatment_indices = tf.cast(tf.where(is_treatment)[:, 0], dtype=tf.int32)
|
|
|
|
control_values = tf.gather(values, control_indices)
|
|
treatment_values = tf.gather(values, treatment_indices)
|
|
|
|
return control_values, treatment_values
|