X2Paddle实践之——Tensorflow版本VGG模型转化为paddle版本模型

╰半夏微凉° 2022-10-17 13:51 182阅读 0赞

最近经常会刷到很多关于paddle的文章,说实在的paddle我之前读研的时候还是有用过的,后面就没怎么再用过了,经过几年的捶打完善,感觉现在已经是一款不错的深度学习框架了,是时候捡起来学习了解一下了。

PaddlePaddle地址在这里,首页截图如下所示:

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L1RvZ2V0aGVyX0Na_size_16_color_FFFFFF_t_70

上面图中的每一个项目都是当前很火的项目,不过我们今天的主要学习内容不在上面,之前也做过不同框架之间模型转化的工作,今天也是这样的内容,因为PaddlePaddle自家提供了类似的开源工具X2Paddle,顾名思义,就是提供其他框架转化为PaddlePaddle自家模型文件的工具。

话不多说,这里先来看X2Paddle,地址在这里,首页如下:

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L1RvZ2V0aGVyX0Na_size_16_color_FFFFFF_t_70 1

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L1RvZ2V0aGVyX0Na_size_16_color_FFFFFF_t_70 2

感兴趣的话可以去仔细了解一下,我们今天主要是借助具体的实践来学习一下具体的操作使用,VGG_16是CV领域的一个经典模型,这里就以VGG_16为例,给大家展示如何将TensorFlow训练好的模型转换为飞桨模型。

首先下载模型,地址如下:

  1. http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz

之后解压缩,截图如下所示:

20210603162945408.png

接下来我们需要重新加载参数,并将网络结构和参数一起保存为checkpoint模型,代码如下:

  1. with tf.Session() as sess:
  2. inputs = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name="inputs")
  3. logits, endpoint = vgg.vgg_16(inputs, num_classes=1000, is_training=False)
  4. load_model = slim.assign_from_checkpoint_fn("vgg_16.ckpt", slim.get_model_variables("vgg_16"))
  5. load_model(sess)
  6. data = numpy.random.rand(5, 224, 224, 3)
  7. input_tensor = sess.graph.get_tensor_by_name("inputs:0")
  8. output_tensor = sess.graph.get_tensor_by_name("vgg_16/fc8/squeezed:0")
  9. result = sess.run([output_tensor], {input_tensor:data})
  10. numpy.save("tensorflow.npy", numpy.array(result))
  11. saver = tf.train.Saver()
  12. saver.save(sess, "checkpoint/model")

完成上述处理,我们就可以着手将模型转为PaddlePaddle模型了,代码如下:

  1. parser = convert._get_parser()
  2. parser.meta_file = "checkpoint/model.meta"
  3. parser.ckpt_dir = "checkpoint"
  4. parser.in_nodes = ["inputs"]
  5. parser.input_shape = ["None,224,224,3"]
  6. parser.output_nodes = ["vgg_16/fc8/squeezed"]
  7. parser.use_cuda = "True"
  8. parser.input_format = "NHWC"
  9. parser.save_dir = "paddle_model"
  10. convert.run(parser)

