DGL 的GATConv报错:Expect number of features to match number of nodes (len(u)). Got 2397 and 799 instead

刺骨的言语ヽ痛彻心扉 2023-02-22 05:22 86阅读 0赞

使用DGL的GATConv层,居然意外的出现如下错误:

  1. dgl._ffi.base.DGLError: Expect number of features to match number of nodes (len(u)). Got 2397 and 799 instead.

注意到799是节点数,而2390刚好是799的3倍,这个3恰好又是num_heads的数值。因此
GATConv的返回值的shape为: ( N , H , M ) (N,H,M) (N,H,M) ,其中 N N N 是节点个数, H H H 是特征长度,而 M M M是头的数目。
当不做任何处理,DGL会默认对返回的矩阵做reshape,reshape的目标是(-1,H) 于是矩阵的行数就变成了 N × M N\times M N×M 了,此时就不对了。

解决方法:对GATConv的返回值执行一次flatten:

  1. def forward(g):
  2. ···
  3. for layer in self.layers:
  4. pkt_length_matrix = layer(g,pkt_length_matrix.to(th.device(self.device)))
  5. arv_time_matrix = layer(g,arv_time_matrix.to(th.device(self.device)))
  6. if self.layer_type =='GAT':
  7. pkt_length_matrix = th.flatten(pkt_length_matrix,1)
  8. arv_time_matrix= th.flatten(arv_time_matrix,1)
  9. ···

同时,下一层GATConv的in_feat设置为上一层的out_feat × \times × num_heads。
这个就可以了。

发表评论

表情:
评论列表 (有 0 条评论,86人围观)

还没有评论,来说两句吧...

相关阅读