使用tf.keras搭建mnist手写数字识别网络

阳光穿透心脏的1/2处 2022-03-21 06:39 568阅读 0赞

使用tf.keras搭建mnist手写数字识别网络

目录

使用tf.keras搭建mnist手写数字识别网络

1.使用tf.keras.Sequential搭建序列模型

1.1 tf.keras.Sequential 模型

1.2 搭建mnist手写数字识别网络(序列模型)

1.3 完整的训练代码

2.构建高级模型

2.1 函数式 API

2.2 搭建mnist手写数字识别网络(函数式 API)

2.3 完整的训练代码

  1. tf.keras高级应用:回调 tf.keras.callbacks.Callback

  2. tf.keras高级应用:自定义层tf.keras.layers.Layer


1.使用tf.keras.Sequential搭建序列模型

1.1 tf.keras.Sequential 模型

在 Keras 中,您可以通过组合层来构建模型。模型(通常)是由层构成的图。最常见的模型类型是层的堆叠:tf.keras.Sequential 模型。要构建一个简单的全连接网络(即多层感知器),请运行以下代码:

  1. model = keras.Sequential()
  2. # Adds a densely-connected layer with 64 units to the model:
  3. model.add(keras.layers.Dense(64, activation='relu'))
  4. # Add another:
  5. model.add(keras.layers.Dense(64, activation='relu'))
  6. # Add a softmax layer with 10 output units:
  7. model.add(keras.layers.Dense(10, activation='softmax'))

图文例子:

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2d1eXVlYWxpYW4_size_16_color_FFFFFF_t_70

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2d1eXVlYWxpYW4_size_16_color_FFFFFF_t_70 1

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2d1eXVlYWxpYW4_size_16_color_FFFFFF_t_70 2

1.2 搭建mnist手写数字识别网络(序列模型)

  1. def mnist_cnn(input_shape):
  2. '''
  3. 构建一个CNN网络模型
  4. :param input_shape: 指定输入维度
  5. :return:
  6. '''
  7. model=keras.Sequential()
  8. model.add(keras.layers.Conv2D(filters=32,kernel_size = 5,strides = (1,1),
  9. padding = 'same',activation = tf.nn.relu,input_shape = input_shape))
  10. model.add(keras.layers.MaxPool2D(pool_size=(2,2), strides = (2,2), padding = 'valid'))
  11. model.add(keras.layers.Conv2D(filters=64,kernel_size = 3,strides = (1,1),padding = 'same',activation = tf.nn.relu))
  12. model.add(keras.layers.MaxPool2D(pool_size=(2,2), strides = (2,2), padding = 'valid'))
  13. model.add(keras.layers.Dropout(0.25))
  14. model.add(keras.layers.Flatten())
  15. model.add(keras.layers.Dense(units=128,activation = tf.nn.relu))
  16. model.add(keras.layers.Dropout(0.5))
  17. model.add(keras.layers.Dense(units=10,activation = tf.nn.softmax))
  18. return model

