神经网络学习小记录8——利用Keras进行简单RNN网络构建

谁借莪1个温暖的怀抱¢ 2023-08-17 16:22 251阅读 0赞

神经网络学习小记录8——利用Keras进行简单RNN网络构建

  • 学习前言
  • Keras中构建RNN的重要函数
    • 1、SimpleRNN
    • 2、model.train_on_batch
  • 全部代码

学习前言

还要会RNN噢。
在这里插入图片描述

Keras中构建RNN的重要函数

1、SimpleRNN

SimpleRNN用于在Keras中构建普通的简单RNN层,在使用前需要import。

  1. from keras.layers import SimpleRNN

在实际使用时,需要用到几个参数。

  1. model.add(
  2. SimpleRNN(
  3. batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
  4. output_dim = CELL_SIZE,
  5. )
  6. )

其中,batch_input_shape代表RNN输入数据的shape,shape的内容分别是每一次训练使用的BATCH,TIME_STEPS表示这个RNN按顺序输入的时间点的数量,INPUT_SIZE表示每一个时间点的输入数据大小。
CELL_SIZE代表训练每一个时间点的神经元数量。

2、model.train_on_batch

与之前的训练CNN网络和普通分类网络不同,RNN网络在建立时就规定了batch_input_shape,所以训练的时候也需要一定量一定量的传入训练数据。
model.train_on_batch在使用前需要对数据进行处理。获取指定BATCH大小的训练集。

  1. X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
  2. Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
  3. index_start += BATCH_SIZE

具体训练过程如下:

  1. for i in range(500):
  2. X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
  3. Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
  4. index_start += BATCH_SIZE
  5. cost = model.train_on_batch(X_batch,Y_batch)
  6. if index_start >= X_train.shape[0]:
  7. index_start = 0
  8. if i%100 == 0:
  9. ## acc
  10. cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
  11. ## W,b = model.layers[0].get_weights()
  12. print("accuracy:",accuracy)
  13. x = X_test[1].reshape(1,28,28)

全部代码

这是一个RNN神经网络的例子,用于识别手写体。(源自莫烦Python)

  1. import numpy as np
  2. from keras.models import Sequential
  3. from keras.layers import SimpleRNN,Activation,Dense ## 全连接层
  4. from keras.datasets import mnist
  5. from keras.utils import np_utils
  6. from keras.optimizers import Adam
  7. TIME_STEPS = 28
  8. INPUT_SIZE = 28
  9. BATCH_SIZE = 50
  10. index_start = 0
  11. OUTPUT_SIZE = 10
  12. CELL_SIZE = 75
  13. LR = 1e-3
  14. (X_train,Y_train),(X_test,Y_test) = mnist.load_data()
  15. X_train = X_train.reshape(-1,28,28)/255
  16. X_test = X_test.reshape(-1,28,28)/255
  17. Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
  18. Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
  19. model = Sequential()
  20. # conv1
  21. model.add(
  22. SimpleRNN(
  23. batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
  24. output_dim = CELL_SIZE,
  25. )
  26. )
  27. model.add(Dense(OUTPUT_SIZE))
  28. model.add(Activation("softmax"))
  29. adam = Adam(LR)
  30. ## compile
  31. model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])
  32. ## tarin
  33. for i in range(500):
  34. X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
  35. Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
  36. index_start += BATCH_SIZE
  37. cost = model.train_on_batch(X_batch,Y_batch)
  38. if index_start >= X_train.shape[0]:
  39. index_start = 0
  40. if i%100 == 0:
  41. ## acc
  42. cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
  43. ## W,b = model.layers[0].get_weights()
  44. print("accuracy:",accuracy)

实验结果为:

  1. 10000/10000 [==============================] - 1s 147us/step
  2. accuracy: 0.09329999938607215
  3. …………………………
  4. 10000/10000 [==============================] - 1s 112us/step
  5. accuracy: 0.9395000022649765
  6. 10000/10000 [==============================] - 1s 109us/step
  7. accuracy: 0.9422999995946885
  8. 10000/10000 [==============================] - 1s 114us/step
  9. accuracy: 0.9534000000357628
  10. 10000/10000 [==============================] - 1s 112us/step
  11. accuracy: 0.9566000008583069
  12. 10000/10000 [==============================] - 1s 113us/step
  13. accuracy: 0.950799999833107
  14. 10000/10000 [==============================] - 1s 116us/step
  15. 10000/10000 [==============================] - 1s 112us/step
  16. accuracy: 0.9474999988079071
  17. 10000/10000 [==============================] - 1s 111us/step
  18. accuracy: 0.9515000003576278
  19. 10000/10000 [==============================] - 1s 114us/step
  20. accuracy: 0.9288999977707862
  21. 10000/10000 [==============================] - 1s 115us/step
  22. accuracy: 0.9487999993562698

有不懂的朋友可以评论询问噢。

发表评论

表情:
评论列表 (有 0 条评论,251人围观)

还没有评论,来说两句吧...

相关阅读