pytorch --Rnn语言模型 -- 《Recurrent neural network based language model》

论文通过实现RNN来完成了文本分类。

论文地址:88888888

模型结构图:

原理自行参考论文,code and comment:

 1 # -*- coding: utf-8 -*-
 2 # @time : 2019/11/9  15:12
 3
 4 import numpy as np
 5 import torch
 6 import torch.nn as nn
 7 import torch.optim as optim
 8 from torch.autograd import Variable
 9
10 dtype = torch.FloatTensor
11
12 sentences = [ "i like dog", "i love coffee", "i hate milk"]
13
14 word_list = " ".join(sentences).split()
15 word_list = list(set(word_list))
16 word_dict = {w: i for i, w in enumerate(word_list)}
17 number_dict = {i: w for i, w in enumerate(word_list)}
18 n_class = len(word_dict)
19
20 # TextRNN Parameter
21 batch_size = len(sentences)
22 n_step = 2 # number of cells(= number of Step)
23 n_hidden = 5 # number of hidden units in one cell
24
25 def make_batch(sentences):
26     input_batch = []
27     target_batch = []
28
29     for sen in sentences:
30         word = sen.split()
31         input = [word_dict[n] for n in word[:-1]]
32         target = word_dict[word[-1]]
33
34         input_batch.append(np.eye(n_class)[input])
35         target_batch.append(target)
36
37     return input_batch, target_batch
38
39 # to Torch.Tensor
40 input_batch, target_batch = make_batch(sentences)
41 input_batch = Variable(torch.Tensor(input_batch))
42 target_batch = Variable(torch.LongTensor(target_batch))
43
44 class TextRNN(nn.Module):
45     def __init__(self):
46         super(TextRNN, self).__init__()
47
48         self.rnn = nn.RNN(input_size=n_class, hidden_size=n_hidden,batch_first=True)
49         self.W = nn.Parameter(torch.randn([n_hidden, n_class]).type(dtype))
50         self.b = nn.Parameter(torch.randn([n_class]).type(dtype))
51
52     def forward(self, hidden, X):
53         if self.rnn.batch_first == True:
54             # X [batch_size,time_step,word_vector]
55             outputs, hidden = self.rnn(X, hidden)
56
57             # outputs [batch_size, time_step, hidden_size*num_directions]
58             output = outputs[:, -1, :]  # [batch_size, num_directions(=1) * n_hidden]
59             model = torch.mm(output, self.W) + self.b  # model : [batch_size, n_class]
60             return model
61         else:
62             X = X.transpose(0, 1) # X : [n_step, batch_size, n_class]
63             outputs, hidden = self.rnn(X, hidden)
64             # outputs : [n_step, batch_size, num_directions(=1) * n_hidden]
65             # hidden : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
66
67             output = outputs[-1,:,:] # [batch_size, num_directions(=1) * n_hidden]
68             model = torch.mm(output, self.W) + self.b # model : [batch_size, n_class]
69             return model
70
71 model = TextRNN()
72
73 criterion = nn.CrossEntropyLoss()
74 optimizer = optim.Adam(model.parameters(), lr=0.001)
75
76 # Training
77 for epoch in range(5000):
78     optimizer.zero_grad()
79
80     # hidden : [num_layers * num_directions, batch, hidden_size]
81     hidden = Variable(torch.zeros(1, batch_size, n_hidden))
82     # input_batch : [batch_size, n_step, n_class]
83     output = model(hidden, input_batch)
84
85     # output : [batch_size, n_class], target_batch : [batch_size] (LongTensor, not one-hot)
86     loss = criterion(output, target_batch)
87     if (epoch + 1) % 1000 == 0:
88         print(‘Epoch:‘, ‘%04d‘ % (epoch + 1), ‘cost =‘, ‘{:.6f}‘.format(loss))
89
90     loss.backward()
91     optimizer.step()
92
93
94 # Predict
95 hidden_initial = Variable(torch.zeros(1, batch_size, n_hidden))
96 predict = model(hidden_initial, input_batch).data.max(1, keepdim=True)[1]
97 print([sen.split()[:2] for sen in sentences], ‘->‘, [number_dict[n.item()] for n in predict.squeeze()])