1.3 完整的训练代码

  1. # -*-coding: utf-8 -*-
  2. """
  3. @Project: tensorflow-yolov3
  4. @File : keras_mnist.py
  5. @Author : panjq
  6. @E-mail : pan_jinquan@163.com
  7. @Date : 2019-01-31 09:30:12
  8. """
  9. import tensorflow as tf
  10. from tensorflow import keras
  11. import matplotlib.pyplot as plt
  12. import numpy as np
  13. mnist=keras.datasets.mnist
  14. def get_train_val(mnist_path):
  15. # mnist下载地址:https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
  16. (train_images, train_labels), (test_images, test_labels) = mnist.load_data(mnist_path)
  17. print("train_images nums:{}".format(len(train_images)))
  18. print("test_images nums:{}".format(len(test_images)))
  19. return train_images, train_labels, test_images, test_labels
  20. def show_mnist(images,labels):
  21. for i in range(25):
  22. plt.subplot(5,5,i+1)
  23. plt.xticks([])
  24. plt.yticks([ ])
  25. plt.grid(False)
  26. plt.imshow(images[i],cmap=plt.cm.gray)
  27. plt.xlabel(str(labels[i]))
  28. plt.show()
  29. def one_hot(labels):
  30. onehot_labels=np.zeros(shape=[len(labels),10])
  31. for i in range(len(labels)):
  32. index=labels[i]
  33. onehot_labels[i][index]=1
  34. return onehot_labels
  35. def mnist_net(input_shape):
  36. '''
  37. 构建一个简单的全连接层网络模型:
  38. 输入层为28x28=784个输入节点
  39. 隐藏层120个节点
  40. 输出层10个节点
  41. :param input_shape: 指定输入维度
  42. :return:
  43. '''
  44. model = keras.Sequential()
  45. model.add(keras.layers.Flatten(input_shape=input_shape)) #输出层
  46. model.add(keras.layers.Dense(units=120, activation=tf.nn.relu)) #隐含层
  47. model.add(keras.layers.Dense(units=10, activation=tf.nn.softmax))#输出层
  48. return model
  49. def mnist_cnn(input_shape):
  50. '''
  51. 构建一个CNN网络模型
  52. :param input_shape: 指定输入维度
  53. :return:
  54. '''
  55. model=keras.Sequential()
  56. model.add(keras.layers.Conv2D(filters=32,kernel_size = 5,strides = (1,1),
  57. padding = 'same',activation = tf.nn.relu,input_shape = input_shape))
  58. model.add(keras.layers.MaxPool2D(pool_size=(2,2), strides = (2,2), padding = 'valid'))
  59. model.add(keras.layers.Conv2D(filters=64,kernel_size = 3,strides = (1,1),padding = 'same',activation = tf.nn.relu))
  60. model.add(keras.layers.MaxPool2D(pool_size=(2,2), strides = (2,2), padding = 'valid'))
  61. model.add(keras.layers.Dropout(0.25))
  62. model.add(keras.layers.Flatten())
  63. model.add(keras.layers.Dense(units=128,activation = tf.nn.relu))
  64. model.add(keras.layers.Dropout(0.5))
  65. model.add(keras.layers.Dense(units=10,activation = tf.nn.softmax))
  66. return model
  67. def trian_model(train_images,train_labels,test_images,test_labels):
  68. # re-scale to 0~1.0之间
  69. train_images=train_images/255.0
  70. test_images=test_images/255.0
  71. # mnist数据转换为四维
  72. train_images=np.expand_dims(train_images,axis = 3)
  73. test_images=np.expand_dims(test_images,axis = 3)
  74. print("train_images :{}".format(train_images.shape))
  75. print("test_images :{}".format(test_images.shape))
  76. train_labels=one_hot(train_labels)
  77. test_labels=one_hot(test_labels)
  78. # 建立模型
  79. # model = mnist_net(input_shape=(28,28))
  80. model=mnist_cnn(input_shape=(28,28,1))
  81. model.compile(optimizer=tf.train.AdamOptimizer(),loss="categorical_crossentropy",metrics=['accuracy'])
  82. model.fit(x=train_images,y=train_labels,epochs=5)
  83. test_loss,test_acc=model.evaluate(x=test_images,y=test_labels)
  84. print("Test Accuracy %.2f"%test_acc)
  85. # 开始预测
  86. cnt=0
  87. predictions=model.predict(test_images)
  88. for i in range(len(test_images)):
  89. target=np.argmax(predictions[i])
  90. label=np.argmax(test_labels[i])
  91. if target==label:
  92. cnt +=1
  93. print("correct prediction of total : %.2f"%(cnt/len(test_images)))
  94. model.save('mnist-model.h5')
  95. if __name__=="__main__":
  96. mnist_path = 'D:/MyGit/tensorflow-yolov3/data/mnist.npz'
  97. train_images, train_labels, test_images, test_labels=get_train_val(mnist_path)
  98. # show_mnist(train_images, train_labels)
  99. trian_model(train_images, train_labels, test_images, test_labels)

2.构建高级模型

2.1 函数式 API

tf.keras.Sequential 模型是层的简单堆叠,无法表示任意模型。使用 Keras 函数式 API 可以构建复杂的模型拓扑,例如:

多输入模型,
多输出模型,
具有共享层的模型(同一层被调用多次),
具有非序列数据流的模型(例如,剩余连接)

使用函数式 API 构建的模型具有以下特征:

层实例可调用并返回张量。
输入张量和输出张量用于定义 tf.keras.Model 实例。
此模型的训练方式和 Sequential 模型一样。

