神经网络学习小记录12——批量学习tf.train.batch

向右看齐 2023-06-06 14:28 139阅读 0赞

神经网络学习小记录12——批量学习tf.train.batch

  • 学习前言
  • tf.train.batch函数
  • 测试代码
    • 1、allow_samller_final_batch=True
    • 2、allow_samller_final_batch=False

学习前言

当我在快乐的学习SSD训练部分的时候,我发现了一个batch我看不太懂,主要是因为tfrecords的数据读取方式我不理解,所以好好学一下batch吧!
在这里插入图片描述

tf.train.batch函数

  1. tf.train.batch(
  2. tensors,
  3. batch_size,
  4. num_threads=1,
  5. capacity=32,
  6. enqueue_many=False,
  7. shapes=None,
  8. dynamic_pad=False,
  9. allow_smaller_final_batch=False,
  10. shared_name=None,
  11. name=None
  12. )

其中:
1、tensors:利用slice_input_producer获得的数据组合。
2、batch_size:设置每次从队列中获取出队数据的数量。
3、num_threads:用来控制线程的数量,如果其值不唯一,由于线程执行的特性,数据获取可能变成乱序。
4、capacity:一个整数,用来设置队列中元素的最大数量
5、allow_samller_final_batch:当其为True时,如果队列中的样本数量小于batch_size,出队的数量会以最终遗留下来的样本进行出队;当其为False时,小于batch_size的样本不会做出队处理。
6、name:名字

测试代码

1、allow_samller_final_batch=True

  1. import pandas as pd
  2. import numpy as np
  3. import tensorflow as tf
  4. # 生成数据
  5. def generate_data():
  6. num = 18
  7. label = np.arange(num)
  8. return label
  9. # 获取数据
  10. def get_batch_data():
  11. label = generate_data()
  12. input_queue = tf.train.slice_input_producer([label], shuffle=False,num_epochs=2)
  13. label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=True)
  14. return label_batch
  15. # 数据组
  16. label = get_batch_data()
  17. sess = tf.Session()
  18. # 初始化变量
  19. sess.run(tf.global_variables_initializer())
  20. sess.run(tf.local_variables_initializer())
  21. # 初始化batch训练的参数
  22. coord = tf.train.Coordinator()
  23. threads = tf.train.start_queue_runners(sess,coord)
  24. try:
  25. while not coord.should_stop():
  26. # 自动获取下一组数据
  27. l = sess.run(label)
  28. print(l)
  29. except tf.errors.OutOfRangeError:
  30. print('Done training')
  31. finally:
  32. coord.request_stop()
  33. coord.join(threads)
  34. sess.close()

运行结果为:

  1. [0 1 2 3 4]
  2. [5 6 7 8 9]
  3. [10 11 12 13 14]
  4. [15 16 17 0 1]
  5. [2 3 4 5 6]
  6. [ 7 8 9 10 11]
  7. [12 13 14 15 16]
  8. [17]
  9. Done training

2、allow_samller_final_batch=False

相比allow_samller_final_batch=True,输出结果少了[17]

  1. import pandas as pd
  2. import numpy as np
  3. import tensorflow as tf
  4. # 生成数据
  5. def generate_data():
  6. num = 18
  7. label = np.arange(num)
  8. return label
  9. # 获取数据
  10. def get_batch_data():
  11. label = generate_data()
  12. input_queue = tf.train.slice_input_producer([label], shuffle=False,num_epochs=2)
  13. label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=False)
  14. return label_batch
  15. # 数据组
  16. label = get_batch_data()
  17. sess = tf.Session()
  18. # 初始化变量
  19. sess.run(tf.global_variables_initializer())
  20. sess.run(tf.local_variables_initializer())
  21. # 初始化batch训练的参数
  22. coord = tf.train.Coordinator()
  23. threads = tf.train.start_queue_runners(sess,coord)
  24. try:
  25. while not coord.should_stop():
  26. # 自动获取下一组数据
  27. l = sess.run(label)
  28. print(l)
  29. except tf.errors.OutOfRangeError:
  30. print('Done training')
  31. finally:
  32. coord.request_stop()
  33. coord.join(threads)
  34. sess.close()

运行结果为:

  1. [0 1 2 3 4]
  2. [5 6 7 8 9]
  3. [10 11 12 13 14]
  4. [15 16 17 0 1]
  5. [2 3 4 5 6]
  6. [ 7 8 9 10 11]
  7. [12 13 14 15 16]
  8. Done training

发表评论

表情:
评论列表 (有 0 条评论,139人围观)

还没有评论,来说两句吧...

相关阅读