【Pytorch】Expected hidden[0] size (2, 136, 256), got [2, 256, 256]

╰半夏微凉° 2022-08-30 15:57 112阅读 0赞

问题

我在使用pytorch的 LSTM (RNN) 构建多类文本分类网络时遇到此错误,网络结构没有问题,能够运行起来,但是运行到几个batch后就报错Expected hidden[0] size (2, 136, 256), got [2, 256, 256]

分析

该错误是由于的训练数据不能被批量大小整除造成的。前面的batch都是256个,但是最后一个batch不足256,只有136个。
假设训练数据有 100个,batch大小为 16,划分为6个batch,最后一个batch将只有 4 个(100%16 = 4)个。

解决方案

(1)方法一
修改batchsize,让数据集大小能整除batchsize
(2)方法二
如果使用Dataloader,设置一个参数drop_last=True,会自动舍弃最后不足batchsize的batch

  1. from torch.utils.data import DataLoader
  2. train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size, drop_last=True)

参考:https://stackoverflow.com/questions/54878904/runtimeerror-expected-hidden0-size-2-20-256-got-2-50-256

发表评论

表情:
评论列表 (有 0 条评论,112人围观)

还没有评论,来说两句吧...

相关阅读

    相关 SHA 256算法

    1. SHA 256算法是什么 要理解SHA 256算法,我们需要先解释哈希函数。哈希函数又称散列函数,是将任何长度的信息转换为另一个值的过程。本质上,它包含数据块,这些

    相关 256创作纪念日

    不知不觉已经是写博客的第256天了,从一个躺平的人变成一个为一件事能坚持并不断去做是真的很爽,回过头看看自己,写了好多东西,也慢慢在成长,不再是以前那个只会玩的小孩了。 1

    相关 SHA256WithRSA

    在[上文][Link 1]中了解到SHA和RSA,工作中恰好用到扩展应用:SHA256WithRSA,本文总结下学习过程,备忘の 再提供另外一种方法,实现Java版pem密