以下示例使用函数式 API 构建一个简单的全连接网络:

  1. inputs = keras.Input(shape=(32,)) # Returns a placeholder tensor
  2. # A layer instance is callable on a tensor, and returns a tensor.
  3. x = keras.layers.Dense(64, activation='relu')(inputs)
  4. x = keras.layers.Dense(64, activation='relu')(x)
  5. predictions = keras.layers.Dense(10, activation='softmax')(x)
  6. # Instantiate the model given inputs and outputs.
  7. model = keras.Model(inputs=inputs, outputs=predictions)
  8. # The compile step specifies the training configuration.
  9. model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
  10. loss='categorical_crossentropy',
  11. metrics=['accuracy'])
  12. # Trains for 5 epochs
  13. model.fit(data, labels, batch_size=32, epochs=5)

2.2 搭建mnist手写数字识别网络(函数式 API)

  1. def mnist_cnn():
  2. """
  3. 使用keras定义mnist模型
  4. """
  5. # define a truncated_normal initializer
  6. tn_init = keras.initializers.truncated_normal(0, 0.1, SEED, dtype=tf.float32)
  7. # define a constant initializer
  8. const_init = keras.initializers.constant(0.1, tf.float32)
  9. # define a L2 regularizer
  10. l2_reg = keras.regularizers.l2(5e-4)
  11. """
  12. 输入占位符。如果输入图像的shape是(28, 28, 1),输入的一批图像(16张图)的shape
  13. 是(16, 28, 28, 1);那么,在定义Input时,shape参数只需要一张图像的大小,也就
  14. 是(28, 28, 1),而不是(16, 28, 28, 1)。
  15. input placeholder. the Input's parameter shape should be a image's
  16. shape (28, 28, 1) rather than a batch of image's shape (16, 28, 28, 1).
  17. """
  18. # inputs: shape(None, 28, 28, 1)
  19. inputs = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS), dtype=tf.float32)
  20. """
  21. 卷积,输出shape为(None, 28,18,32)。Conv2D的第一个参数为卷积核个数;第二个参数为卷积
  22. 核大小,和tensorflow不同的是,卷积核的大小只需指定卷积窗口的大小,例如在tensorflow中,
  23. 卷积核的大小为(BATCH_SIZE, 5, 5, 1),那么在Keras中,只需指定卷积窗口的大小(5, 5),
  24. 最后一维的大小会根据之前输入的形状自动推算,假如上一层的shape为(None, 28, 28, 1),那
  25. 么最后一维的大小为1;第三个参数为strides,和上一个参数同理。其他参数可查阅Keras的官方文档。
  26. """
  27. # conv1: shape(None, 28, 28, 32)
  28. conv1 = layers.Conv2D(32, (5, 5), strides=(1, 1), padding='same',
  29. activation='relu', use_bias=True,
  30. kernel_initializer=tn_init, name='conv1')(inputs)
  31. # pool1: shape(None, 14, 14, 32)
  32. pool1 = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='same', name='pool1')(conv1)
  33. # conv2: shape(None, 14, 14, 64)
  34. conv2 = layers.Conv2D(64, (5, 5), strides=(1, 1), padding='same',
  35. activation='relu', use_bias=True,
  36. kernel_initializer=tn_init,
  37. bias_initializer=const_init, name='conv2')(pool1)
  38. # pool2: shape(None, 7, 7, 64)
  39. pool2 = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='same', name='pool2')(conv2)
  40. # flatten: shape(None, 3136)
  41. flatten = layers.Flatten(name='flatten')(pool2)
  42. # fc1: shape(None, 512)
  43. fc1 = layers.Dense(512, 'relu', True, kernel_initializer=tn_init,
  44. bias_initializer=const_init, kernel_regularizer=l2_reg,
  45. bias_regularizer=l2_reg, name='fc1')(flatten)
  46. # dropout
  47. dropout1 = layers.Dropout(0.5, seed=SEED)(fc1)
  48. # dense2: shape(None, 10)
  49. fc2 = layers.Dense(NUM_LABELS, activation=None, use_bias=True,
  50. kernel_initializer=tn_init, bias_initializer=const_init, name='fc2',
  51. kernel_regularizer=l2_reg, bias_regularizer=l2_reg)(dropout1)
  52. # softmax: shape(None, 10)
  53. softmax = layers.Softmax(name='softmax')(fc2)
  54. # make new model
  55. model = keras.Model(inputs=inputs, outputs=softmax, name='nmist')
  56. return model