原文地址:https://www.cnblogs.com/dhName/p/11826541.html

时间: 2024-09-28 16:24:27

pytorch --Rnn语言模型 -- 《Recurrent neural network based language model》的相关文章

RNN(Recurrent Neural Network)的几个难点

1. vanish of gradient RNN的error相对于某个时间点t的梯度为: \(\frac{\partial E_t}{\partial W}=\sum_{k=1}^{t}\frac{\partial E_t}{\partial y_t}\frac{\partial y_t}{\partial h_i}\frac{\partial h_t}{\partial h_k}\frac{\partial h_k}{\partial W}\), 其中\(h\)是hidden node的输出

Recurrent neural network language modeling toolkit 源码深入剖析系列(一)

系列前言 参考文献: RNNLM - Recurrent Neural Network  Language Modeling Toolkit(点此阅读) Recurrent neural network based language model(点此阅读) EXTENSIONS OF RECURRENT NEURAL NETWORK LANGUAGE MODEL(点此阅读) Strategies for Training Large Scale Neural Network  Language

Recurrent neural network language modeling toolkit 源码走读(八)

系列前言 参考文献: RNNLM - Recurrent Neural Network  Language Modeling Toolkit(点此阅读) Recurrent neural network based language model(点此阅读) EXTENSIONS OF RECURRENT NEURAL NETWORK LANGUAGE MODEL(点此阅读) Strategies for Training Large Scale Neural Network  Language

Recurrent neural network language modeling toolkit 源码深入剖析系列(二)

系列前言 参考文献: RNNLM - Recurrent Neural Network  Language Modeling Toolkit(点此阅读) Recurrent neural network based language model(点此阅读) EXTENSIONS OF RECURRENT NEURAL NETWORK LANGUAGE MODEL(点此阅读) Strategies for Training Large Scale Neural Network  Language

Recurrent neural network language modeling toolkit 源码走读(六)

系列前言 参考文献: RNNLM - Recurrent Neural Network  Language Modeling Toolkit(点此阅读) Recurrent neural network based language model(点此阅读) EXTENSIONS OF RECURRENT NEURAL NETWORK LANGUAGE MODEL(点此阅读) Strategies for Training Large Scale Neural Network  Language

Recurrent neural network language modeling toolkit 源码走读(七)

系列前言 参考文献: RNNLM - Recurrent Neural Network  Language Modeling Toolkit(点此阅读) Recurrent neural network based language model(点此阅读) EXTENSIONS OF RECURRENT NEURAL NETWORK LANGUAGE MODEL(点此阅读) Strategies for Training Large Scale Neural Network  Language

Recurrent neural network language modeling toolkit 源码走读(五)

系列前言 参考文献: RNNLM - Recurrent Neural Network  Language Modeling Toolkit(点此阅读) Recurrent neural network based language model(点此阅读) EXTENSIONS OF RECURRENT NEURAL NETWORK LANGUAGE MODEL(点此阅读) Strategies for Training Large Scale Neural Network  Language

Recurrent neural network language modeling toolkit 源码剖析(三)

系列前言 参考文献: RNNLM - Recurrent Neural Network  Language Modeling Toolkit(点此阅读) Recurrent neural network based language model(点此阅读) EXTENSIONS OF RECURRENT NEURAL NETWORK LANGUAGE MODEL(点此阅读) Strategies for Training Large Scale Neural Network  Language

Recurrent neural network language modeling toolkit 源码剖析(四)

系列前言 参考文献: RNNLM - Recurrent Neural Network  Language Modeling Toolkit(点此阅读) Recurrent neural network based language model(点此阅读) EXTENSIONS OF RECURRENT NEURAL NETWORK LANGUAGE MODEL(点此阅读) Strategies for Training Large Scale Neural Network  Language