deeplabv3+: 输入模块全解析 输入数据 更多维度 多标注

傷城~ 2024-04-18 08:42 23阅读 0赞

之前已经讲过了generator,这次是要建立一个更详细的框架,数据到底怎么被处理的。可以作为样例学习。写的并不详细,因为你如果要做更深的工作,你需要很高自主能力,大多数人都具备,所以我就不废话了(主要是忙)。

源数据:图片/矩阵

目标数据:tensorflow 标准的Dataset

主要过程:

  1. 源数据->build_voc_data.py->tfrecord
  2. tfrecord->data_generator.py->Dataset

第一步将数据转换成tfrecord

直接读入的是二进制图像数据

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3UwMTMyNDk4NTM_size_16_color_FFFFFF_t_70

之后利用类build_data.py直接将其转换成feature存入tfrecord,其实就是转换一下格式成为feature,然后用tf.train.Example画个格子装进去。

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3UwMTMyNDk4NTM_size_16_color_FFFFFF_t_70 1

所以如果要多存入一些数据,比如深度数据,比如视频数据,那就需要改三个点:

1.build_data.py :def image_seg_to_tfexample加入你自己的定义。注意你的数据如果是整数就用_int64,浮点数就自己写个浮点数的,或者用string

  1. return tf.train.Example(features=tf.train.Features(feature={
  2. 'image/encoded': _bytes_list_feature(image_data),
  3. 'image/filename': _bytes_list_feature(filename),
  4. 'image/format': _bytes_list_feature(
  5. _IMAGE_FORMAT_MAP[FLAGS.image_format]),
  6. 'image/height': _int64_list_feature(height),
  7. 'image/width': _int64_list_feature(width),
  8. 'image/channels': _int64_list_feature(3),
  9. 'image/segmentation/class/encoded': (
  10. _bytes_list_feature(seg_data)),
  11. 'image/segmentation/class/format': _bytes_list_feature(
  12. FLAGS.label_format),
  13. }))

2.build_voc_data.py:读入你自己的数据,下面就是读取图像和标注。你需要写你自己的数据读入。

  1. image_filename = os.path.join(
  2. FLAGS.image_folder, filenames[i] + '.' + FLAGS.image_format)
  3. image_data = tf.gfile.FastGFile(image_filename, 'rb').read()
  4. height, width = image_reader.read_image_dims(image_data)
  5. # Read the semantic segmentation annotation.
  6. seg_filename = os.path.join(
  7. FLAGS.semantic_segmentation_folder,
  8. filenames[i] + '.' + FLAGS.label_format)
  9. seg_data = tf.gfile.FastGFile(seg_filename, 'rb').read()
  10. seg_height, seg_width = label_reader.read_image_dims(seg_data)

3.build_voc_data.py:传递你的数据直接转换。

  1. example = build_data.image_seg_to_tfexample(
  2. image_data, filenames[i], height, width, seg_data)

第二步通过tfrecord建立Dataset

1.主函数是data_generator,调用了input_preprocess.py,preprocess调用了core.preprocess_utils.py

data_generator 读取tfrecord 需要改变features字典

  1. features = {
  2. 'image/encoded':
  3. tf.FixedLenFeature((), tf.string, default_value=''),
  4. 'image/filename':
  5. tf.FixedLenFeature((), tf.string, default_value=''),
  6. 'image/format':
  7. tf.FixedLenFeature((), tf.string, default_value='jpeg'),
  8. 'image/height':
  9. tf.FixedLenFeature((), tf.int64, default_value=0),
  10. 'image/width':
  11. tf.FixedLenFeature((), tf.int64, default_value=0),
  12. 'image/segmentation/class/encoded':
  13. tf.FixedLenFeature((), tf.string, default_value=''),
  14. 'image/segmentation/class/format':
  15. tf.FixedLenFeature((), tf.string, default_value='png'),
  16. 'image/superpixel16/class/encoded':
  17. tf.VarLenFeature(tf.int64),
  18. }

注意如果不能指定读取形状,那就用VarLenFeature。不过需要将var读取的稀疏数据通过

  1. tf.sparse_tensor_to_dense

转成稠密的。

