1 模型描述
(1)语言模型的定义,来自于维基百科
- 统计式的语言模型是一个几率分布。语言模型提供上下文来区分听起来相似的单词和短语。例如,短语“再给我两份葱,让我把记忆煎成饼”和“再给我两分钟,让我把记忆结成冰”听起来相似,但意思不同。
- 语言模型经常使用在许多自然语言处理方面的应用,如语音识别,机器翻译,词性标注,句法分析,手写体识别和资讯检索。由于字词与句子都是任意组合的长度,因此在训练过的语言模型中会出现未曾出现的字串(资料稀疏的问题),也使得在语料库中估算字串的几率变得很困难,这也是要使用近似的平滑n-元语法(N-gram)模型之原因。
- 在语音辨识和在资料压缩的领域中,这种模式试图捕捉语言的特性,并预测在语音串列中的下一个字。
- 在语音识别中,声音与单词序列相匹配。当来自语言模型的证据与发音模型和声学模型相结合时,歧义更容易解决。
(2)数据集
- 这里使用的是Penn Treebank词性标记集
- 简单地说,语言模型就是用来计算一个句子的概率的模型,也就是判断一句话是否是人话的概率?句子概率越大,语言模型越好,迷惑度越小(from 深入浅出讲解语言模型),因此模型输出是接近人话的文本
-
2 相关代码
# language model
# Some part of the code was referenced from below.
# https://github.com/pytorch/examples/tree/master/word_language_model
import torch
import torch.nn as nn
import numpy as np
from torch.nn.utils import clip_grad_norm_
# for dropout
class Dictionary(object):
def __init__(self): # bi-directional dic
self.word2idx = {}
self.idx2word = {}
self.idx = 0
def add_word(self, word):
if not word in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def __len__(self):
return len(self.word2idx)
class Corpus(object):
def __init__(self):
self.dictionary = Dictionary()
def get_data(self, path, batch_size=20):
# Add words to the dictionary
with open(path, 'r') as f:
tokens = 0
for line in f:
words = line.split() + ['<eos>']
tokens += len(words)
for word in words:
self.dictionary.add_word(word)
# Tokenize the file content
# recode all words and tokens
ids = torch.LongTensor(tokens)
token = 0
with open(path, 'r') as f:
for line in f:
words = line.split() + ['<eos>']
for word in words:
ids[token] = self.dictionary.word2idx[word]
token += 1
num_batches = ids.size(0) // batch_size
ids = ids[:num_batches*batch_size]
return ids.view(batch_size, -1)
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyper-parameters
embed_size = 128
hidden_size = 1024
num_layers = 1
num_epochs = 5
num_samples = 1000 # number of words to be sampled
batch_size = 20
seq_length = 30
learning_rate = 0.002
# Load "Penn Treebank" dataset
corpus = Corpus()
ids = corpus.get_data('data/train.txt', batch_size)
vocab_size = len(corpus.dictionary)
num_batches = ids.size(1) // seq_length
# RNN based language model
class RNNLM(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size, num_layers):
super(RNNLM, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size) # embedding like mapping
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
self.linear = nn.Linear(hidden_size, vocab_size) # outlayer is a linear function
def forward(self, x, h):
# Embed word ids to vectors
x = self.embed(x)
# Forward propagate LSTM
out, (h, c) = self.lstm(x, h)
# Reshape output to (batch_size*sequence_length, hidden_size)
out = out.reshape(out.size(0)*out.size(1), out.size(2))
# Decode hidden states of all time steps
out = self.linear(out)
return out, (h, c)
model = RNNLM(vocab_size, embed_size, hidden_size, num_layers).to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Truncated backpropagation
def detach(states):
return [state.detach() for state in states]
# Train the model
for epoch in range(num_epochs):
# Set initial hidden and cell states
states = (torch.zeros(num_layers, batch_size, hidden_size).to(device),
torch.zeros(num_layers, batch_size, hidden_size).to(device))
for i in range(0, ids.size(1) - seq_length, seq_length):
# Get mini-batch inputs and targets
inputs = ids[:, i:i+seq_length].to(device)
targets = ids[:, (i+1):(i+1)+seq_length].to(device)
# Forward pass
states = detach(states)
outputs, states = model(inputs, states)
loss = criterion(outputs, targets.reshape(-1))
# Backward and optimize
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
step = (i+1) // seq_length
if step % 100 == 0:
print ('Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}, Perplexity: {:5.2f}'
.format(epoch+1, num_epochs, step, num_batches, loss.item(), np.exp(loss.item())))
# Test the model
with torch.no_grad():
with open('sample.txt', 'w') as f:
# Set intial hidden ane cell states
state = (torch.zeros(num_layers, 1, hidden_size).to(device),
torch.zeros(num_layers, 1, hidden_size).to(device))
# Select one word id randomly
prob = torch.ones(vocab_size)
input = torch.multinomial(prob, num_samples=1).unsqueeze(1).to(device)
for i in range(num_samples):
# Forward propagate RNN
output, state = model(input, state)
# Sample a word id
prob = output.exp()
word_id = torch.multinomial(prob, num_samples=1).item()
# Fill input with sampled word id for the next time step
input.fill_(word_id)
# File write
word = corpus.dictionary.idx2word[word_id]
word = '\n' if word == '<eos>' else word + ' '
f.write(word)
if (i+1) % 100 == 0:
print('Sampled [{}/{}] words and save to {}'.format(i+1, num_samples, 'sample.txt'))
# Save the model checkpoints
torch.save(model.state_dict(), 'model.ckpt')
3 程序输出
上述程序的输出如下所示,随着训练程度增加(速度较慢),Loss(交叉熵)和Perplexity(Loss的e次方)不断下降,则模型输出的语言更接近人话。
Epoch [1/5], Step[0/1549], Loss: 9.2150, Perplexity: 10046.61
Epoch [1/5], Step[100/1549], Loss: 6.0423, Perplexity: 420.85
Epoch [1/5], Step[200/1549], Loss: 5.9387, Perplexity: 379.44
Epoch [1/5], Step[300/1549], Loss: 5.7512, Perplexity: 314.56
Epoch [1/5], Step[400/1549], Loss: 5.6709, Perplexity: 290.30
Epoch [1/5], Step[500/1549], Loss: 5.1621, Perplexity: 174.54
Epoch [1/5], Step[600/1549], Loss: 5.1755, Perplexity: 176.89
Epoch [1/5], Step[700/1549], Loss: 5.3721, Perplexity: 215.32
Epoch [1/5], Step[800/1549], Loss: 5.1827, Perplexity: 178.17
Epoch [1/5], Step[900/1549], Loss: 5.0756, Perplexity: 160.06
Epoch [1/5], Step[1000/1549], Loss: 5.1428, Perplexity: 171.19
Epoch [1/5], Step[1100/1549], Loss: 5.3263, Perplexity: 205.67
Epoch [1/5], Step[1200/1549], Loss: 5.1895, Perplexity: 179.39
Epoch [1/5], Step[1300/1549], Loss: 5.0724, Perplexity: 159.56
Epoch [1/5], Step[1400/1549], Loss: 4.8528, Perplexity: 128.10
Epoch [1/5], Step[1500/1549], Loss: 5.1661, Perplexity: 175.22
Epoch [2/5], Step[0/1549], Loss: 5.4163, Perplexity: 225.05
Epoch [2/5], Step[100/1549], Loss: 4.5526, Perplexity: 94.88
Epoch [2/5], Step[200/1549], Loss: 4.6929, Perplexity: 109.17
Epoch [2/5], Step[300/1549], Loss: 4.6444, Perplexity: 104.00
Epoch [2/5], Step[400/1549], Loss: 4.5688, Perplexity: 96.42
Epoch [2/5], Step[500/1549], Loss: 4.1592, Perplexity: 64.02
Epoch [2/5], Step[600/1549], Loss: 4.4269, Perplexity: 83.67
Epoch [2/5], Step[700/1549], Loss: 4.3720, Perplexity: 79.20
Epoch [2/5], Step[800/1549], Loss: 4.4036, Perplexity: 81.74
Epoch [2/5], Step[900/1549], Loss: 4.1653, Perplexity: 64.41
Epoch [2/5], Step[1000/1549], Loss: 4.3449, Perplexity: 77.08
Epoch [2/5], Step[1100/1549], Loss: 4.4840, Perplexity: 88.59
Epoch [2/5], Step[1200/1549], Loss: 4.4659, Perplexity: 87.00
Epoch [2/5], Step[1300/1549], Loss: 4.1735, Perplexity: 64.94
Epoch [2/5], Step[1400/1549], Loss: 3.9952, Perplexity: 54.34
Epoch [2/5], Step[1500/1549], Loss: 4.2860, Perplexity: 72.67
Epoch [3/5], Step[0/1549], Loss: 4.4764, Perplexity: 87.91
Epoch [3/5], Step[100/1549], Loss: 3.8185, Perplexity: 45.54
Epoch [3/5], Step[200/1549], Loss: 4.0630, Perplexity: 58.15
Epoch [3/5], Step[300/1549], Loss: 3.8839, Perplexity: 48.62
Epoch [3/5], Step[400/1549], Loss: 3.9263, Perplexity: 50.72
Epoch [3/5], Step[500/1549], Loss: 3.4153, Perplexity: 30.43
Epoch [3/5], Step[600/1549], Loss: 3.8813, Perplexity: 48.49
Epoch [3/5], Step[700/1549], Loss: 3.7443, Perplexity: 42.28
Epoch [3/5], Step[800/1549], Loss: 3.7594, Perplexity: 42.92
Epoch [3/5], Step[900/1549], Loss: 3.4794, Perplexity: 32.44
Epoch [3/5], Step[1000/1549], Loss: 3.6235, Perplexity: 37.47
Epoch [3/5], Step[1100/1549], Loss: 3.7085, Perplexity: 40.79
Epoch [3/5], Step[1200/1549], Loss: 3.8110, Perplexity: 45.20
Epoch [3/5], Step[1300/1549], Loss: 3.4499, Perplexity: 31.50
Epoch [3/5], Step[1400/1549], Loss: 3.2214, Perplexity: 25.06
Epoch [3/5], Step[1500/1549], Loss: 3.5429, Perplexity: 34.57
Epoch [4/5], Step[0/1549], Loss: 3.6315, Perplexity: 37.77
Epoch [4/5], Step[100/1549], Loss: 3.2487, Perplexity: 25.76
Epoch [4/5], Step[200/1549], Loss: 3.5140, Perplexity: 33.58
Epoch [4/5], Step[300/1549], Loss: 3.3193, Perplexity: 27.64
Epoch [4/5], Step[400/1549], Loss: 3.4360, Perplexity: 31.06
Epoch [4/5], Step[500/1549], Loss: 2.9549, Perplexity: 19.20
Epoch [4/5], Step[600/1549], Loss: 3.3490, Perplexity: 28.48
Epoch [4/5], Step[700/1549], Loss: 3.3122, Perplexity: 27.45
Epoch [4/5], Step[800/1549], Loss: 3.2668, Perplexity: 26.23
Epoch [4/5], Step[900/1549], Loss: 2.9631, Perplexity: 19.36
Epoch [4/5], Step[1000/1549], Loss: 3.1250, Perplexity: 22.76
Epoch [4/5], Step[1100/1549], Loss: 3.2380, Perplexity: 25.48
Epoch [4/5], Step[1200/1549], Loss: 3.2806, Perplexity: 26.59
Epoch [4/5], Step[1300/1549], Loss: 2.9988, Perplexity: 20.06
Epoch [4/5], Step[1400/1549], Loss: 2.7011, Perplexity: 14.90
Epoch [4/5], Step[1500/1549], Loss: 3.1112, Perplexity: 22.45
Epoch [5/5], Step[0/1549], Loss: 3.0950, Perplexity: 22.09
Epoch [5/5], Step[100/1549], Loss: 2.8688, Perplexity: 17.62
Epoch [5/5], Step[200/1549], Loss: 3.1285, Perplexity: 22.84
Epoch [5/5], Step[300/1549], Loss: 2.9598, Perplexity: 19.29
Epoch [5/5], Step[400/1549], Loss: 3.1288, Perplexity: 22.85
Epoch [5/5], Step[500/1549], Loss: 2.6090, Perplexity: 13.58
Epoch [5/5], Step[600/1549], Loss: 3.0915, Perplexity: 22.01
Epoch [5/5], Step[700/1549], Loss: 2.9536, Perplexity: 19.18
Epoch [5/5], Step[800/1549], Loss: 2.9605, Perplexity: 19.31
Epoch [5/5], Step[900/1549], Loss: 2.6687, Perplexity: 14.42
Epoch [5/5], Step[1000/1549], Loss: 2.8161, Perplexity: 16.71
Epoch [5/5], Step[1100/1549], Loss: 2.9194, Perplexity: 18.53
Epoch [5/5], Step[1200/1549], Loss: 3.0538, Perplexity: 21.20
Epoch [5/5], Step[1300/1549], Loss: 2.6999, Perplexity: 14.88
Epoch [5/5], Step[1400/1549], Loss: 2.4688, Perplexity: 11.81
Epoch [5/5], Step[1500/1549], Loss: 2.7906, Perplexity: 16.29
Sampled [100/1000] words and save to sample.txt
Sampled [200/1000] words and save to sample.txt
Sampled [300/1000] words and save to sample.txt
Sampled [400/1000] words and save to sample.txt
Sampled [500/1000] words and save to sample.txt
Sampled [600/1000] words and save to sample.txt
Sampled [700/1000] words and save to sample.txt
Sampled [800/1000] words and save to sample.txt
Sampled [900/1000] words and save to sample.txt
Sampled [1000/1000] words and save to sample.txt
截取的sample.txt的部分内容如下,基本上语言表述是可以的,但还是缺乏逻辑能力。只能通过增大语料库,并增加训练程度来提升。
N repeal according to takeover experts
for the british and canada are insolvent interest to the cost of how the transactions will seek additional information on its u.k. business
the banks badly previously concluded that americans should slash the impact of income returns by the agency but they warn that it should be known for many small <unk>
if the central park world will be made only a german economy and the state industries which has <unk> to foreigners how the u.s. can then by how to pay political 's healthy benefit push for such tasks as thrifts often
while the baltimore for managers are among the <unk> concerned the average families of the cause in the area will translate to the higher costs for beneficiaries days
he favors a resignation of selling such matters openly that <unk> doubled made all penalties
there 's no answers at that time the proper has missed it the dead force is <unk> out of control
we call it and easy for rights to pay attention
i think we owe japanese investment and obviously not even agree
until then mr. bryant is a much higher degree and more important for the country