2.3 完整的训练代码

  1. # -*-coding: utf-8 -*-
  2. """
  3. @Project: tensorflow-yolov3
  4. @File : mnist_cnn2.py
  5. @Author : panjq
  6. @E-mail : pan_jinquan@163.com
  7. @Date : 2019-01-31 10:59:33
  8. """
  9. # 从tensorflow里导入keras和keras.layer
  10. from tensorflow import keras
  11. from tensorflow.keras import layers
  12. import tensorflow as tf
  13. import matplotlib.pyplot as plt
  14. import numpy as np
  15. mnist=keras.datasets.mnist
  16. # 图像的大小
  17. IMAGE_SIZE = 28
  18. # 图像的通道数,为1,即为灰度图像
  19. NUM_CHANNELS = 1
  20. # 图像想素值的范围
  21. PIXEL_DEPTH = 255
  22. # 分类数目,0~9总共有10类
  23. NUM_LABELS = 10
  24. # 验证集大小
  25. VALIDATION_SIZE = 5000 # Size of the validation set.
  26. # 种子
  27. SEED = 66478 # Set to None for random seed.
  28. # 批次大小
  29. BATCH_SIZE = 64
  30. # 训练多少个epoch
  31. NUM_EPOCHS = 10
  32. EVAL_BATCH_SIZE = 64
  33. def mnist_cnn():
  34. """
  35. 使用keras定义mnist模型
  36. """
  37. # define a truncated_normal initializer
  38. tn_init = keras.initializers.truncated_normal(0, 0.1, SEED, dtype=tf.float32)
  39. # define a constant initializer
  40. const_init = keras.initializers.constant(0.1, tf.float32)
  41. # define a L2 regularizer
  42. l2_reg = keras.regularizers.l2(5e-4)
  43. """
  44. 输入占位符。如果输入图像的shape是(28, 28, 1),输入的一批图像(16张图)的shape
  45. 是(16, 28, 28, 1);那么,在定义Input时,shape参数只需要一张图像的大小,也就
  46. 是(28, 28, 1),而不是(16, 28, 28, 1)。
  47. input placeholder. the Input's parameter shape should be a image's
  48. shape (28, 28, 1) rather than a batch of image's shape (16, 28, 28, 1).
  49. """
  50. # inputs: shape(None, 28, 28, 1)
  51. inputs = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS), dtype=tf.float32)
  52. """
  53. 卷积,输出shape为(None, 28,18,32)。Conv2D的第一个参数为卷积核个数;第二个参数为卷积
  54. 核大小,和tensorflow不同的是,卷积核的大小只需指定卷积窗口的大小,例如在tensorflow中,
  55. 卷积核的大小为(BATCH_SIZE, 5, 5, 1),那么在Keras中,只需指定卷积窗口的大小(5, 5),
  56. 最后一维的大小会根据之前输入的形状自动推算,假如上一层的shape为(None, 28, 28, 1),那
  57. 么最后一维的大小为1;第三个参数为strides,和上一个参数同理。其他参数可查阅Keras的官方文档。
  58. """
  59. # conv1: shape(None, 28, 28, 32)
  60. conv1 = layers.Conv2D(32, (5, 5), strides=(1, 1), padding='same',
  61. activation='relu', use_bias=True,
  62. kernel_initializer=tn_init, name='conv1')(inputs)
  63. # pool1: shape(None, 14, 14, 32)
  64. pool1 = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='same', name='pool1')(conv1)
  65. # conv2: shape(None, 14, 14, 64)
  66. conv2 = layers.Conv2D(64, (5, 5), strides=(1, 1), padding='same',
  67. activation='relu', use_bias=True,
  68. kernel_initializer=tn_init,
  69. bias_initializer=const_init, name='conv2')(pool1)
  70. # pool2: shape(None, 7, 7, 64)
  71. pool2 = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='same', name='pool2')(conv2)
  72. # flatten: shape(None, 3136)
  73. flatten = layers.Flatten(name='flatten')(pool2)
  74. # fc1: shape(None, 512)
  75. fc1 = layers.Dense(512, 'relu', True, kernel_initializer=tn_init,
  76. bias_initializer=const_init, kernel_regularizer=l2_reg,
  77. bias_regularizer=l2_reg, name='fc1')(flatten)
  78. # dropout
  79. dropout1 = layers.Dropout(0.5, seed=SEED)(fc1)
  80. # dense2: shape(None, 10)
  81. fc2 = layers.Dense(NUM_LABELS, activation=None, use_bias=True,
  82. kernel_initializer=tn_init, bias_initializer=const_init, name='fc2',
  83. kernel_regularizer=l2_reg, bias_regularizer=l2_reg)(dropout1)
  84. # softmax: shape(None, 10)
  85. softmax = layers.Softmax(name='softmax')(fc2)
  86. # make new model
  87. model = keras.Model(inputs=inputs, outputs=softmax, name='nmist')
  88. return model
  89. def get_train_val(mnist_path):
  90. # mnist下载地址:https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
  91. (train_images, train_labels), (test_images, test_labels) = mnist.load_data(mnist_path)
  92. print("train_images nums:{}".format(len(train_images)))
  93. print("test_images nums:{}".format(len(test_images)))
  94. return train_images, train_labels, test_images, test_labels
  95. def show_mnist(images,labels):
  96. for i in range(25):
  97. plt.subplot(5,5,i+1)
  98. plt.xticks([])
  99. plt.yticks([ ])
  100. plt.grid(False)
  101. plt.imshow(images[i],cmap=plt.cm.gray)
  102. plt.xlabel(str(labels[i]))
  103. plt.show()
  104. def one_hot(labels):
  105. onehot_labels=np.zeros(shape=[len(labels),10])
  106. for i in range(len(labels)):
  107. index=labels[i]
  108. onehot_labels[i][index]=1
  109. return onehot_labels
  110. def trian_model(train_images,train_labels,test_images,test_labels):
  111. # re-scale to 0~1.0之间
  112. train_images=train_images/255.0
  113. test_images=test_images/255.0
  114. # mnist数据转换为四维
  115. train_images=np.expand_dims(train_images,axis = 3)
  116. test_images=np.expand_dims(test_images,axis = 3)
  117. print("train_images :{}".format(train_images.shape))
  118. print("test_images :{}".format(test_images.shape))
  119. train_labels=one_hot(train_labels)
  120. test_labels=one_hot(test_labels)
  121. # 建立模型
  122. model=mnist_cnn()
  123. model.compile(optimizer=tf.train.AdamOptimizer(),loss="categorical_crossentropy",metrics=['accuracy'])
  124. model.fit(x=train_images,y=train_labels,epochs=5)
  125. test_loss,test_acc=model.evaluate(x=test_images,y=test_labels)
  126. print("Test Accuracy %.2f"%test_acc)
  127. # 开始预测
  128. cnt=0
  129. predictions=model.predict(test_images)
  130. for i in range(len(test_images)):
  131. target=np.argmax(predictions[i])
  132. label=np.argmax(test_labels[i])
  133. if target==label:
  134. cnt +=1
  135. print("correct prediction of total : %.2f"%(cnt/len(test_images)))
  136. model.save('mnist-model.h5')
  137. if __name__=="__main__":
  138. mnist_path = 'D:/MyGit/tensorflow-yolov3/data/mnist.npz'
  139. train_images, train_labels, test_images, test_labels=get_train_val(mnist_path)
  140. # show_mnist(train_images, train_labels)
  141. trian_model(train_images, train_labels, test_images, test_labels)

