【Bert4keras】解决Key bert/embeddings/word_embeddings not found in checkpoint

阳光穿透心脏的1/2处 2022-09-07 09:59 129阅读 0赞

1 问题

使用苏神的bert4keras,预训练后产生了多个文件,但是在训练加载预训练模型的时候出错Key bert/embeddings/word_embeddings not found in checkpoint
(1)pretrain.py文件

  1. # 预训练脚本
  2. import os
  3. os.environ['TF_KERAS'] = '1' # 必须使用tf.keras
  4. import tensorflow as tf
  5. from bert4keras.backend import keras, K
  6. from bert4keras.models import build_transformer_model
  7. from bert4keras.optimizers import Adam
  8. from bert4keras.optimizers import extend_with_gradient_accumulation
  9. from bert4keras.optimizers import extend_with_layer_adaptation
  10. from bert4keras.optimizers import extend_with_piecewise_linear_lr
  11. from bert4keras.optimizers import extend_with_weight_decay
  12. from keras.layers import Input, Lambda
  13. from keras.models import Model
  14. from data_utils import TrainingDatasetRoBERTa
  15. # 语料路径和模型保存路径
  16. model_saved_path = 'pre_models/bert_model.ckpt'
  17. corpus_paths = [f'corpus_tfrecord/corpus.{ i}.tfrecord' for i in range(10)]
  18. # 其他配置
  19. sequence_length = 512
  20. batch_size = 64
  21. config_path = 'pre_models/bert_config.json'
  22. checkpoint_path = None # 如果从零训练,就设为None
  23. learning_rate = 0.00176
  24. weight_decay_rate = 0.01
  25. num_warmup_steps = 3125
  26. num_train_steps = 5000
  27. steps_per_epoch = 100
  28. grad_accum_steps = 16 # 大于1即表明使用梯度累积
  29. epochs = num_train_steps * grad_accum_steps // steps_per_epoch
  30. exclude_from_weight_decay = ['Norm', 'bias']
  31. tpu_address = None # 如果用多GPU跑,直接设为None
  32. which_optimizer = 'lamb' # adam 或 lamb,均自带weight decay
  33. lr_schedule = {
  34. num_warmup_steps * grad_accum_steps: 1.0,
  35. num_train_steps * grad_accum_steps: 0.0,
  36. }
  37. floatx = K.floatx()
  38. # 读取数据集,构建数据张量
  39. dataset = TrainingDatasetRoBERTa.load_tfrecord(
  40. record_names=corpus_paths,
  41. sequence_length=sequence_length,
  42. batch_size=batch_size // grad_accum_steps,
  43. )
  44. def build_transformer_model_with_mlm():
  45. """带mlm的bert模型。"""
  46. bert = build_transformer_model(
  47. config_path, with_mlm='linear', return_keras_model=False
  48. )
  49. proba = bert.model.output
  50. # 辅助输入
  51. token_ids = Input(shape=(None,), dtype='int64', name='token_ids') # 目标id
  52. is_masked = Input(shape=(None,), dtype=floatx, name='is_masked') # mask标记
  53. def mlm_loss(inputs):
  54. """计算loss的函数,需要封装为一个层。"""
  55. y_true, y_pred, mask = inputs
  56. loss = K.sparse_categorical_crossentropy(
  57. y_true, y_pred, from_logits=True
  58. )
  59. loss = K.sum(loss * mask) / (K.sum(mask) + K.epsilon())
  60. return loss
  61. def mlm_acc(inputs):
  62. """计算准确率的函数,需要封装为一个层 """
  63. y_true, y_pred, mask = inputs
  64. y_true = K.cast(y_true, floatx)
  65. acc = keras.metrics.sparse_categorical_accuracy(y_true, y_pred)
  66. acc = K.sum(acc * mask) / (K.sum(mask) + K.epsilon())
  67. return acc
  68. mlm_loss = Lambda(mlm_loss, name='mlm_loss')([token_ids, proba, is_masked])
  69. mlm_acc = Lambda(mlm_acc, name='mlm_acc')([token_ids, proba, is_masked])
  70. train_model = Model(
  71. bert.model.inputs + [token_ids, is_masked], [mlm_loss, mlm_acc]
  72. )
  73. loss = {
  74. 'mlm_loss': lambda y_true, y_pred: y_pred,
  75. 'mlm_acc': lambda y_true, y_pred: K.stop_gradient(y_pred),
  76. }
  77. return bert, train_model, loss
  78. def build_transformer_model_for_pretraining():
  79. """构建训练模型,通用于TPU/GPU 注意全程要用keras标准的层写法,一些比较灵活的“移花接木”式的 写法可能会在TPU上训练失败。此外,要注意的是TPU并非支持所有 tensorflow算子,尤其不支持动态(变长)算子,因此编写相应运算 时要格外留意。 """
  80. bert, train_model, loss = build_transformer_model_with_mlm()
  81. # 优化器
  82. optimizer = extend_with_weight_decay(Adam)
  83. if which_optimizer == 'lamb':
  84. optimizer = extend_with_layer_adaptation(optimizer)
  85. optimizer = extend_with_piecewise_linear_lr(optimizer)
  86. optimizer_params = {
  87. 'learning_rate': learning_rate,
  88. 'lr_schedule': lr_schedule,
  89. 'weight_decay_rate': weight_decay_rate,
  90. 'exclude_from_weight_decay': exclude_from_weight_decay,
  91. 'bias_correction': False,
  92. }
  93. if grad_accum_steps > 1:
  94. optimizer = extend_with_gradient_accumulation(optimizer)
  95. optimizer_params['grad_accum_steps'] = grad_accum_steps
  96. optimizer = optimizer(**optimizer_params)
  97. # 模型定型
  98. train_model.compile(loss=loss, optimizer=optimizer)
  99. # 如果传入权重,则加载。注:须在此处加载,才保证不报错。
  100. if checkpoint_path is not None:
  101. bert.load_weights_from_checkpoint(checkpoint_path)
  102. return train_model,bert
  103. if tpu_address is None:
  104. # 单机多卡模式(多机多卡也类似,但需要硬软件配合,请参考https://tf.wiki)
  105. strategy = tf.distribute.MirroredStrategy()
  106. else:
  107. # TPU模式
  108. resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
  109. tpu=tpu_address
  110. )
  111. tf.config.experimental_connect_to_host(resolver.master())
  112. tf.tpu.experimental.initialize_tpu_system(resolver)
  113. strategy = tf.distribute.experimental.TPUStrategy(resolver)
  114. with strategy.scope():
  115. train_model,bert = build_transformer_model_for_pretraining()
  116. train_model.summary()
  117. class ModelCheckpoint(keras.callbacks.Callback):
  118. """自动保存最新模型。"""
  119. def on_epoch_end(self, epoch, logs=None):
  120. self.model.save_weights(model_saved_path, overwrite=True)
  121. checkpoint = ModelCheckpoint() # 保存模型
  122. csv_logger = keras.callbacks.CSVLogger('training.log') # 记录日志
  123. # 模型训练
  124. train_model.fit(
  125. dataset,
  126. steps_per_epoch=steps_per_epoch,
  127. epochs=epochs,
  128. callbacks=[checkpoint, csv_logger],
  129. )

(2)train.py文件
在train.py中加载模型时,报错Key bert/embeddings/word_embeddings not found in checkpoint

  1. # 加载预训练模型
  2. bert = build_transformer_model(
  3. config_path=BERT_CONFIG_PATH,
  4. checkpoint_path=model_saved_path,
  5. return_keras_model=False,
  6. )

2 解决办法

根据苏神的博客介绍到,需要把预训练的模型加载后,重新生成权重模型,再使用该权重模型,就不会出错。所以在pretrain.py文件中最后加两行代码,并注释掉train_model.fit。重新执行一遍该文件。再执行train.py就不会出现问题

  1. # 以上代码一样,此处省略
  2. # 、、、、
  3. # 模型训练
  4. # train_model.fit(
  5. # dataset,
  6. # steps_per_epoch=steps_per_epoch,
  7. # epochs=epochs,
  8. # callbacks=[checkpoint, csv_logger],
  9. # )
  10. train_model.load_weights(model_saved_path)
  11. bert.save_weights_as_checkpoint(filename='bert_model/bert_model.ckpt')

发表评论

表情:
评论列表 (有 0 条评论,129人围观)

还没有评论,来说两句吧...

相关阅读