2.增加sample里面的字典key,你的数据key

  1. sample = {
  2. common.IMAGE: image,
  3. common.IMAGE_NAME: image_name,
  4. common.HEIGHT: parsed_features['image/height'],
  5. common.WIDTH: parsed_features['image/width'],
  6. common.SUPER16:super16,
  7. common.SUPER8:super8,
  8. 'shape' : shape,
  9. 'ori_image':image,
  10. }

这个时候你已经能通过调试查看自己的数据有没有存储并且正确读取。

调试代码:

  1. # Copyright 2018 The TensorFlow Authors All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Training script for the DeepLab model.
  16. See model.py for more details and usage.
  17. """
  18. import six
  19. import tensorflow as tf
  20. from tensorflow.python.ops import math_ops
  21. import sys
  22. sys.path.append('..')
  23. from deeplab import common
  24. from deeplab import model
  25. from deeplab.datasets import data_generator_super as data_generator
  26. from deeplab.utils import train_utils
  27. import os
  28. from matplotlib import pyplot as plt
  29. from skimage.segmentation import slic, mark_boundaries
  30. os.environ['CUDA_VISIBLE_DEVICES']='5'
  31. import numpy as np
  32. flags = tf.app.flags
  33. FLAGS = flags.FLAGS
  34. # Settings for multi-GPUs/multi-replicas training.
  35. flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy.')
  36. flags.DEFINE_boolean('clone_on_cpu', False, 'Use CPUs to deploy clones.')
  37. flags.DEFINE_integer('num_replicas', 1, 'Number of worker replicas.')
  38. flags.DEFINE_integer('startup_delay_steps', 15,
  39. 'Number of training steps between replicas startup.')
  40. flags.DEFINE_integer(
  41. 'num_ps_tasks', 0,
  42. 'The number of parameter servers. If the value is 0, then '
  43. 'the parameters are handled locally by the worker.')
  44. flags.DEFINE_string('master', '', 'BNS name of the tensorflow server')
  45. flags.DEFINE_integer('task', 0, 'The task ID.')
  46. # Settings for logging.
  47. flags.DEFINE_string('train_logdir', None,
  48. 'Where the checkpoint and logs are stored.')
  49. flags.DEFINE_integer('log_steps', 10,
  50. 'Display logging information at every log_steps.')
  51. flags.DEFINE_integer('save_interval_secs', 1200,
  52. 'How often, in seconds, we save the model to disk.')
  53. flags.DEFINE_integer('save_summaries_secs', 600,
  54. 'How often, in seconds, we compute the summaries.')
  55. flags.DEFINE_boolean(
  56. 'save_summaries_images', False,
  57. 'Save sample inputs, labels, and semantic predictions as '
  58. 'images to summary.')
  59. # Settings for profiling.
  60. flags.DEFINE_string('profile_logdir', None,
  61. 'Where the profile files are stored.')
  62. # Settings for training strategy.
  63. flags.DEFINE_enum('learning_policy', 'poly', ['poly', 'step'],
  64. 'Learning rate policy for training.')
  65. # Use 0.007 when training on PASCAL augmented training set, train_aug. When
  66. # fine-tuning on PASCAL trainval set, use learning rate=0.0001.
  67. flags.DEFINE_float('base_learning_rate', .0001,
  68. 'The base learning rate for model training.')
  69. flags.DEFINE_float('learning_rate_decay_factor', 0.1,
  70. 'The rate to decay the base learning rate.')
  71. flags.DEFINE_integer('learning_rate_decay_step', 2000,
  72. 'Decay the base learning rate at a fixed step.')
  73. flags.DEFINE_float('learning_power', 0.9,
  74. 'The power value used in the poly learning policy.')
  75. flags.DEFINE_integer('training_number_of_steps', 30000,
  76. 'The number of steps used for training')
  77. flags.DEFINE_float('momentum', 0.9, 'The momentum value to use')
  78. # When fine_tune_batch_norm=True, use at least batch size larger than 12
  79. # (batch size more than 16 is better). Otherwise, one could use smaller batch
  80. # size and set fine_tune_batch_norm=False.
  81. flags.DEFINE_integer('train_batch_size', 8,
  82. 'The number of images in each batch during training.')
  83. # For weight_decay, use 0.00004 for MobileNet-V2 or Xcpetion model variants.
  84. # Use 0.0001 for ResNet model variants.
  85. flags.DEFINE_float('weight_decay', 0.00004,
  86. 'The value of the weight decay for training.')
  87. flags.DEFINE_list('train_crop_size', '513,513',
  88. 'Image crop size [height, width] during training.')
  89. flags.DEFINE_float(
  90. 'last_layer_gradient_multiplier', 1.0,
  91. 'The gradient multiplier for last layers, which is used to '
  92. 'boost the gradient of last layers if the value > 1.')
  93. flags.DEFINE_boolean('upsample_logits', True,
  94. 'Upsample logits during training.')
  95. # Hyper-parameters for NAS training strategy.
  96. flags.DEFINE_float(
  97. 'drop_path_keep_prob', 1.0,
  98. 'Probability to keep each path in the NAS cell when training.')
  99. # Settings for fine-tuning the network.
  100. flags.DEFINE_string('tf_initial_checkpoint', None,
  101. 'The initial checkpoint in tensorflow format.')
  102. # Set to False if one does not want to re-use the trained classifier weights.
  103. flags.DEFINE_boolean('initialize_last_layer', True,
  104. 'Initialize the last layer.')
  105. flags.DEFINE_boolean('last_layers_contain_logits_only', False,
  106. 'Only consider logits as last layers or not.')
  107. flags.DEFINE_integer('slow_start_step', 0,
  108. 'Training model with small learning rate for few steps.')
  109. flags.DEFINE_float('slow_start_learning_rate', 1e-4,
  110. 'Learning rate employed during slow start.')
  111. # Set to True if one wants to fine-tune the batch norm parameters in DeepLabv3.
  112. # Set to False and use small batch size to save GPU memory.
  113. flags.DEFINE_boolean('fine_tune_batch_norm', True,
  114. 'Fine tune the batch norm parameters or not.')
  115. flags.DEFINE_float('min_scale_factor', 0.5,
  116. 'Mininum scale factor for data augmentation.')
  117. flags.DEFINE_float('max_scale_factor', 2.,
  118. 'Maximum scale factor for data augmentation.')
  119. flags.DEFINE_float('scale_factor_step_size', 0.25,
  120. 'Scale factor step size for data augmentation.')
  121. # For `xception_65`, use atrous_rates = [12, 24, 36] if output_stride = 8, or
  122. # rates = [6, 12, 18] if output_stride = 16. For `mobilenet_v2`, use None. Note
  123. # one could use different atrous_rates/output_stride during training/evaluation.
  124. flags.DEFINE_multi_integer('atrous_rates', None,
  125. 'Atrous rates for atrous spatial pyramid pooling.')
  126. flags.DEFINE_integer('output_stride', 16,
  127. 'The ratio of input to output spatial resolution.')
  128. # Hard example mining related flags.
  129. flags.DEFINE_integer(
  130. 'hard_example_mining_step', 0,
  131. 'The training step in which exact hard example mining kicks off. Note we '
  132. 'gradually reduce the mining percent to the specified '
  133. 'top_k_percent_pixels. For example, if hard_example_mining_step=100K and '
  134. 'top_k_percent_pixels=0.25, then mining percent will gradually reduce from '
  135. '100% to 25% until 100K steps after which we only mine top 25% pixels.')
  136. flags.DEFINE_float(
  137. 'top_k_percent_pixels', 1.0,
  138. 'The top k percent pixels (in terms of the loss values) used to compute '
  139. 'loss during training. This is useful for hard pixel mining.')
  140. # Quantization setting.
  141. flags.DEFINE_integer(
  142. 'quantize_delay_step', -1,
  143. 'Steps to start quantized training. If < 0, will not quantize model.')
  144. # Dataset settings.
  145. flags.DEFINE_string('dataset', 'pascal_voc_seg',
  146. 'Name of the segmentation dataset.')
  147. flags.DEFINE_string('train_split', 'train_aug',
  148. 'Which split of the dataset to be used for training')
  149. flags.DEFINE_string('dataset_dir', None, 'Where the dataset reside.')
  150. def main(unused_argv):
  151. tf.logging.set_verbosity(tf.logging.INFO)
  152. tf.gfile.MakeDirs(FLAGS.train_logdir)
  153. tf.logging.info('Training on %s set', FLAGS.train_split)
  154. graph = tf.Graph()
  155. with graph.as_default():
  156. with tf.device(tf.train.replica_device_setter(ps_tasks=FLAGS.num_ps_tasks)):
  157. assert FLAGS.train_batch_size % FLAGS.num_clones == 0, (
  158. 'Training batch size not divisble by number of clones (GPUs).')
  159. clone_batch_size = FLAGS.train_batch_size // FLAGS.num_clones
  160. dataset = data_generator.Dataset(
  161. dataset_name=FLAGS.dataset,
  162. split_name=FLAGS.train_split,
  163. dataset_dir=FLAGS.dataset_dir,
  164. batch_size=clone_batch_size,
  165. crop_size=[int(sz) for sz in FLAGS.train_crop_size],
  166. min_resize_value=FLAGS.min_resize_value,
  167. max_resize_value=FLAGS.max_resize_value,
  168. resize_factor=FLAGS.resize_factor,
  169. min_scale_factor=FLAGS.min_scale_factor,
  170. max_scale_factor=FLAGS.max_scale_factor,
  171. scale_factor_step_size=FLAGS.scale_factor_step_size,
  172. model_variant=FLAGS.model_variant,
  173. num_readers=2,
  174. is_training=True,
  175. should_shuffle=True,
  176. should_repeat=True)
  177. iterator = dataset.get_one_shot_iterator()
  178. next_element = iterator.get_next()
  179. # Soft placement allows placing on CPU ops without GPU implementation.
  180. session_config = tf.ConfigProto(
  181. allow_soft_placement=True, log_device_placement=False)
  182. last_layers = model.get_extra_layer_scopes(
  183. FLAGS.last_layers_contain_logits_only)
  184. init_fn = None
  185. #FLAGS.tf_initial_checkpoint = '/home/DATA/liutian/tmp/tfdeeplab/deeplab/datasets/pascal_voc_seg/init_models/deeplabv3_pascal_train_aug/model.ckpt'
  186. if FLAGS.tf_initial_checkpoint:
  187. init_fn = train_utils.get_model_init_fn(
  188. FLAGS.train_logdir,
  189. FLAGS.tf_initial_checkpoint,
  190. FLAGS.initialize_last_layer,
  191. last_layers,
  192. ignore_missing_vars=True)
  193. stop_hook = tf.train.StopAtStepHook(
  194. last_step=FLAGS.training_number_of_steps)
  195. profile_dir = FLAGS.profile_logdir
  196. if profile_dir is not None:
  197. tf.gfile.MakeDirs(profile_dir)
  198. sess = tf.Session()
  199. next_element = sess.run(next_element)
  200. shape = next_element['shape']
  201. ori_image = np.array(next_element['image']).astype(np.uint8)
  202. print("shape is ", sess.run(shape))
  203. if __name__ == '__main__':
  204. flags.mark_flag_as_required('train_logdir')
  205. flags.mark_flag_as_required('dataset_dir')
  206. tf.app.run()

3.preprocess处理你的数据。

原始的数据(image+标注)如果需要放缩,多尺度测试,以及padding,那么 你的数据也需要的话,你就得改def _preprocess_image

该函数将新数据传入preprocess_image_and_label.

所以最主要的是改core里面的preprocess_utils

很多函数接受的只有image和label,你可以将你的数据输入,然后仿照image和label对你的数据进行操作。

比如:

  1. image = tf.squeeze(tf.image.resize_bilinear(
  2. tf.expand_dims(image, 0),
  3. new_dim,
  4. align_corners=True), [0])
  5. if label is not None:
  6. label = tf.squeeze(tf.image.resize_nearest_neighbor(
  7. tf.expand_dims(label, 0),
  8. new_dim,
  9. align_corners=True), [0])
  10. new_sup = []
  11. if superpixels is not None:
  12. for superpixel in superpixels:
  13. superpixel=tf.squeeze(tf.image.resize_nearest_neighbor(
  14. tf.expand_dims(superpixel, 0),
  15. new_dim,
  16. align_corners=True), [0])
  17. #todo: list union
  18. new_sup.append(superpixel)
  19. else:
  20. new_sup.append(None)
  21. return image, label, new_sup

最后你住需要更新下这个函数返回的值到sample就好。

再调试一下看看是不是预处理对了。一般不会出错,因为用的是官方已调试好的带啊吗。

发表评论

表情:
评论列表 (有 0 条评论,23人围观)

还没有评论,来说两句吧...

相关阅读