【PyTorch】语言模型/Language model


1 模型描述


  • 统计式的语言模型是一个几率分布。语言模型提供上下文来区分听起来相似的单词和短语。例如,短语“再给我两份葱,让我把记忆煎成饼”和“再给我两分钟,让我把记忆结成冰”听起来相似,但意思不同。
  • 语言模型经常使用在许多自然语言处理方面的应用,如语音识别,机器翻译,词性标注,句法分析,手写体识别和资讯检索。由于字词与句子都是任意组合的长度,因此在训练过的语言模型中会出现未曾出现的字串(资料稀疏的问题),也使得在语料库中估算字串的几率变得很困难,这也是要使用近似的平滑n-元语法(N-gram)模型之原因。
  • 在语音辨识和在资料压缩的领域中,这种模式试图捕捉语言的特性,并预测在语音串列中的下一个字。
  • 在语音识别中,声音与单词序列相匹配。当来自语言模型的证据与发音模型和声学模型相结合时,歧义更容易解决。


  • 这里使用的是Penn Treebank词性标记集
  • 简单地说,语言模型就是用来计算一个句子的概率的模型,也就是判断一句话是否是人话的概率?句子概率越大,语言模型越好,迷惑度越小(from 深入浅出讲解语言模型),因此模型输出是接近人话的文本

2 相关代码

# language model
# Some part of the code was referenced from below.
# https://github.com/pytorch/examples/tree/master/word_language_model 
import torch
import torch.nn as nn
import numpy as np
from torch.nn.utils import clip_grad_norm_
# for dropout

class Dictionary(object):
    def __init__(self): # bi-directional dic
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0
    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1
    def __len__(self):
        return len(self.word2idx)

class Corpus(object):
    def __init__(self):
        self.dictionary = Dictionary()

    def get_data(self, path, batch_size=20):
        # Add words to the dictionary
        with open(path, 'r') as f:
            tokens = 0
            for line in f:
                words = line.split() + ['<eos>']
                tokens += len(words)
                for word in words: 
        # Tokenize the file content
        # recode all words and tokens
        ids = torch.LongTensor(tokens)
        token = 0
        with open(path, 'r') as f:
            for line in f:
                words = line.split() + ['<eos>']
                for word in words:
                    ids[token] = self.dictionary.word2idx[word]
                    token += 1
        num_batches = ids.size(0) // batch_size
        ids = ids[:num_batches*batch_size]
        return ids.view(batch_size, -1)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
embed_size = 128
hidden_size = 1024
num_layers = 1
num_epochs = 5
num_samples = 1000     # number of words to be sampled
batch_size = 20
seq_length = 30
learning_rate = 0.002

# Load "Penn Treebank" dataset
corpus = Corpus()
ids = corpus.get_data('data/train.txt', batch_size)
vocab_size = len(corpus.dictionary)
num_batches = ids.size(1) // seq_length

# RNN based language model
class RNNLM(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers):
        super(RNNLM, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size) # embedding like mapping
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size) # outlayer is a linear function
    def forward(self, x, h):
        # Embed word ids to vectors
        x = self.embed(x)
        # Forward propagate LSTM
        out, (h, c) = self.lstm(x, h)
        # Reshape output to (batch_size*sequence_length, hidden_size)
        out = out.reshape(out.size(0)*out.size(1), out.size(2))
        # Decode hidden states of all time steps
        out = self.linear(out)
        return out, (h, c)

model = RNNLM(vocab_size, embed_size, hidden_size, num_layers).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Truncated backpropagation
def detach(states):
    return [state.detach() for state in states] 

# Train the model
for epoch in range(num_epochs):
    # Set initial hidden and cell states
    states = (torch.zeros(num_layers, batch_size, hidden_size).to(device),
              torch.zeros(num_layers, batch_size, hidden_size).to(device))
    for i in range(0, ids.size(1) - seq_length, seq_length):
        # Get mini-batch inputs and targets
        inputs = ids[:, i:i+seq_length].to(device)
        targets = ids[:, (i+1):(i+1)+seq_length].to(device)
        # Forward pass
        states = detach(states)
        outputs, states = model(inputs, states)
        loss = criterion(outputs, targets.reshape(-1))
        # Backward and optimize
        clip_grad_norm_(model.parameters(), 0.5)

        step = (i+1) // seq_length
        if step % 100 == 0:
            print ('Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}, Perplexity: {:5.2f}'
                   .format(epoch+1, num_epochs, step, num_batches, loss.item(), np.exp(loss.item())))

