Default / 默认 · September 1, 2021

使用seq2seq做知识提取

Table of Content

尝试使用reformer和的seq2seq做知识提取。
借助reformer-pytorch里面提供的demo。
https://github.com/lucidrains/reformer-pytorch


import torch from reformer_pytorch import ReformerEncDec DE_SEQ_LEN = 4096 EN_SEQ_LEN = 4096 enc_dec = ReformerEncDec( dim = 512, enc_num_tokens = 20000, enc_depth = 6, enc_max_seq_len = DE_SEQ_LEN, dec_num_tokens = 20000, dec_depth = 6, dec_max_seq_len = EN_SEQ_LEN ).cuda() train_seq_in = torch.randint(0, 20000, (1, DE_SEQ_LEN)).long().cuda() train_seq_out = torch.randint(0, 20000, (1, EN_SEQ_LEN)).long().cuda() input_mask = torch.ones(1, DE_SEQ_LEN).bool().cuda() loss = enc_dec(train_seq_in, train_seq_out, return_loss = True, enc_input_mask = input_mask) loss.backward() <h1>learn</h1> <h1>evaluate with the following</h1> eval_seq_in = torch.randint(0, 20000, (1, DE_SEQ_LEN)).long().cuda() eval_seq_out_start = torch.tensor([[0.]]).long().cuda() # assume 0 is id of start token samples = enc_dec.generate(eval_seq_in, eval_seq_out_start, seq_len = EN_SEQ_LEN, eos_token = 1) # assume 1 is id of stop token print(samples.shape) # (1, <= 1024) decode the tokens
%d bloggers like this: