models/official/vision/docs/customize_input_pipeline.md

16 KiB
Raw Permalink Blame History

Customize Input Pipeline

Overview

A task is a class that encapsulates the logic of loading data, building models, performing one-step training and validation, etc. It connects all components together and is called by the base Trainer. You can create your own task by inheriting from base Task, or from one of the tasks we already defined, if most of the operations can be reused. An ExampleTask inheriting from ImageClassificationTask can be found here.

In a task class, the build_inputs method is responsible for building the input pipeline for training and evaluation. Specifically, it will instantiate a Decoder object and a Parser object, which are used to create an InputReader that will generate a tf.data.Dataset object.

Here's an example code snippet that demonstrates how to create a custom build_inputs method:

def build_inputs(
    self,
    params: exp_cfg.DataConfig,
    input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
  ....


  decoder = sample_input.Decoder()
  parser = sample_input.Parser(
      output_size=..., num_classes=...)
  reader = input_reader_factory.input_reader_generator(
      params,
      dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
      decoder_fn=decoder.decode,
      parser_fn=parser.parse_fn(params.is_training))
      ....


  dataset = reader.read(input_context=input_context)
  return dataset

The class being responsible for building the input pipeline is InputReader with interface

class InputReader:
  """Input reader that returns a tf.data.Dataset instance."""

  def __init__(
      self,
      params: cfg.DataConfig,
      dataset_fn=tf.data.TFRecordDataset,
      decoder_fn: Optional[Callable[..., Any]] = None,
      combine_fn: Optional[Callable[..., Any]] = None,
      sample_fn: Optional[Callable[..., Any]] = None,
      parser_fn: Optional[Callable[..., Any]] = None,
      filter_fn: Optional[Callable[..., tf.Tensor]] = None,
      transform_and_batch_fn: Optional[
          Callable[
              [tf.data.Dataset, Optional[tf.distribute.InputContext]],
              tf.data.Dataset,
          ]
      ] = None,
      postprocess_fn: Optional[Callable[..., Any]] = None,
  ):
  ....

  def read(self,
            input_context: Optional[tf.distribute.InputContext] = None,
            dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
      """Generates a tf.data.Dataset object."""
      if dataset is None:
        dataset = self._read_data_source(self._matched_files, self._dataset_fn,
                                        input_context)
      dataset = self._decode_and_parse_dataset(dataset, self._global_batch_size,
                                              input_context)
      dataset = _maybe_map_fn(dataset, self._postprocess_fn)
      if not (self._enable_shared_tf_data_service_between_parallel_trainers and
              self._apply_tf_data_service_before_batching):
        dataset = self._maybe_apply_data_service(dataset, input_context)

      if self._deterministic is not None:
        options = tf.data.Options()
        options.deterministic = self._deterministic
        dataset = dataset.with_options(options)
      if self._autotune_algorithm:
        options = tf.data.Options()
        options.autotune.autotune_algorithm = (
            tf.data.experimental.AutotuneAlgorithm[self._autotune_algorithm])
        dataset = dataset.with_options(options)
      return dataset.prefetch(self._prefetch_buffer_size)

Therefore, customizing the input pipeline is equivalent to having customized versions of dataset_fn, decoder_fn, etc. The execution order is generally as:

dataset_fn -> decoder_fn -> combine_fn -> parser_fn -> filter_fn ->
transform_and_batch_fn -> postprocess_fn

The transform_and_batch_fn is an optional function that merges multiple examples into a batch and its default behavior to dataset.batch if not specified. In this workflow, the functions before transform_and_batch_fn, e.g. dataset_fn, decoder_fn, consume tensors without the batch dimension, while postprocess_fn will consume tensors with the batch dimension.

We have essentially covered decoder_fn, and parser_fn is another very important one that takes the decoded raw tensors dict and parses them into a dictionary of tensors that can be consumed by the model. It will be executed after decoder_fn.

It is also worth noting that optimizing of the input pipeline through batching, shuffling and prefetching is also implemented in this class.

Parser

A custom data loader can also be useful if you want to take advantage of features such as data augmentation.

Customizing preprocessing is useful because it allows the user to tailor the preprocessing steps to suit the specific requirements of the task. While there are standard preprocessing techniques that are commonly used, different applications may require different preprocessing steps. Additionally, custom preprocessing can also improve the efficiency and accuracy of the model by removing unnecessary steps, reducing computational resources or adding steps that are important to the specific task being addressed.

For example, tasks such as object detection or segmentation may require additional preprocessing steps such as resizing, cropping, or data augmentation to improve the robustness of the model. Below are some essential steps to customize a parser.

Instructions

  • Create a Subclass

Like Decoder, create class Parser(parser.Parser) in the same file.The Parser class should be a childclass of the generic parser interface and must implement all the abstract methods. It should have the implementation of abstract methods _parse_train_data and _parse_eval_data, to generate images and labels for model training and evaluation respectively. The below example takes only two arguments but one can freely add as many arguments as needed.

class Parser(parser.Parser):

 def __init__(self, output_size: List[int], num_classes: float):

   self._output_size = output_size
   self._num_classes = num_classes
   self._dtype = tf.float32

    ....

Refer to the data parser and processing class for Mask R-CNN for more complex cases. The class has multiple parameters related to data augmentation, masking, anchor boxes, data type of output image and more.


  • Complete Abstract Methods

To define your own Parser, the user should override abstract functions _parse_train_data and _parse_eval_data of the parser interface in the subclass, where decoded tensors are parsed with pre-processing steps for training and evaluation respectively. The output from the two functions can be any structure like a tuple, list or dictionary.

  @abc.abstractmethod
  def _parse_train_data(self, decoded_tensors):
    """Generates images and labels that are usable for model training.

    Args:
      decoded_tensors: a dict of Tensors produced by the decoder.

    Returns:
      images: the image tensor.
      labels: a dict of Tensors that contains labels.
    """
    pass

  @abc.abstractmethod
  def _parse_eval_data(self, decoded_tensors):
    """Generates images and labels that are usable for model evaluation.

    Args:
      decoded_tensors: a dict of Tensors produced by the decoder.

    Returns:
      images: the image tensor.
      labels: a dict of Tensors that contains labels.
    """
    pass

The input of _parse_train_data and _parse_eval_data is a dict of Tensors produced by the decoder; the output of these two functions is typically a tuple of (processe_image, processed_label). The user may perform any processing steps in these two functions as long as the interface is aligned. Note that the processing steps in _parse_train_data and _parse_eval_data are typically different since data augmentation is usually only applied to training. For Example, refer to the Data parser and processing steps for classification. We can observe that

-For _parse_train_data, the following steps are performed

                  -    Image decoding
                  -    Random cropping
                  -    Random flipping
                  -    Color jittering
                  -    Image resizing
                  -    Auto-augmentation with autoaug, randaug etc.
                  -    Image normalization

-For _parse_eval_data, the following steps are performed

                  -     Image decoding
                  -    Center cropping
                  -    Image resizing
                  -    Image normalization

Additional Methods

The subclass (say sample_input.py) must include implementations for all of the abstract methods defined in the Interface Decoder and Parser , as well as any additional methods that are necessary for the subclass's functionality.

For Example, In object detection, the decoder will take the serialized example and output a dictionary of tensors with multiple fields that process and analyze to detect objects and determine their location and orientation in the image. Separate methods for each of the above fields can make the code easier to read and maintain, especially when the class contains a large number of methods.

Refer Data parser for Object Detection here.

Example

Creating a Parser is an optional step and it varies with the use case. Below are some use cases where we have included the Decoder and Parser based on the requirements.

Use case Decoder/Parser
Classification Both Decoder and Parser
Segmentation Only Parser

Input Pipeline

Decoder and Parser discussed previously define how to decode and parse per data point e.g. an image. However a complete input pipeline would need to handle reading data from files in a distributed system, applying random perturbations, batching etc. You may find more details about these concepts here.

We have established a well tuned input pipeline as defined in the InputReader class, such that the user wont need to modify it in most cases. The input pipeline roughly follows
                  -    Shuffling the files
                  -    Decoding
                  -    Parsing
                  -    Caching
                  -    If training: repeat and shuffle
                  -    Batching
                  -    Prefetching

For the rest of this section, we will discuss one particular use case that requires the modification of the typical pipeline by maybe creating a subclass of the InputReader.

Combines multiple datasets

Create a custom InputReader by subclassing InputReader interface. Custom InputReader class allows the user to combine multiple datasets, helps in mixing a labeled and pseudo-labeled dataset etc. The business logic is implemented in the read() method which finally generates a tf.data.Dataset object.

The exact implementation of an InputReader can vary depending on the specific requirements of your task and the type of input data you're working with, data format, and preprocessing requirements.

Here is an example of how to create a custom InputReader by subclassing InputReader interface:

class CustomInputReader(input_reader.InputReader):

 def __init__(self,
              params: cfg.DataConfig,
              dataset_fn=tf.data.TFRecordDataset,
              pseudo_label_dataset_fn=tf.data.TFRecordDataset,
                ....):

 def read(
     self,
     input_context: Optional[tf.distribute.InputContext] = None
 ) -> tf.data.Dataset:


   labeled_dataset =   ....
   pseudo_labeled_dataset =   ....
   dataset_concat = tf.data.Dataset.zip(
       (labeled_dataset, pseudo_labeled_dataset))
  ....

   return dataset_concat.prefetch(tf.data.experimental.AUTOTUNE)

Example

Refer to the InputReader for vision in TFM. The CombinationDatasetInputReader class mixes a labeled and pseudo-labeled dataset and returns a tf.data.Dataset instance.