如何使用Tensorflow的PTB模型示例?

2023-11-29

我正在尝试Tensorflow的rnn示例。 一开始遇到一些问题,我可以运行示例来训练 ptb,现在我已经训练了一个模型。

现在我该如何使用模型来创建句子,而不必每次都重新训练?

我用类似的命令运行它python ptb_word_lm.py --data_path=/home/data/ --model medium --save_path=/home/medium

有没有关于如何使用经过训练的模型造句的示例?


1.在最后一行添加以下代码PTBModel:__init__()功能:

self._output_probs = tf.nn.softmax(logits)

2.添加如下功能PTBModel:

@property
def output_probs(self):
    return self._output_probs

3.尝试运行以下代码:

raw_data = reader.ptb_raw_data(FLAGS.data_path)
train_data, valid_data, test_data, vocabulary, word_to_id, id_to_word = raw_data

eval_config = get_config()
eval_config.batch_size = 1
eval_config.num_steps = 1

sess = tf.Session()

initializer = tf.random_uniform_initializer(-eval_config.init_scale,
                                            eval_config.init_scale)
with tf.variable_scope("model", reuse=None, initializer=initializer):
    mtest = PTBModel(is_training=False, config=eval_config)

sess.run(tf.initialize_all_variables())

saver = tf.train.Saver()

ckpt = tf.train.get_checkpoint_state('/home/medium')  # __YOUR__MODEL__SAVE__PATH__
if ckpt and gfile.Exists(ckpt.model_checkpoint_path):
    msg = 'Reading model parameters from %s' % ckpt.model_checkpoint_path
    print(msg)
    saver.restore(sess, ckpt.model_checkpoint_path)

def pick_from_weight(weight, pows=1.0):
    weight = weight**pows
    t = np.cumsum(weight)
    s = np.sum(weight)
    return int(np.searchsorted(t, np.random.rand(1) * s))

while True:
    number_of_sentences = 10  # generate 10 sentences one time
    sentence_cnt = 0
    text = '\n'
    end_of_sentence_char = word_to_id['<eos>']
    input_char = np.array([[end_of_sentence_char]])
    state = sess.run(mtest.initial_state)
    while sentence_cnt < number_of_sentences:
        feed_dict = {mtest.input_data: input_char,
                     mtest.initial_state: state}
        probs, state = sess.run([mtest.output_probs, mtest.final_state],
                                       feed_dict=feed_dict)
        sampled_char = pick_from_weight(probs[0])
        if sampled_char == end_of_sentence_char:
            text += '.\n'
            sentence_cnt += 1
        else:
            text += ' ' + id_to_word[sampled_char]
        input_char = np.array([[sampled_char]])
    print(text)
    raw_input('press any key to continue ...')
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何使用Tensorflow的PTB模型示例? 的相关文章

随机推荐