深度阅读

nn Sequential 中使用LSTM/GRU

作者
作者
2023年08月22日
更新时间
7.99 分钟
阅读时间
0
阅读量

由于LSTM/GRU这种是多个输出,所以没法直接使用,可以采用自定义模块的方式来转化成一个。

还有一个方案:

# I made a module called SelectItem to pick out an element from a tuple or list

class SelectItem(nn.Module):
    def __init__(self, item_index):
        super(SelectItem, self).__init__()
        self._name = 'selectitem'
        self.item_index = item_index

    def forward(self, inputs):
        return inputs[self.item_index]
# SelectItem can be used in Sequential to pick out the hidden state:

net = nn.Sequential(
nn.GRU(dim_in, dim_out, batch_first=True),
SelectItem(1)
)

方案来自于

https://stackoverflow.com/questions/65906889/lstm-error-attributeerror-tuple-object-has-no-attribute-dim/65907794

https://stackoverflow.com/questions/50817916/how-do-i-add-lstm-gru-or-other-recurrent-layers-to-a-sequential-in-pytorch

博客作者

热爱技术,乐于分享,持续学习。专注于Web开发、系统架构设计和人工智能领域。