3. tf.keras高级应用:回调 tf.keras.callbacks.Callback

回调是传递给模型的对象,用于在训练期间自定义该模型并扩展其行为。您可以编写自定义回调,也可以使用包含以下方法的内置 tf.keras.callbacks:

  1. tf.keras.callbacks.ModelCheckpoint:定期保存模型的检查点。
  2. tf.keras.callbacks.LearningRateScheduler:动态更改学习速率。
  3. tf.keras.callbacks.EarlyStopping:在验证效果不再改进时中断训练。
  4. tf.keras.callbacks.TensorBoard:使用 TensorBoard 监控模型的行为。

要使用 tf.keras.callbacks.Callback,请将其传递给模型的 fit 方法:

  1. callbacks = [
  2. # Interrupt training if `val_loss` stops improving for over 2 epochs
  3. keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
  4. # Write TensorBoard logs to `./logs` directory
  5. keras.callbacks.TensorBoard(log_dir='./logs')
  6. ]
  7. model.fit(data, labels, batch_size=32, epochs=5, callbacks=callbacks,
  8. validation_data=(val_data, val_targets))

将mnist训练代码增加回调打印模型的信息:

  1. def trian_model(train_images,train_labels,test_images,test_labels):
  2. # re-scale to 0~1.0之间
  3. train_images=train_images/255.0
  4. test_images=test_images/255.0
  5. # mnist数据转换为四维
  6. train_images=np.expand_dims(train_images,axis = 3)
  7. test_images=np.expand_dims(test_images,axis = 3)
  8. print("train_images :{}".format(train_images.shape))
  9. print("test_images :{}".format(test_images.shape))
  10. train_labels=one_hot(train_labels)
  11. test_labels=one_hot(test_labels)
  12. # 建立模型
  13. model=mnist_cnn()
  14. # 打印模型的信息
  15. model.summary()
  16. # 编译模型;第一个参数是优化器;第二个参数为loss,因为是多元分类问题,固为
  17. # 'categorical_crossentropy';第三个参数为metrics,就是在训练的时候需
  18. # 要监控的指标列表。
  19. # compile model
  20. model.compile(optimizer=tf.train.AdamOptimizer(),loss="categorical_crossentropy",metrics=['accuracy'])
  21. # model.compile(optimizer=keras.optimizers.SGD(lr=0.01, momentum=0.9, decay=1e-5),loss='categorical_crossentropy', metrics=['accuracy'])
  22. # 设置回调
  23. # setting callbacks
  24. callbacks = [
  25. # 把TensorBoard的日志写入文件夹'./logs'
  26. # write TensorBoard' logs to directory 'logs'
  27. keras.callbacks.TensorBoard(log_dir='./logs'),
  28. ]
  29. # 开始训练
  30. # start training
  31. model.fit(train_images, train_labels, BATCH_SIZE, epochs=5,
  32. validation_data=(test_images, test_labels), callbacks=callbacks)
  33. # evaluate
  34. print('', 'evaluating on test sets...')
  35. loss, accuracy = model.evaluate(test_images, test_labels)
  36. print('test loss:', loss)
  37. print('test Accuracy:', accuracy)
  38. # save model
  39. model.save('mnist-model.h5')

