深度学习Tensorflow 之 MNIST手写数字识别

小咪咪 2022-01-20 15:51 821阅读 0赞

MNIST数据集介绍和下载:http://yann.lecun.com/exdb/mnist/

一、数据集介绍:

MNIST是一个入门级的计算机视觉数据集

下载下来的数据集被分成两部分:60000行的训练数据集(mnist.train)和10000行的测试数据集(mnist.test)

二、TensorFlow实现MNIST手写数字识别

(1)构建一个只有输入层和输出层的简单神经网络模型,使用二次代价函数和梯度下降算法进行优化;代码如下:

  1. #TensorFlow实现MNIST手写数字识别-简单版本
  2. import tensorflow as tf
  3. #Tensorflow提供了一个类来处理MNIST数据
  4. from tensorflow.examples.tutorials.mnist import input_data
  5. #载入数据集
  6. mnist=input_data.read_data_sets('MNIST_data',one_hot=True)
  7. #设置每个批次的大小
  8. batch_size=100
  9. #计算一共有多少个批次
  10. n_batch=mnist.train.num_examples//batch_size
  11. #定义两个placeholder
  12. x=tf.placeholder(tf.float32,[None,784])
  13. y=tf.placeholder(tf.float32,[None,10])
  14. #创建一个简单的神经网络(只有输入层和输出层)
  15. Weights=tf.Variable(tf.zeros([784,10]))
  16. biases=tf.Variable(tf.zeros([10]))
  17. prediction=tf.nn.softmax(tf.matmul(x,Weights)+biases)
  18. #定义代价函数(均方差函数)
  19. loss=tf.reduce_mean(tf.square(y-prediction))
  20. #定义反向传播算法(使用梯度下降算法)
  21. train_step=tf.train.GradientDescentOptimizer(0.2).minimize(loss)
  22. #结果存放在一个布尔型列表中(argmax函数返回一维张量中最大的值所在的位置)
  23. correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
  24. #求准确率(tf.cast将布尔值转换为float型)
  25. accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
  26. #创建会话
  27. with tf.Session() as sess:
  28. sess.run(tf.global_variables_initializer()) #初始化变量
  29. #训练次数
  30. for i in range(21):
  31. for batch in range(n_batch):
  32. batch_xs,batch_ys=mnist.train.next_batch(batch_size)
  33. sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
  34. acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
  35. print("Iter"+str(i)+",Testing Accuracy"+str(acc))

 结果为:

1429709-20180629204009130-945389696.png

(2)模型同上,使用交叉熵函数和梯度下降算法进行优化,

把上面代码的代价函数改为下面的交叉熵代价函数:

  1. loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))

 结果为:

1429709-20180629204126470-950787734.png

(3)构建一个多层的神经网络模型,使用交叉熵函数和梯度下降算法进行优化,添加Dropout防止过拟合;

模型结构如下:

1429709-20180629204204611-1342671438.png

代码如下:

  1. import tensorflow as tf
  2. from tensorflow.examples.tutorials.mnist import input_data
  3. #载入数据集
  4. mnist=input_data.read_data_sets('MNIST_data',one_hot=True)
  5. #设置每个批次的大小
  6. batch_size=100
  7. #计算一共有多少个批次
  8. n_batch=mnist.train.num_examples//batch_size
  9. #定义三个placeholder
  10. x=tf.placeholder(tf.float32,[None,784])
  11. y=tf.placeholder(tf.float32,[None,10])
  12. keep_prob=tf.placeholder(tf.float32) #存放百分率
  13. #创建一个多层神经网络模型
  14. #第一个隐藏层
  15. W1=tf.Variable(tf.truncated_normal([784,2000],stddev=0.1))
  16. b1=tf.Variable(tf.zeros([2000])+0.1)
  17. L1=tf.nn.tanh(tf.matmul(x,W1)+b1)
  18. L1_drop=tf.nn.dropout(L1,keep_prob) #keep_prob设置工作状态神经元的百分率
  19. #第二个隐藏层
  20. W2=tf.Variable(tf.truncated_normal([2000,2000],stddev=0.1))
  21. b2=tf.Variable(tf.zeros([2000])+0.1)
  22. L2=tf.nn.tanh(tf.matmul(L1_drop,W2)+b2)
  23. L2_drop=tf.nn.dropout(L2,keep_prob)
  24. #第三个隐藏层
  25. W3=tf.Variable(tf.truncated_normal([2000,1000],stddev=0.1))
  26. b3=tf.Variable(tf.zeros([1000])+0.1)
  27. L3=tf.nn.tanh(tf.matmul(L2_drop,W3)+b3)
  28. L3_drop=tf.nn.dropout(L3,keep_prob)
  29. #输出层
  30. W4=tf.Variable(tf.truncated_normal([1000,10],stddev=0.1))
  31. b4=tf.Variable(tf.zeros([10])+0.1)
  32. prediction=tf.nn.softmax(tf.matmul(L3_drop,W4)+b4)
  33. #定义交叉熵代价函数
  34. loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
  35. #定义反向传播算法(使用梯度下降算法)
  36. train_step=tf.train.GradientDescentOptimizer(0.2).minimize(loss)
  37. #结果存放在一个布尔型列表中(argmax函数返回一维张量中最大的值所在的位置)
  38. correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
  39. #求准确率(tf.cast将布尔值转换为float型)
  40. accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
  41. #创建会话
  42. with tf.Session() as sess:
  43. sess.run(tf.global_variables_initializer()) #初始化变量
  44. #训练次数
  45. for i in range(21):
  46. for batch in range(n_batch):
  47. batch_xs,batch_ys=mnist.train.next_batch(batch_size)
  48. sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:1.0})
  49. #测试数据计算出的准确率
  50. test_acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
  51. print("Iter"+str(i)+",Testing Accuracy"+str(test_acc))

结果为:

1429709-20180629204309318-1726688130.png

发表评论

表情:
评论列表 (有 0 条评论,821人围观)

还没有评论,来说两句吧...

相关阅读