mirror of https://github.com/tensorflow/models.git
350 lines
12 KiB
Python
350 lines
12 KiB
Python
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Preprocessing ops."""
|
|
import functools
|
|
import tensorflow as tf, tf_keras
|
|
|
|
CROP_PROPORTION = 0.875 # Standard for ImageNet.
|
|
|
|
|
|
def random_apply(func, p, x):
|
|
"""Randomly apply function func to x with probability p."""
|
|
return tf.cond(
|
|
tf.less(
|
|
tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
|
|
tf.cast(p, tf.float32)), lambda: func(x), lambda: x)
|
|
|
|
|
|
def random_brightness(image, max_delta, impl='simclrv2'):
|
|
"""A multiplicative vs additive change of brightness."""
|
|
if impl == 'simclrv2':
|
|
factor = tf.random.uniform([], tf.maximum(1.0 - max_delta, 0),
|
|
1.0 + max_delta)
|
|
image = image * factor
|
|
elif impl == 'simclrv1':
|
|
image = tf.image.random_brightness(image, max_delta=max_delta)
|
|
else:
|
|
raise ValueError('Unknown impl {} for random brightness.'.format(impl))
|
|
return image
|
|
|
|
|
|
def to_grayscale(image, keep_channels=True):
|
|
image = tf.image.rgb_to_grayscale(image)
|
|
if keep_channels:
|
|
image = tf.tile(image, [1, 1, 3])
|
|
return image
|
|
|
|
|
|
def color_jitter_nonrand(image,
|
|
brightness=0,
|
|
contrast=0,
|
|
saturation=0,
|
|
hue=0,
|
|
impl='simclrv2'):
|
|
"""Distorts the color of the image (jittering order is fixed).
|
|
|
|
Args:
|
|
image: The input image tensor.
|
|
brightness: A float, specifying the brightness for color jitter.
|
|
contrast: A float, specifying the contrast for color jitter.
|
|
saturation: A float, specifying the saturation for color jitter.
|
|
hue: A float, specifying the hue for color jitter.
|
|
impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
|
|
version of random brightness.
|
|
|
|
Returns:
|
|
The distorted image tensor.
|
|
"""
|
|
with tf.name_scope('distort_color'):
|
|
def apply_transform(i, x, brightness, contrast, saturation, hue):
|
|
"""Apply the i-th transformation."""
|
|
if brightness != 0 and i == 0:
|
|
x = random_brightness(x, max_delta=brightness, impl=impl)
|
|
elif contrast != 0 and i == 1:
|
|
x = tf.image.random_contrast(
|
|
x, lower=1 - contrast, upper=1 + contrast)
|
|
elif saturation != 0 and i == 2:
|
|
x = tf.image.random_saturation(
|
|
x, lower=1 - saturation, upper=1 + saturation)
|
|
elif hue != 0:
|
|
x = tf.image.random_hue(x, max_delta=hue)
|
|
return x
|
|
|
|
for i in range(4):
|
|
image = apply_transform(i, image, brightness, contrast, saturation, hue)
|
|
image = tf.clip_by_value(image, 0., 1.)
|
|
return image
|
|
|
|
|
|
def color_jitter_rand(image,
|
|
brightness=0,
|
|
contrast=0,
|
|
saturation=0,
|
|
hue=0,
|
|
impl='simclrv2'):
|
|
"""Distorts the color of the image (jittering order is random).
|
|
|
|
Args:
|
|
image: The input image tensor.
|
|
brightness: A float, specifying the brightness for color jitter.
|
|
contrast: A float, specifying the contrast for color jitter.
|
|
saturation: A float, specifying the saturation for color jitter.
|
|
hue: A float, specifying the hue for color jitter.
|
|
impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
|
|
version of random brightness.
|
|
|
|
Returns:
|
|
The distorted image tensor.
|
|
"""
|
|
with tf.name_scope('distort_color'):
|
|
def apply_transform(i, x):
|
|
"""Apply the i-th transformation."""
|
|
|
|
def brightness_foo():
|
|
if brightness == 0:
|
|
return x
|
|
else:
|
|
return random_brightness(x, max_delta=brightness, impl=impl)
|
|
|
|
def contrast_foo():
|
|
if contrast == 0:
|
|
return x
|
|
else:
|
|
return tf.image.random_contrast(x, lower=1 - contrast,
|
|
upper=1 + contrast)
|
|
|
|
def saturation_foo():
|
|
if saturation == 0:
|
|
return x
|
|
else:
|
|
return tf.image.random_saturation(
|
|
x, lower=1 - saturation, upper=1 + saturation)
|
|
|
|
def hue_foo():
|
|
if hue == 0:
|
|
return x
|
|
else:
|
|
return tf.image.random_hue(x, max_delta=hue)
|
|
|
|
x = tf.cond(tf.less(i, 2),
|
|
lambda: tf.cond(tf.less(i, 1), brightness_foo, contrast_foo),
|
|
lambda: tf.cond(tf.less(i, 3), saturation_foo, hue_foo))
|
|
return x
|
|
|
|
perm = tf.random.shuffle(tf.range(4))
|
|
for i in range(4):
|
|
image = apply_transform(perm[i], image)
|
|
image = tf.clip_by_value(image, 0., 1.)
|
|
return image
|
|
|
|
|
|
def color_jitter(image, strength, random_order=True, impl='simclrv2'):
|
|
"""Distorts the color of the image.
|
|
|
|
Args:
|
|
image: The input image tensor.
|
|
strength: the floating number for the strength of the color augmentation.
|
|
random_order: A bool, specifying whether to randomize the jittering order.
|
|
impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
|
|
version of random brightness.
|
|
|
|
Returns:
|
|
The distorted image tensor.
|
|
"""
|
|
brightness = 0.8 * strength
|
|
contrast = 0.8 * strength
|
|
saturation = 0.8 * strength
|
|
hue = 0.2 * strength
|
|
if random_order:
|
|
return color_jitter_rand(
|
|
image, brightness, contrast, saturation, hue, impl=impl)
|
|
else:
|
|
return color_jitter_nonrand(
|
|
image, brightness, contrast, saturation, hue, impl=impl)
|
|
|
|
|
|
def random_color_jitter(image,
|
|
p=1.0,
|
|
color_jitter_strength=1.0,
|
|
impl='simclrv2'):
|
|
"""Perform random color jitter."""
|
|
def _transform(image):
|
|
color_jitter_t = functools.partial(
|
|
color_jitter, strength=color_jitter_strength, impl=impl)
|
|
image = random_apply(color_jitter_t, p=0.8, x=image)
|
|
return random_apply(to_grayscale, p=0.2, x=image)
|
|
|
|
return random_apply(_transform, p=p, x=image)
|
|
|
|
|
|
def gaussian_blur(image, kernel_size, sigma, padding='SAME'):
|
|
"""Blurs the given image with separable convolution.
|
|
|
|
|
|
Args:
|
|
image: Tensor of shape [height, width, channels] and dtype float to blur.
|
|
kernel_size: Integer Tensor for the size of the blur kernel. This is should
|
|
be an odd number. If it is an even number, the actual kernel size will be
|
|
size + 1.
|
|
sigma: Sigma value for gaussian operator.
|
|
padding: Padding to use for the convolution. Typically 'SAME' or 'VALID'.
|
|
|
|
Returns:
|
|
A Tensor representing the blurred image.
|
|
"""
|
|
radius = tf.cast(kernel_size / 2, dtype=tf.int32)
|
|
kernel_size = radius * 2 + 1
|
|
x = tf.cast(tf.range(-radius, radius + 1), dtype=tf.float32)
|
|
blur_filter = tf.exp(-tf.pow(x, 2.0) /
|
|
(2.0 * tf.pow(tf.cast(sigma, dtype=tf.float32), 2.0)))
|
|
blur_filter /= tf.reduce_sum(blur_filter)
|
|
# One vertical and one horizontal filter.
|
|
blur_v = tf.reshape(blur_filter, [kernel_size, 1, 1, 1])
|
|
blur_h = tf.reshape(blur_filter, [1, kernel_size, 1, 1])
|
|
num_channels = tf.shape(image)[-1]
|
|
blur_h = tf.tile(blur_h, [1, 1, num_channels, 1])
|
|
blur_v = tf.tile(blur_v, [1, 1, num_channels, 1])
|
|
expand_batch_dim = image.shape.ndims == 3
|
|
if expand_batch_dim:
|
|
# Tensorflow requires batched input to convolutions, which we can fake with
|
|
# an extra dimension.
|
|
image = tf.expand_dims(image, axis=0)
|
|
blurred = tf.nn.depthwise_conv2d(
|
|
image, blur_h, strides=[1, 1, 1, 1], padding=padding)
|
|
blurred = tf.nn.depthwise_conv2d(
|
|
blurred, blur_v, strides=[1, 1, 1, 1], padding=padding)
|
|
if expand_batch_dim:
|
|
blurred = tf.squeeze(blurred, axis=0)
|
|
return blurred
|
|
|
|
|
|
def random_blur(image, height, width, p=0.5):
|
|
"""Randomly blur an image.
|
|
|
|
Args:
|
|
image: `Tensor` representing an image of arbitrary size.
|
|
height: Height of output image.
|
|
width: Width of output image.
|
|
p: probability of applying this transformation.
|
|
|
|
Returns:
|
|
A preprocessed image `Tensor`.
|
|
"""
|
|
del width
|
|
|
|
def _transform(image):
|
|
sigma = tf.random.uniform([], 0.1, 2.0, dtype=tf.float32)
|
|
return gaussian_blur(
|
|
image, kernel_size=height // 10, sigma=sigma, padding='SAME')
|
|
|
|
return random_apply(_transform, p=p, x=image)
|
|
|
|
|
|
def distorted_bounding_box_crop(image,
|
|
bbox,
|
|
min_object_covered=0.1,
|
|
aspect_ratio_range=(0.75, 1.33),
|
|
area_range=(0.05, 1.0),
|
|
max_attempts=100,
|
|
scope=None):
|
|
"""Generates cropped_image using one of the bboxes randomly distorted.
|
|
|
|
See `tf.image.sample_distorted_bounding_box` for more documentation.
|
|
|
|
Args:
|
|
image: `Tensor` of image data.
|
|
bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
|
|
where each coordinate is [0, 1) and the coordinates are arranged
|
|
as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
|
|
image.
|
|
min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
|
|
area of the image must contain at least this fraction of any bounding
|
|
box supplied.
|
|
aspect_ratio_range: An optional list of `float`s. The cropped area of the
|
|
image must have an aspect ratio = width / height within this range.
|
|
area_range: An optional list of `float`s. The cropped area of the image
|
|
must contain a fraction of the supplied image within in this range.
|
|
max_attempts: An optional `int`. Number of attempts at generating a cropped
|
|
region of the image of the specified constraints. After `max_attempts`
|
|
failures, return the entire image.
|
|
scope: Optional `str` for name scope.
|
|
Returns:
|
|
(cropped image `Tensor`, distorted bbox `Tensor`).
|
|
"""
|
|
with tf.name_scope(scope or 'distorted_bounding_box_crop'):
|
|
shape = tf.shape(image)
|
|
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
|
|
shape,
|
|
bounding_boxes=bbox,
|
|
min_object_covered=min_object_covered,
|
|
aspect_ratio_range=aspect_ratio_range,
|
|
area_range=area_range,
|
|
max_attempts=max_attempts,
|
|
use_image_if_no_bounding_boxes=True)
|
|
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
|
|
|
|
# Crop the image to the specified bounding box.
|
|
offset_y, offset_x, _ = tf.unstack(bbox_begin)
|
|
target_height, target_width, _ = tf.unstack(bbox_size)
|
|
image = tf.image.crop_to_bounding_box(
|
|
image, offset_y, offset_x, target_height, target_width)
|
|
|
|
return image
|
|
|
|
|
|
def crop_and_resize(image, height, width):
|
|
"""Make a random crop and resize it to height `height` and width `width`.
|
|
|
|
Args:
|
|
image: Tensor representing the image.
|
|
height: Desired image height.
|
|
width: Desired image width.
|
|
|
|
Returns:
|
|
A `height` x `width` x channels Tensor holding a random crop of `image`.
|
|
"""
|
|
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
|
|
aspect_ratio = width / height
|
|
image = distorted_bounding_box_crop(
|
|
image,
|
|
bbox,
|
|
min_object_covered=0.1,
|
|
aspect_ratio_range=(3. / 4 * aspect_ratio, 4. / 3. * aspect_ratio),
|
|
area_range=(0.08, 1.0),
|
|
max_attempts=100,
|
|
scope=None)
|
|
return tf.image.resize([image], [height, width],
|
|
method=tf.image.ResizeMethod.BICUBIC)[0]
|
|
|
|
|
|
def random_crop_with_resize(image, height, width, p=1.0):
|
|
"""Randomly crop and resize an image.
|
|
|
|
Args:
|
|
image: `Tensor` representing an image of arbitrary size.
|
|
height: Height of output image.
|
|
width: Width of output image.
|
|
p: Probability of applying this transformation.
|
|
|
|
Returns:
|
|
A preprocessed image `Tensor`.
|
|
"""
|
|
|
|
def _transform(image): # pylint: disable=missing-docstring
|
|
image = crop_and_resize(image, height, width)
|
|
return image
|
|
|
|
return random_apply(_transform, p=p, x=image)
|