4. tf.keras高级应用:自定义层tf.keras.layers.Layer

通过对 tf.keras.layers.Layer 进行子类化并实现以下方法来创建自定义层:

build:创建层的权重。使用 add_weight 方法添加权重。
call:定义前向传播。
compute_output_shape:指定在给定输入形状的情况下如何计算层的输出形状。
或者,可以通过实现 get_config 方法和 from_config 类方法序列化层。

下面是一个使用核矩阵实现输入 matmul 的自定义层示例:

  1. class MyLayer(keras.layers.Layer):
  2. def __init__(self, output_dim, **kwargs):
  3. self.output_dim = output_dim
  4. super(MyLayer, self).__init__(**kwargs)
  5. def build(self, input_shape):
  6. shape = tf.TensorShape((input_shape[1], self.output_dim))
  7. # Create a trainable weight variable for this layer.
  8. self.kernel = self.add_weight(name='kernel',
  9. shape=shape,
  10. initializer='uniform',
  11. trainable=True)
  12. # Be sure to call this at the end
  13. super(MyLayer, self).build(input_shape)
  14. def call(self, inputs):
  15. return tf.matmul(inputs, self.kernel)
  16. def compute_output_shape(self, input_shape):
  17. shape = tf.TensorShape(input_shape).as_list()
  18. shape[-1] = self.output_dim
  19. return tf.TensorShape(shape)
  20. def get_config(self):
  21. base_config = super(MyLayer, self).get_config()
  22. base_config['output_dim'] = self.output_dim
  23. @classmethod
  24. def from_config(cls, config):
  25. return cls(**config)
  26. # Create a model using the custom layer
  27. model = keras.Sequential([MyLayer(10),
  28. keras.layers.Activation('softmax')])
  29. # The compile step specifies the training configuration
  30. model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
  31. loss='categorical_crossentropy',
  32. metrics=['accuracy'])
  33. # Trains for 5 epochs.
  34. model.fit(data, targets, batch_size=32, epochs=5)

发表评论

表情:
评论列表 (有 0 条评论,568人围观)

还没有评论,来说两句吧...

相关阅读