由于是初次学习使用,这里我附上完整的日志输出,如下:

  1. Use tf.compat.v1.graph_util.extract_sub_graph
  2. INFO:root:Tensorflow model loaded!
  3. INFO:root:TotalNum:86,TraslatedNum:1,CurrentNode:inputs
  4. INFO:root:TotalNum:86,TraslatedNum:2,CurrentNode:vgg_16/conv1/conv1_1/weights
  5. INFO:root:TotalNum:86,TraslatedNum:3,CurrentNode:vgg_16/conv1/conv1_1/biases
  6. INFO:root:TotalNum:86,TraslatedNum:4,CurrentNode:vgg_16/conv1/conv1_2/weights
  7. INFO:root:TotalNum:86,TraslatedNum:5,CurrentNode:vgg_16/conv1/conv1_2/biases
  8. INFO:root:TotalNum:86,TraslatedNum:6,CurrentNode:vgg_16/conv2/conv2_1/weights
  9. INFO:root:TotalNum:86,TraslatedNum:7,CurrentNode:vgg_16/conv2/conv2_1/biases
  10. INFO:root:TotalNum:86,TraslatedNum:8,CurrentNode:vgg_16/conv2/conv2_2/weights
  11. INFO:root:TotalNum:86,TraslatedNum:9,CurrentNode:vgg_16/conv2/conv2_2/biases
  12. INFO:root:TotalNum:86,TraslatedNum:10,CurrentNode:vgg_16/conv3/conv3_1/weights
  13. INFO:root:TotalNum:86,TraslatedNum:11,CurrentNode:vgg_16/conv3/conv3_1/biases
  14. INFO:root:TotalNum:86,TraslatedNum:12,CurrentNode:vgg_16/conv3/conv3_2/weights
  15. INFO:root:TotalNum:86,TraslatedNum:13,CurrentNode:vgg_16/conv3/conv3_2/biases
  16. INFO:root:TotalNum:86,TraslatedNum:14,CurrentNode:vgg_16/conv3/conv3_3/weights
  17. INFO:root:TotalNum:86,TraslatedNum:15,CurrentNode:vgg_16/conv3/conv3_3/biases
  18. INFO:root:TotalNum:86,TraslatedNum:16,CurrentNode:vgg_16/conv4/conv4_1/weights
  19. INFO:root:TotalNum:86,TraslatedNum:17,CurrentNode:vgg_16/conv4/conv4_1/biases
  20. INFO:root:TotalNum:86,TraslatedNum:18,CurrentNode:vgg_16/conv4/conv4_2/weights
  21. INFO:root:TotalNum:86,TraslatedNum:19,CurrentNode:vgg_16/conv4/conv4_2/biases
  22. INFO:root:TotalNum:86,TraslatedNum:20,CurrentNode:vgg_16/conv4/conv4_3/weights
  23. INFO:root:TotalNum:86,TraslatedNum:21,CurrentNode:vgg_16/conv4/conv4_3/biases
  24. INFO:root:TotalNum:86,TraslatedNum:22,CurrentNode:vgg_16/conv5/conv5_1/weights
  25. INFO:root:TotalNum:86,TraslatedNum:23,CurrentNode:vgg_16/conv5/conv5_1/biases
  26. INFO:root:TotalNum:86,TraslatedNum:24,CurrentNode:vgg_16/conv5/conv5_2/weights
  27. INFO:root:TotalNum:86,TraslatedNum:25,CurrentNode:vgg_16/conv5/conv5_2/biases
  28. INFO:root:TotalNum:86,TraslatedNum:26,CurrentNode:vgg_16/conv5/conv5_3/weights
  29. INFO:root:TotalNum:86,TraslatedNum:27,CurrentNode:vgg_16/conv5/conv5_3/biases
  30. INFO:root:TotalNum:86,TraslatedNum:28,CurrentNode:vgg_16/fc6/weights
  31. INFO:root:TotalNum:86,TraslatedNum:29,CurrentNode:vgg_16/fc6/biases
  32. INFO:root:TotalNum:86,TraslatedNum:30,CurrentNode:vgg_16/fc7/weights
  33. INFO:root:TotalNum:86,TraslatedNum:31,CurrentNode:vgg_16/fc7/biases
  34. INFO:root:TotalNum:86,TraslatedNum:32,CurrentNode:vgg_16/fc8/weights
  35. INFO:root:TotalNum:86,TraslatedNum:33,CurrentNode:vgg_16/fc8/biases
  36. INFO:root:TotalNum:86,TraslatedNum:34,CurrentNode:vgg_16/conv1/conv1_1/Conv2D
  37. INFO:root:TotalNum:86,TraslatedNum:35,CurrentNode:vgg_16/conv1/conv1_1/BiasAdd
  38. INFO:root:TotalNum:86,TraslatedNum:36,CurrentNode:vgg_16/conv1/conv1_1/Relu
  39. INFO:root:TotalNum:86,TraslatedNum:37,CurrentNode:vgg_16/conv1/conv1_2/Conv2D
  40. INFO:root:TotalNum:86,TraslatedNum:38,CurrentNode:vgg_16/conv1/conv1_2/BiasAdd
  41. INFO:root:TotalNum:86,TraslatedNum:39,CurrentNode:vgg_16/conv1/conv1_2/Relu
  42. INFO:root:TotalNum:86,TraslatedNum:40,CurrentNode:vgg_16/pool1/MaxPool
  43. INFO:root:TotalNum:86,TraslatedNum:41,CurrentNode:vgg_16/conv2/conv2_1/Conv2D
  44. INFO:root:TotalNum:86,TraslatedNum:42,CurrentNode:vgg_16/conv2/conv2_1/BiasAdd
  45. INFO:root:TotalNum:86,TraslatedNum:43,CurrentNode:vgg_16/conv2/conv2_1/Relu
  46. INFO:root:TotalNum:86,TraslatedNum:44,CurrentNode:vgg_16/conv2/conv2_2/Conv2D
  47. INFO:root:TotalNum:86,TraslatedNum:45,CurrentNode:vgg_16/conv2/conv2_2/BiasAdd
  48. INFO:root:TotalNum:86,TraslatedNum:46,CurrentNode:vgg_16/conv2/conv2_2/Relu
  49. INFO:root:TotalNum:86,TraslatedNum:47,CurrentNode:vgg_16/pool2/MaxPool
  50. INFO:root:TotalNum:86,TraslatedNum:48,CurrentNode:vgg_16/conv3/conv3_1/Conv2D
  51. INFO:root:TotalNum:86,TraslatedNum:49,CurrentNode:vgg_16/conv3/conv3_1/BiasAdd
  52. INFO:root:TotalNum:86,TraslatedNum:50,CurrentNode:vgg_16/conv3/conv3_1/Relu
  53. INFO:root:TotalNum:86,TraslatedNum:51,CurrentNode:vgg_16/conv3/conv3_2/Conv2D
  54. INFO:root:TotalNum:86,TraslatedNum:52,CurrentNode:vgg_16/conv3/conv3_2/BiasAdd
  55. INFO:root:TotalNum:86,TraslatedNum:53,CurrentNode:vgg_16/conv3/conv3_2/Relu
  56. INFO:root:TotalNum:86,TraslatedNum:54,CurrentNode:vgg_16/conv3/conv3_3/Conv2D
  57. INFO:root:TotalNum:86,TraslatedNum:55,CurrentNode:vgg_16/conv3/conv3_3/BiasAdd
  58. INFO:root:TotalNum:86,TraslatedNum:56,CurrentNode:vgg_16/conv3/conv3_3/Relu
  59. INFO:root:TotalNum:86,TraslatedNum:57,CurrentNode:vgg_16/pool3/MaxPool
  60. INFO:root:TotalNum:86,TraslatedNum:58,CurrentNode:vgg_16/conv4/conv4_1/Conv2D
  61. INFO:root:TotalNum:86,TraslatedNum:59,CurrentNode:vgg_16/conv4/conv4_1/BiasAdd
  62. INFO:root:TotalNum:86,TraslatedNum:60,CurrentNode:vgg_16/conv4/conv4_1/Relu
  63. INFO:root:TotalNum:86,TraslatedNum:61,CurrentNode:vgg_16/conv4/conv4_2/Conv2D
  64. INFO:root:TotalNum:86,TraslatedNum:62,CurrentNode:vgg_16/conv4/conv4_2/BiasAdd
  65. INFO:root:TotalNum:86,TraslatedNum:63,CurrentNode:vgg_16/conv4/conv4_2/Relu
  66. INFO:root:TotalNum:86,TraslatedNum:64,CurrentNode:vgg_16/conv4/conv4_3/Conv2D
  67. INFO:root:TotalNum:86,TraslatedNum:65,CurrentNode:vgg_16/conv4/conv4_3/BiasAdd
  68. INFO:root:TotalNum:86,TraslatedNum:66,CurrentNode:vgg_16/conv4/conv4_3/Relu
  69. INFO:root:TotalNum:86,TraslatedNum:67,CurrentNode:vgg_16/pool4/MaxPool
  70. INFO:root:TotalNum:86,TraslatedNum:68,CurrentNode:vgg_16/conv5/conv5_1/Conv2D
  71. INFO:root:TotalNum:86,TraslatedNum:69,CurrentNode:vgg_16/conv5/conv5_1/BiasAdd
  72. INFO:root:TotalNum:86,TraslatedNum:70,CurrentNode:vgg_16/conv5/conv5_1/Relu
  73. INFO:root:TotalNum:86,TraslatedNum:71,CurrentNode:vgg_16/conv5/conv5_2/Conv2D
  74. INFO:root:TotalNum:86,TraslatedNum:72,CurrentNode:vgg_16/conv5/conv5_2/BiasAdd
  75. INFO:root:TotalNum:86,TraslatedNum:73,CurrentNode:vgg_16/conv5/conv5_2/Relu
  76. INFO:root:TotalNum:86,TraslatedNum:74,CurrentNode:vgg_16/conv5/conv5_3/Conv2D
  77. INFO:root:TotalNum:86,TraslatedNum:75,CurrentNode:vgg_16/conv5/conv5_3/BiasAdd
  78. INFO:root:TotalNum:86,TraslatedNum:76,CurrentNode:vgg_16/conv5/conv5_3/Relu
  79. INFO:root:TotalNum:86,TraslatedNum:77,CurrentNode:vgg_16/pool5/MaxPool
  80. INFO:root:TotalNum:86,TraslatedNum:78,CurrentNode:vgg_16/fc6/Conv2D
  81. INFO:root:TotalNum:86,TraslatedNum:79,CurrentNode:vgg_16/fc6/BiasAdd
  82. INFO:root:TotalNum:86,TraslatedNum:80,CurrentNode:vgg_16/fc6/Relu
  83. INFO:root:TotalNum:86,TraslatedNum:81,CurrentNode:vgg_16/fc7/Conv2D
  84. INFO:root:TotalNum:86,TraslatedNum:82,CurrentNode:vgg_16/fc7/BiasAdd
  85. INFO:root:TotalNum:86,TraslatedNum:83,CurrentNode:vgg_16/fc7/Relu
  86. INFO:root:TotalNum:86,TraslatedNum:84,CurrentNode:vgg_16/fc8/Conv2D
  87. INFO:root:TotalNum:86,TraslatedNum:85,CurrentNode:vgg_16/fc8/BiasAdd
  88. INFO:root:TotalNum:86,TraslatedNum:86,CurrentNode:vgg_16/fc8/squeezed
  89. INFO:root:Model translated!

