|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utility script to perform net surgery on a model. |
|
|
|
This script will perform net surgery on DeepLab models trained on a source |
|
dataset and create a new checkpoint for the target dataset. |
|
""" |
|
|
|
from typing import Any, Dict, Text, Tuple |
|
|
|
from absl import app |
|
from absl import flags |
|
from absl import logging |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
from google.protobuf import text_format |
|
from deeplab2 import common |
|
from deeplab2 import config_pb2 |
|
from deeplab2.data import dataset |
|
from deeplab2.model import deeplab |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
flags.DEFINE_string('source_dataset', 'cityscapes', |
|
'Dataset name on which the model has been pretrained. ' |
|
'Supported datasets: `cityscapes`.') |
|
|
|
flags.DEFINE_string('target_dataset', 'motchallenge_step', |
|
'Dataset name for conversion. Supported datasets: ' |
|
'`motchallenge_step`.') |
|
|
|
flags.DEFINE_string('input_config_path', None, |
|
'Path to a config file that defines the DeepLab model and ' |
|
'the checkpoint path.') |
|
|
|
flags.DEFINE_string('output_checkpoint_path', None, |
|
'Output filename for the generated checkpoint file.') |
|
|
|
|
|
_SUPPORTED_SOURCE_DATASETS = {'cityscapes'} |
|
_SUPPORTED_TARGET_DATASETS = {'motchallenge_step'} |
|
|
|
_CITYSCAPES_TO_MOTCHALLENGE_STEP = ( |
|
1, |
|
2, |
|
8, |
|
10, |
|
11, |
|
12, |
|
18, |
|
) |
|
|
|
_DATASET_TO_INFO = { |
|
'cityscapes': dataset.CITYSCAPES_PANOPTIC_INFORMATION, |
|
'motchallenge_step': dataset.MOTCHALLENGE_STEP_INFORMATION, |
|
} |
|
_INPUT_SIZE = (1025, 2049, 3) |
|
|
|
|
|
def _load_model( |
|
config_path: Text, |
|
source_dataset: Text) -> Tuple[deeplab.DeepLab, |
|
config_pb2.ExperimentOptions]: |
|
"""Load DeepLab model based on config and dataset.""" |
|
options = config_pb2.ExperimentOptions() |
|
with tf.io.gfile.GFile(config_path) as f: |
|
text_format.Parse(f.read(), options) |
|
options.model_options.panoptic_deeplab.semantic_head.output_channels = ( |
|
_DATASET_TO_INFO[source_dataset].num_classes) |
|
model = deeplab.DeepLab(options, |
|
_DATASET_TO_INFO[source_dataset]) |
|
return model, options |
|
|
|
|
|
def _convert_bias(input_tensor: np.ndarray, |
|
label_list: Tuple[int, ...]) -> np.ndarray: |
|
"""Converts 1D tensor bias w.r.t. label list. |
|
|
|
We select the subsets from the input_tensor based on the label_list. |
|
|
|
We assume input_tensor has shape = [num_classes], where |
|
input_tensor is the bias weights trained on source dataset, and num_classes |
|
is the number of classes in source dataset. |
|
|
|
Args: |
|
input_tensor: A numpy array with ndim == 1. |
|
label_list: A tuple of labels used for net surgery. |
|
|
|
Returns: |
|
A numpy array with values modified. |
|
|
|
Raises: |
|
ValueError: input_tensor's ndim != 1. |
|
""" |
|
if input_tensor.ndim != 1: |
|
raise ValueError('The bias tensor should have ndim == 1.') |
|
|
|
num_elements = len(label_list) |
|
output_tensor = np.zeros(num_elements, dtype=np.float32) |
|
for i, label in enumerate(label_list): |
|
output_tensor[i] = input_tensor[label] |
|
return output_tensor |
|
|
|
|
|
def _convert_kernels(input_tensor: np.ndarray, |
|
label_list: Tuple[int, ...]) -> np.ndarray: |
|
"""Converts 4D tensor kernels w.r.t. label list. |
|
|
|
We select the subsets from the input_tensor based on the label_list. |
|
|
|
We assume input_tensor has shape = [h, w, input_dim, num_classes], where |
|
input_tensor is the kernel weights trained on source dataset, and num_classes |
|
is the number of classes in source dataset. |
|
|
|
Args: |
|
input_tensor: A numpy array with ndim == 4. |
|
label_list: A tuple of labels used for net surgery. |
|
|
|
Returns: |
|
A numpy array with values modified. |
|
|
|
Raises: |
|
ValueError: input_tensor's ndim != 4. |
|
""" |
|
if input_tensor.ndim != 4: |
|
raise ValueError('The kernels tensor should have ndim == 4.') |
|
|
|
num_elements = len(label_list) |
|
kernel_height, kernel_width, input_dim, _ = input_tensor.shape |
|
output_tensor = np.zeros( |
|
(kernel_height, kernel_width, input_dim, num_elements), dtype=np.float32) |
|
for i, label in enumerate(label_list): |
|
output_tensor[:, :, :, i] = input_tensor[:, :, :, label] |
|
return output_tensor |
|
|
|
|
|
def _restore_checkpoint(restore_dict: Dict[Any, Any], |
|
options: config_pb2.ExperimentOptions |
|
) -> tf.train.Checkpoint: |
|
"""Reads the provided dict items from the checkpoint specified in options. |
|
|
|
Args: |
|
restore_dict: A mapping of checkpoint item to location. |
|
options: A experiment configuration containing the checkpoint location. |
|
|
|
Returns: |
|
The loaded checkpoint. |
|
""" |
|
ckpt = tf.train.Checkpoint(**restore_dict) |
|
if tf.io.gfile.isdir(options.model_options.initial_checkpoint): |
|
path = tf.train.latest_checkpoint( |
|
options.model_options.initial_checkpoint) |
|
status = ckpt.restore(path) |
|
else: |
|
status = ckpt.restore(options.model_options.initial_checkpoint) |
|
status.expect_partial().assert_existing_objects_matched() |
|
return ckpt |
|
|
|
|
|
def main(_) -> None: |
|
if FLAGS.source_dataset not in _SUPPORTED_SOURCE_DATASETS: |
|
raise ValueError('Source dataset is not supported. Use --help to get list ' |
|
'of supported datasets.') |
|
if FLAGS.target_dataset not in _SUPPORTED_TARGET_DATASETS: |
|
raise ValueError('Target dataset is not supported. Use --help to get list ' |
|
'of supported datasets.') |
|
|
|
logging.info('Loading DeepLab model from config %s', FLAGS.input_config_path) |
|
source_model, options = _load_model(FLAGS.input_config_path, |
|
FLAGS.source_dataset) |
|
logging.info('Load pretrained checkpoint.') |
|
_restore_checkpoint(source_model.checkpoint_items, options) |
|
source_model(tf.keras.Input(_INPUT_SIZE), training=False) |
|
|
|
logging.info('Perform net surgery.') |
|
semantic_weights = ( |
|
source_model._decoder._semantic_head.final_conv.get_weights()) |
|
|
|
if (FLAGS.source_dataset == 'cityscapes' and |
|
FLAGS.target_dataset == 'motchallenge_step'): |
|
|
|
semantic_weights[0] = _convert_kernels(semantic_weights[0], |
|
_CITYSCAPES_TO_MOTCHALLENGE_STEP) |
|
|
|
semantic_weights[1] = _convert_bias(semantic_weights[1], |
|
_CITYSCAPES_TO_MOTCHALLENGE_STEP) |
|
|
|
logging.info('Load target model without last semantic layer.') |
|
target_model, _ = _load_model(FLAGS.input_config_path, FLAGS.target_dataset) |
|
restore_dict = target_model.checkpoint_items |
|
del restore_dict[common.CKPT_SEMANTIC_LAST_LAYER] |
|
|
|
ckpt = _restore_checkpoint(restore_dict, options) |
|
target_model(tf.keras.Input(_INPUT_SIZE), training=False) |
|
target_model._decoder._semantic_head.final_conv.set_weights(semantic_weights) |
|
|
|
logging.info('Save checkpoint to output path: %s', |
|
FLAGS.output_checkpoint_path) |
|
ckpt = tf.train.Checkpoint(**target_model.checkpoint_items) |
|
ckpt.save(FLAGS.output_checkpoint_path) |
|
|
|
|
|
if __name__ == '__main__': |
|
flags.mark_flags_as_required( |
|
['input_config_path', 'output_checkpoint_path']) |
|
app.run(main) |
|
|