深度阅读

使用seq2seq做知识提取

作者
作者
2023年08月22日
更新时间
11.21 分钟
阅读时间
0
阅读量

尝试使用reformer和的seq2seq做知识提取。
借助reformer-pytorch里面提供的demo。
https://github.com/lucidrains/reformer-pytorch


import torch
from reformer_pytorch import ReformerEncDec

DE_SEQ_LEN = 4096
EN_SEQ_LEN = 4096

enc_dec = ReformerEncDec(
    dim = 512,
    enc_num_tokens = 20000,
    enc_depth = 6,
    enc_max_seq_len = DE_SEQ_LEN,
    dec_num_tokens = 20000,
    dec_depth = 6,
    dec_max_seq_len = EN_SEQ_LEN
).cuda()

train_seq_in = torch.randint(0, 20000, (1, DE_SEQ_LEN)).long().cuda()
train_seq_out = torch.randint(0, 20000, (1, EN_SEQ_LEN)).long().cuda()
input_mask = torch.ones(1, DE_SEQ_LEN).bool().cuda()

loss = enc_dec(train_seq_in, train_seq_out, return_loss = True, enc_input_mask = input_mask)
loss.backward()
<h1>learn</h1>
<h1>evaluate with the following</h1>
eval_seq_in = torch.randint(0, 20000, (1, DE_SEQ_LEN)).long().cuda()
eval_seq_out_start = torch.tensor([[0.]]).long().cuda() # assume 0 is id of start token
samples = enc_dec.generate(eval_seq_in, eval_seq_out_start, seq_len = EN_SEQ_LEN, eos_token = 1) # assume 1 is id of stop token
print(samples.shape) # (1, <= 1024) decode the tokens


相关标签

博客作者

热爱技术,乐于分享,持续学习。专注于Web开发、系统架构设计和人工智能领域。