基于tensorflow的手写数字识别

柔情只为你懂 2022-11-21 04:30 403阅读 0赞
  1. import numpy as np
  2. #import tensorflow as tf
  3. import tensorflow.compat.v1 as tf
  4. tf.disable_v2_behavior() #解决tf.placeholder报错问题
  5. import matplotlib.pyplot as plt
  6. import input_data #使用的数据库是tensorflow内置数据库,可下载到本地
  7. mnist = input_data.read_data_sets('data/',one_hot=True)
  8. #network topologies 网络拓扑
  9. n_hidden_1 = 256
  10. n_hidden_2 = 128
  11. n_input = 784
  12. n_classes = 10
  13. #inputs and outputs 输入 输出
  14. x = tf.placeholder("float",[None,n_input])
  15. y = tf.placeholder("float",[None,n_classes])
  16. #network parameters 网络参数
  17. stddev = 0.1
  18. weights = {
  19. 'w1':tf.Variable(tf.random_normal([n_input,n_hidden_1],stddev=stddev)),
  20. 'w2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2],stddev=stddev)),
  21. 'out':tf.Variable(tf.random_normal([n_hidden_2,n_classes],stddev=stddev))
  22. }
  23. biases = {
  24. 'b1':tf.Variable(tf.random_normal([n_hidden_1])),
  25. 'b2':tf.Variable(tf.random_normal([n_hidden_2])),
  26. 'out':tf.Variable(tf.random_normal([n_classes]))
  27. }
  28. print("NETWORK READY")
  29. def multilayer_perceptron(_X,_weights,_biases):
  30. layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(_X,_weights['w1']),_biases['b1']))
  31. layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1,_weights['w2']),_biases['b2']))
  32. return (tf.matmul(layer_2,_weights['out'])+_biases['out'])
  33. #prediction
  34. pred = multilayer_perceptron(x,weights,biases)
  35. #loss and optimizer 损失函数及优化器
  36. cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
  37. optm = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(cost)
  38. corr = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
  39. accr = tf.reduce_mean(tf.cast(corr,"float"))
  40. #initializer
  41. init = tf.global_variables_initializer()
  42. print("FUNCTIONS READY")
  43. #迭代
  44. training_epochs = 20
  45. batch_size = 100
  46. display_step = 4
  47. #launch the graph
  48. sess = tf.Session()
  49. sess.run(init)
  50. #optimize
  51. for epoch in range(training_epochs):
  52. avg_cost = 0.
  53. total_batch = int(mnist.train.num_examples/batch_size)
  54. #iteration
  55. for i in range(total_batch):
  56. batch_xs,batch_ys = mnist.train.next_batch(batch_size)
  57. feeds = {
  58. x:batch_xs,y:batch_ys}
  59. sess.run(optm,feed_dict=feeds)
  60. avg_cost +=sess.run(cost,feed_dict=feeds)
  61. avg_cost = avg_cost/total_batch
  62. #display
  63. if (epoch+1)%display_step==0:
  64. print("Epoch:%03d/%03d cost:%.9f"%(epoch,training_epochs,avg_cost))
  65. feeds = {
  66. x:batch_xs,y:batch_ys}
  67. training_acc = sess.run(accr,feed_dict=feeds)
  68. print("Train Accuracy:%.3f"%(training_acc))
  69. feeds = {
  70. x:mnist.test.images,y:mnist.test.labels}
  71. test_acc = sess.run(accr,feed_dict=feeds)
  72. print("Test Accuracy:%.3f"%(test_acc))
  73. print("Optimization Finished")

发表评论

表情:
评论列表 (有 0 条评论,403人围观)

还没有评论,来说两句吧...

相关阅读

    相关 Tensorflow实战02-数字识别

    > 注:学习完深度学习的基本知识后,本练习可帮助大家逐步从理论转向实战,如果对深度学习还不太了解,前先学习深度学习的基本知识以及原理 在本课程中,您学习了如何使用时装M