执行成功截图如下所示:

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L1RvZ2V0aGVyX0Na_size_16_color_FFFFFF_t_70 3

转化成功后会出现 paddle_model 目录,如下:

20210603163549713.png

截图如下所示:

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L1RvZ2V0aGVyX0Na_size_16_color_FFFFFF_t_70 4

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L1RvZ2V0aGVyX0Na_size_16_color_FFFFFF_t_70 5

转化的工作到这里其实就已经结束了,但是一般的转化都逃不开精度损失的问题,这里需要做一下验证计算。代码如下:

  1. model = ml.ModelLoader("paddle_model", use_cuda=False)
  2. numpy.random.seed(13)
  3. data = numpy.random.rand(5, 224, 224, 3).astype("float32")
  4. # NHWC -> NCHW
  5. data = numpy.transpose(data, (0, 3, 1, 2))
  6. results = model.inference(feed_dict={model.inputs[0]:data})
  7. numpy.save("paddle.npy", numpy.array(results))

执行的时候报错了,错误信息如下:

  1. AssertionError: In PaddlePaddle 2.x, we turn on dynamic graph mode by default, and 'data()' is only supported in static graph mode. So if you want to use this api, please call 'paddle.enable_static()' before this api to enter static graph mode.

代码更正为下面的,其实主要就是静态图的问题。

  1. paddle.enable_static()
  2. model = ml.ModelLoader("paddle_model", use_cuda=False)
  3. numpy.random.seed(13)
  4. data = numpy.random.rand(5, 224, 224, 3).astype("float32")
  5. # NHWC -> NCHW
  6. data = numpy.transpose(data, (0, 3, 1, 2))
  7. results = model.inference(feed_dict={model.inputs[0]:data})
  8. numpy.save("paddle.npy", numpy.array(results))