# Test the model
with torch.no_grad():
    with open('sample.txt', 'w') as f:
        # Set intial hidden ane cell states
        state = (torch.zeros(num_layers, 1, hidden_size).to(device),
                 torch.zeros(num_layers, 1, hidden_size).to(device))

        # Select one word id randomly
        prob = torch.ones(vocab_size)
        input = torch.multinomial(prob, num_samples=1).unsqueeze(1).to(device)

        for i in range(num_samples):
            # Forward propagate RNN 
            output, state = model(input, state)

            # Sample a word id
            prob = output.exp()
            word_id = torch.multinomial(prob, num_samples=1).item()

            # Fill input with sampled word id for the next time step

            # File write
            word = corpus.dictionary.idx2word[word_id]
            word = '\n' if word == '<eos>' else word + ' '

            if (i+1) % 100 == 0:
                print('Sampled [{}/{}] words and save to {}'.format(i+1, num_samples, 'sample.txt'))

# Save the model checkpoints
torch.save(model.state_dict(), 'model.ckpt')

3 程序输出


Epoch [1/5], Step[0/1549], Loss: 9.2150, Perplexity: 10046.61
Epoch [1/5], Step[100/1549], Loss: 6.0423, Perplexity: 420.85
Epoch [1/5], Step[200/1549], Loss: 5.9387, Perplexity: 379.44
Epoch [1/5], Step[300/1549], Loss: 5.7512, Perplexity: 314.56
Epoch [1/5], Step[400/1549], Loss: 5.6709, Perplexity: 290.30
Epoch [1/5], Step[500/1549], Loss: 5.1621, Perplexity: 174.54
Epoch [1/5], Step[600/1549], Loss: 5.1755, Perplexity: 176.89
Epoch [1/5], Step[700/1549], Loss: 5.3721, Perplexity: 215.32
Epoch [1/5], Step[800/1549], Loss: 5.1827, Perplexity: 178.17
Epoch [1/5], Step[900/1549], Loss: 5.0756, Perplexity: 160.06
Epoch [1/5], Step[1000/1549], Loss: 5.1428, Perplexity: 171.19
Epoch [1/5], Step[1100/1549], Loss: 5.3263, Perplexity: 205.67
Epoch [1/5], Step[1200/1549], Loss: 5.1895, Perplexity: 179.39
Epoch [1/5], Step[1300/1549], Loss: 5.0724, Perplexity: 159.56
Epoch [1/5], Step[1400/1549], Loss: 4.8528, Perplexity: 128.10
Epoch [1/5], Step[1500/1549], Loss: 5.1661, Perplexity: 175.22
Epoch [2/5], Step[0/1549], Loss: 5.4163, Perplexity: 225.05
Epoch [2/5], Step[100/1549], Loss: 4.5526, Perplexity: 94.88
Epoch [2/5], Step[200/1549], Loss: 4.6929, Perplexity: 109.17
Epoch [2/5], Step[300/1549], Loss: 4.6444, Perplexity: 104.00
Epoch [2/5], Step[400/1549], Loss: 4.5688, Perplexity: 96.42
Epoch [2/5], Step[500/1549], Loss: 4.1592, Perplexity: 64.02
Epoch [2/5], Step[600/1549], Loss: 4.4269, Perplexity: 83.67
Epoch [2/5], Step[700/1549], Loss: 4.3720, Perplexity: 79.20
Epoch [2/5], Step[800/1549], Loss: 4.4036, Perplexity: 81.74
Epoch [2/5], Step[900/1549], Loss: 4.1653, Perplexity: 64.41
Epoch [2/5], Step[1000/1549], Loss: 4.3449, Perplexity: 77.08
Epoch [2/5], Step[1100/1549], Loss: 4.4840, Perplexity: 88.59
Epoch [2/5], Step[1200/1549], Loss: 4.4659, Perplexity: 87.00
Epoch [2/5], Step[1300/1549], Loss: 4.1735, Perplexity: 64.94
Epoch [2/5], Step[1400/1549], Loss: 3.9952, Perplexity: 54.34
Epoch [2/5], Step[1500/1549], Loss: 4.2860, Perplexity: 72.67
Epoch [3/5], Step[0/1549], Loss: 4.4764, Perplexity: 87.91
Epoch [3/5], Step[100/1549], Loss: 3.8185, Perplexity: 45.54
Epoch [3/5], Step[200/1549], Loss: 4.0630, Perplexity: 58.15
Epoch [3/5], Step[300/1549], Loss: 3.8839, Perplexity: 48.62
Epoch [3/5], Step[400/1549], Loss: 3.9263, Perplexity: 50.72
Epoch [3/5], Step[500/1549], Loss: 3.4153, Perplexity: 30.43
Epoch [3/5], Step[600/1549], Loss: 3.8813, Perplexity: 48.49
Epoch [3/5], Step[700/1549], Loss: 3.7443, Perplexity: 42.28
Epoch [3/5], Step[800/1549], Loss: 3.7594, Perplexity: 42.92
Epoch [3/5], Step[900/1549], Loss: 3.4794, Perplexity: 32.44
Epoch [3/5], Step[1000/1549], Loss: 3.6235, Perplexity: 37.47
Epoch [3/5], Step[1100/1549], Loss: 3.7085, Perplexity: 40.79
Epoch [3/5], Step[1200/1549], Loss: 3.8110, Perplexity: 45.20
Epoch [3/5], Step[1300/1549], Loss: 3.4499, Perplexity: 31.50
Epoch [3/5], Step[1400/1549], Loss: 3.2214, Perplexity: 25.06
Epoch [3/5], Step[1500/1549], Loss: 3.5429, Perplexity: 34.57
Epoch [4/5], Step[0/1549], Loss: 3.6315, Perplexity: 37.77
Epoch [4/5], Step[100/1549], Loss: 3.2487, Perplexity: 25.76
Epoch [4/5], Step[200/1549], Loss: 3.5140, Perplexity: 33.58
Epoch [4/5], Step[300/1549], Loss: 3.3193, Perplexity: 27.64
Epoch [4/5], Step[400/1549], Loss: 3.4360, Perplexity: 31.06
Epoch [4/5], Step[500/1549], Loss: 2.9549, Perplexity: 19.20
Epoch [4/5], Step[600/1549], Loss: 3.3490, Perplexity: 28.48
Epoch [4/5], Step[700/1549], Loss: 3.3122, Perplexity: 27.45
Epoch [4/5], Step[800/1549], Loss: 3.2668, Perplexity: 26.23
Epoch [4/5], Step[900/1549], Loss: 2.9631, Perplexity: 19.36
Epoch [4/5], Step[1000/1549], Loss: 3.1250, Perplexity: 22.76
Epoch [4/5], Step[1100/1549], Loss: 3.2380, Perplexity: 25.48
Epoch [4/5], Step[1200/1549], Loss: 3.2806, Perplexity: 26.59
Epoch [4/5], Step[1300/1549], Loss: 2.9988, Perplexity: 20.06
Epoch [4/5], Step[1400/1549], Loss: 2.7011, Perplexity: 14.90
Epoch [4/5], Step[1500/1549], Loss: 3.1112, Perplexity: 22.45
Epoch [5/5], Step[0/1549], Loss: 3.0950, Perplexity: 22.09
Epoch [5/5], Step[100/1549], Loss: 2.8688, Perplexity: 17.62
Epoch [5/5], Step[200/1549], Loss: 3.1285, Perplexity: 22.84
Epoch [5/5], Step[300/1549], Loss: 2.9598, Perplexity: 19.29
Epoch [5/5], Step[400/1549], Loss: 3.1288, Perplexity: 22.85
Epoch [5/5], Step[500/1549], Loss: 2.6090, Perplexity: 13.58
Epoch [5/5], Step[600/1549], Loss: 3.0915, Perplexity: 22.01
Epoch [5/5], Step[700/1549], Loss: 2.9536, Perplexity: 19.18
Epoch [5/5], Step[800/1549], Loss: 2.9605, Perplexity: 19.31
Epoch [5/5], Step[900/1549], Loss: 2.6687, Perplexity: 14.42
Epoch [5/5], Step[1000/1549], Loss: 2.8161, Perplexity: 16.71
Epoch [5/5], Step[1100/1549], Loss: 2.9194, Perplexity: 18.53
Epoch [5/5], Step[1200/1549], Loss: 3.0538, Perplexity: 21.20
Epoch [5/5], Step[1300/1549], Loss: 2.6999, Perplexity: 14.88
Epoch [5/5], Step[1400/1549], Loss: 2.4688, Perplexity: 11.81
Epoch [5/5], Step[1500/1549], Loss: 2.7906, Perplexity: 16.29
Sampled [100/1000] words and save to sample.txt
Sampled [200/1000] words and save to sample.txt
Sampled [300/1000] words and save to sample.txt
Sampled [400/1000] words and save to sample.txt
Sampled [500/1000] words and save to sample.txt
Sampled [600/1000] words and save to sample.txt
Sampled [700/1000] words and save to sample.txt
Sampled [800/1000] words and save to sample.txt
Sampled [900/1000] words and save to sample.txt
Sampled [1000/1000] words and save to sample.txt


N repeal according to takeover experts 
for the british and canada are insolvent interest to the cost of how the transactions will seek additional information on its u.k. business 
the banks badly previously concluded that americans should slash the impact of income returns by the agency but they warn that it should be known for many small <unk> 
if the central park world will be made only a german economy and the state industries which has <unk> to foreigners how the u.s. can then by how to pay political 's healthy benefit push for such tasks as thrifts often 
while the baltimore for managers are among the <unk> concerned the average families of the cause in the area will translate to the higher costs for beneficiaries days 
he favors a resignation of selling such matters openly that <unk> doubled made all penalties 
there 's no answers at that time the proper has missed it the dead force is <unk> out of control 
we call it and easy for rights to pay attention 
i think we owe japanese investment and obviously not even agree 
until then mr. bryant is a much higher degree and more important for the country 

