mirror of https://github.com/tensorflow/models.git
162 lines
6.5 KiB
Python
162 lines
6.5 KiB
Python
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""2D detection evaluator for the Waymo Open Dataset."""
|
|
import pprint
|
|
from absl import logging
|
|
|
|
import tensorflow as tf, tf_keras
|
|
from official.vision.ops import box_ops
|
|
from waymo_open_dataset import label_pb2
|
|
from waymo_open_dataset.metrics.python import wod_detection_evaluator
|
|
from waymo_open_dataset.protos import breakdown_pb2
|
|
from waymo_open_dataset.protos import metrics_pb2
|
|
|
|
|
|
def get_2d_detection_default_config():
|
|
"""Returns the config proto for WOD 2D detection Evaluation."""
|
|
config = metrics_pb2.Config()
|
|
|
|
config.breakdown_generator_ids.append(breakdown_pb2.Breakdown.OBJECT_TYPE)
|
|
difficulty = config.difficulties.add()
|
|
difficulty.levels.append(label_pb2.Label.LEVEL_1)
|
|
difficulty.levels.append(label_pb2.Label.LEVEL_2)
|
|
config.breakdown_generator_ids.append(breakdown_pb2.Breakdown.ALL_BUT_SIGN)
|
|
difficulty = config.difficulties.add()
|
|
difficulty.levels.append(label_pb2.Label.LEVEL_1)
|
|
difficulty.levels.append(label_pb2.Label.LEVEL_2)
|
|
config.matcher_type = metrics_pb2.MatcherProto.TYPE_HUNGARIAN
|
|
config.iou_thresholds.append(0.0)
|
|
config.iou_thresholds.append(0.7)
|
|
config.iou_thresholds.append(0.5)
|
|
config.iou_thresholds.append(0.5)
|
|
config.iou_thresholds.append(0.5)
|
|
config.box_type = label_pb2.Label.Box.TYPE_2D
|
|
|
|
for i in range(100):
|
|
config.score_cutoffs.append(i * 0.01)
|
|
config.score_cutoffs.append(1.0)
|
|
|
|
return config
|
|
|
|
|
|
class WOD2dDetectionEvaluator(wod_detection_evaluator.WODDetectionEvaluator):
|
|
"""WOD 2D detection evaluation metric class."""
|
|
|
|
def __init__(self, config=None):
|
|
if config is None:
|
|
config = get_2d_detection_default_config()
|
|
super().__init__(config=config)
|
|
|
|
def _remove_padding(self, tensor_dict, num_valid):
|
|
"""Remove the paddings of the prediction/groundtruth data."""
|
|
result_tensor_dict = {}
|
|
gather_indices = tf.range(num_valid)
|
|
for k, v in tensor_dict.items():
|
|
if 'frame_id' in k:
|
|
result_tensor_dict[k] = tf.tile([v], [num_valid])
|
|
else:
|
|
result_tensor_dict[k] = tf.gather(v, gather_indices)
|
|
return result_tensor_dict
|
|
|
|
def update_state(self, groundtruths, predictions):
|
|
"""Update the metrics state with prediction and ground-truth data.
|
|
|
|
Args:
|
|
groundtruths: a dictionary of Tensors including the fields below.
|
|
Required fields:
|
|
- source_id: a numpy array of int or string of shape [batch_size].
|
|
- num_detections: a numpy array of int of shape [batch_size].
|
|
- boxes: a numpy array of float of shape [batch_size, K, 4].
|
|
- classes: a numpy array of int of shape [batch_size, K].
|
|
- difficulties: a numpy array of int of shape [batch_size, K].
|
|
|
|
predictions: a dictionary of tensors including the fields below.
|
|
Required fields:
|
|
- source_id: a numpy array of int or string of shape [batch_size].
|
|
- image_info: a numpy array of float of shape [batch_size, 4, 2].
|
|
- num_detections: a numpy array of int of shape [batch_size].
|
|
- detection_boxes: a numpy array of float of shape [batch_size, K, 4].
|
|
- detection_classes: a numpy array of int of shape [batch_size, K].
|
|
- detection_scores: a numpy array of float of shape [batch_size, K].
|
|
"""
|
|
# Preprocess potentially aggregated tensors.
|
|
for k, v in groundtruths.items():
|
|
if isinstance(v, tuple):
|
|
groundtruths[k] = tf.concat(v, axis=0)
|
|
for k, v in predictions.items():
|
|
if isinstance(v, tuple):
|
|
predictions[k] = tf.concat(v, axis=0)
|
|
|
|
# Change cyclists' type id from 3 to 4, where 3 is reserved for sign.
|
|
groundtruth_type = tf.cast(groundtruths['classes'], tf.uint8)
|
|
groundtruth_type = tf.where(
|
|
tf.equal(groundtruth_type, 3),
|
|
tf.ones_like(groundtruth_type) * 4, groundtruth_type)
|
|
prediction_type = tf.cast(predictions['detection_classes'], tf.uint8)
|
|
prediction_type = tf.where(
|
|
tf.equal(prediction_type, 3),
|
|
tf.ones_like(prediction_type) * 4, prediction_type)
|
|
|
|
# Rescale the detection boxes back to original scale.
|
|
image_scale = tf.tile(predictions['image_info'][:, 2:3, :], (1, 1, 2))
|
|
prediction_bbox = predictions['detection_boxes'] / image_scale
|
|
|
|
batch_size = tf.shape(groundtruths['source_id'])[0]
|
|
|
|
for i in tf.range(batch_size):
|
|
frame_groundtruths = {
|
|
'ground_truth_frame_id':
|
|
groundtruths['source_id'][i],
|
|
'ground_truth_bbox':
|
|
box_ops.yxyx_to_cycxhw(
|
|
tf.cast(groundtruths['boxes'][i], tf.float32)),
|
|
'ground_truth_type':
|
|
groundtruth_type[i],
|
|
'ground_truth_difficulty':
|
|
tf.cast(groundtruths['difficulties'][i], tf.uint8),
|
|
}
|
|
frame_groundtruths = self._remove_padding(
|
|
frame_groundtruths, groundtruths['num_detections'][i])
|
|
frame_predictions = {
|
|
'prediction_frame_id':
|
|
groundtruths['source_id'][i],
|
|
'prediction_bbox':
|
|
box_ops.yxyx_to_cycxhw(
|
|
tf.cast(prediction_bbox[i], tf.float32)),
|
|
'prediction_type':
|
|
prediction_type[i],
|
|
'prediction_score':
|
|
tf.cast(predictions['detection_scores'][i], tf.float32),
|
|
'prediction_overlap_nlz':
|
|
tf.zeros_like(predictions['detection_scores'][i], dtype=tf.bool)
|
|
}
|
|
frame_predictions = self._remove_padding(frame_predictions,
|
|
predictions['num_detections'][i])
|
|
super().update_state(frame_groundtruths, frame_predictions)
|
|
|
|
def evaluate(self):
|
|
"""Compute the final metrics."""
|
|
ap, _, _, _, _, _, _ = super().evaluate()
|
|
metric_dict = {}
|
|
for i, name in enumerate(self._breakdown_names):
|
|
# Skip sign metrics in 2d detection task.
|
|
if 'SIGN' in name:
|
|
continue
|
|
metric_dict['WOD metrics/{}/AP'.format(name)] = ap[i]
|
|
pp = pprint.PrettyPrinter()
|
|
logging.info('WOD Detection Metrics: \n %s', pp.pformat(metric_dict))
|
|
|
|
return metric_dict
|