之后计算一下模型的损失,通过把两个模型文件加载进来后,通过numpy.fabs来求两个模型结果的差异即可,如下:

  1. paddle_result = numpy.load("paddle.npy")
  2. tensorflow_result = numpy.load("tensorflow.npy")
  3. diff = numpy.fabs(paddle_result - tensorflow_result)
  4. print(numpy.max(diff))

结果如下所示:

20210603163950947.png

误差的数据还是比较小的,在可接受的范围内。

最后附上完整代码,如下:

  1. #!usr/bin/env python
  2. #encoding:utf-8
  3. from __future__ import division
  4. '''
  5. __Author__:沂水寒城
  6. 功能: VGG Tensorflow 转 Paddle 实战
  7. '''
  8. import os
  9. import tensorflow.contrib.slim as slim
  10. from tensorflow.contrib.slim.nets import vgg
  11. import tensorflow as tf
  12. import numpy
  13. '''
  14. 保存模型为checkpoint格式
  15. '''
  16. with tf.Session() as sess:
  17. inputs = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name="inputs")
  18. logits, endpoint = vgg.vgg_16(inputs, num_classes=1000, is_training=False)
  19. load_model = slim.assign_from_checkpoint_fn("vgg_16.ckpt", slim.get_model_variables("vgg_16"))
  20. load_model(sess)
  21. data = numpy.random.rand(5, 224, 224, 3)
  22. input_tensor = sess.graph.get_tensor_by_name("inputs:0")
  23. output_tensor = sess.graph.get_tensor_by_name("vgg_16/fc8/squeezed:0")
  24. result = sess.run([output_tensor], {input_tensor:data})
  25. numpy.save("tensorflow.npy", numpy.array(result))
  26. saver = tf.train.Saver()
  27. saver.save(sess, "checkpoint/model")
  28. '''
  29. 将模型转换为飞桨模型
  30. '''
  31. import tf2fluid.convert as convert
  32. import argparse
  33. parser = convert._get_parser()
  34. parser.meta_file = "checkpoint/model.meta"
  35. parser.ckpt_dir = "checkpoint"
  36. parser.in_nodes = ["inputs"]
  37. parser.input_shape = ["None,224,224,3"]
  38. parser.output_nodes = ["vgg_16/fc8/squeezed"]
  39. parser.use_cuda = "True"
  40. parser.input_format = "NHWC"
  41. parser.save_dir = "paddle_model"
  42. convert.run(parser)
  43. '''
  44. 预测结果差异对比
  45. '''
  46. import numpy
  47. import paddle
  48. import tf2fluid.model_loader as ml
  49. paddle.enable_static()
  50. model = ml.ModelLoader("paddle_model", use_cuda=False)
  51. numpy.random.seed(13)
  52. data = numpy.random.rand(5, 224, 224, 3).astype("float32")
  53. # NHWC -> NCHW
  54. data = numpy.transpose(data, (0, 3, 1, 2))
  55. results = model.inference(feed_dict={model.inputs[0]:data})
  56. numpy.save("paddle.npy", numpy.array(results))
  57. '''
  58. 对比模型损失
  59. '''
  60. import numpy
  61. paddle_result = numpy.load("paddle.npy")
  62. tensorflow_result = numpy.load("tensorflow.npy")
  63. diff = numpy.fabs(paddle_result - tensorflow_result)
  64. print(numpy.max(diff))

初次的实践使用就到这里了,后面有时间继续深入研究下。

发表评论

表情:
评论列表 (有 0 条评论,182人围观)

还没有评论,来说两句吧...

相关阅读