TensorFlow - Saver.restore 未恢复所有参数


我训练了双向 LSTM 类型的 RNN 近 24 小时,由于误差波动,我决定在允许其继续训练之前减少学习量。由于模型在每个时期都使用 Saver.save(sess,file) 保存,因此我终止了训练,CTC 损失已最小化至大约 115。

现在恢复模型后,我得到的初始错误率约为 162,这与我在第 7 个 epoch 中得到的错误率流不一致,也是我在第一个 epoch 中得到的错误率。所以我的印象是“恢复”功能不起作用,或者如果它起作用,那么一定有其他东西不允许它生效。


    graph = tf.Graph()
    with graph.as_default():
        # Graph creation
        graph_start = time.time()
        seq_inputs = tf.placeholder(tf.float32, shape=     [None,batch_size,frame_length], name="sequence_inputs")
        seq_lens = tf.placeholder(shape=[batch_size],dtype=tf.int32)
        seq_inputs = seq_bn(seq_inputs,seq_lens)

        initializer = tf.truncated_normal_initializer(mean=0,stddev=0.1)
        forward = tf.nn.rnn_cell.LSTMCell(num_units=num_units,
                                          num_proj = hidden_size,

        forward = tf.nn.rnn_cell.MultiRNNCell([forward] * n_layers, state_is_tuple=True)

        backward = tf.nn.rnn_cell.LSTMCell(num_units=num_units,
                                           num_proj= hidden_size,

        backward = tf.nn.rnn_cell.MultiRNNCell([backward] * n_layers, state_is_tuple=True)

        [fw_out,bw_out], _ = tf.nn.bidirectional_dynamic_rnn(cell_fw=forward, cell_bw=backward, inputs=seq_inputs,time_major=True, dtype=tf.float32,                                               sequence_length=tf.cast(seq_lens,tf.int64))

        # Batch normalize forward output
        mew,var_ = tf.nn.moments(fw_out,axes=[0])
        fw_out = tf.nn.batch_normalization(fw_out,mew,var_,0.1,1,1e-6)
        # fw_out = seq_bn(fw_out,seq_lens)

        # Batch normalize backward output
        mew,var_ = tf.nn.moments(bw_out,axes=[0])
        bw_out = tf.nn.batch_normalization(bw_out,mew,var_,0.1,1,1e-6)
        # bw_out = seq_bn(bw_out,seq_lens)

        # Reshaping forward, and backward outputs for affine transformation
        fw_out = tf.reshape(fw_out,[-1,hidden_size])
        bw_out = tf.reshape(bw_out,[-1,hidden_size])

        # Linear Layer params
        W_fw = tf.Variable(tf.truncated_normal(shape=[hidden_size,n_chars],stddev=np.sqrt(2.0 / (hidden_size))))
        W_bw = tf.Variable(tf.truncated_normal(shape=[hidden_size,n_chars],stddev=np.sqrt(2.0 / (hidden_size))))
        b_out = tf.constant(0.1,shape=[n_chars])

        # Perform an affine transformation
        logits =  tf.add(tf.add(tf.matmul(fw_out,W_fw),tf.matmul(bw_out,W_bw)),b_out)
        logits = tf.reshape(logits,[-1,batch_size,n_chars])

        # Use CTC Beam Search Decoder to decode pred string from the prob map
        decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_lens)

        # Target params
        indices = tf.placeholder(dtype=tf.int64, shape=[None,2])
        values = tf.placeholder(dtype=tf.int32, shape=[None])
        shape = tf.placeholder(dtype=tf.int64,shape=[2])
        # Make targets
        targets = tf.SparseTensor(indices,values,shape)

        # Compute Loss
        loss = tf.reduce_mean(tf.nn.ctc_loss(logits, targets, seq_lens))
        # Compute error rate based on edit distance
        predicted = tf.to_int32(decoded[0])
        error_rate = tf.reduce_sum(tf.edit_distance(predicted,targets,normalize=False))/ \

        tvars = tf.trainable_variables()
        grad, _ = tf.clip_by_global_norm(tf.gradients(loss,tvars),max_grad_norm)
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr,momentum=momentum)
        train_step = optimizer.apply_gradients(zip(grad,tvars))
        graph_end = time.time()
        print("Time elapsed for creating graph: %.3f"%(round(graph_end-graph_start,3)))
        # steps per epoch
        start_time = 0
        steps = int(np.ceil(len(data_train.files)/batch_size))

        loss_tr = []
        log_tr = []
        loss_vl = []
        log_vl = []
        err_tr = []
        err_vl = []
        saver = tf.train.Saver()
        with tf.Session(config=config) as sess:
            checkpt_path = tf.train.latest_checkpoint(checkpoint_dir)
            print("Model restore from 7th epoch 188th step")
            feed = None
            epoch = None
            step = None
                for epoch in range(7,epochs+1):
                    if epoch==7:
                       initial_step = 189
                       initial_step = 0
                    transcript = []
                    loss_val = 0
                    l_pr = 0
                    start_time = time.time()
                    for step in range(initial_step,steps):
                        train_data, transcript, \
                        targ_indices, targ_values, \
                        targ_shape, n_frames = data_train.next_batch()
                        n_frames = np.reshape(n_frames,[-1])
                        feed = {seq_inputs: train_data, indices:targ_indices, values:targ_values, shape:targ_shape, seq_lens:n_frames}
                        del train_data,targ_indices,targ_values,targ_shape,n_frames

                        # Evaluate loss value, decoded transcript, and log probability
                        _,loss_val,deco,l_pr,err_rt_tr = sess.run([train_step,loss,decoded,log_prob,error_rate],
                        del feed

                        # On validation set
                        val_data, val_transcript, \
                        targ_indices, targ_values, \
                        targ_shape, n_frames = data_val.next_batch()
                        n_frames = np.reshape(n_frames, [-1])
                        feed = {seq_inputs: val_data, indices: targ_indices,values: targ_values, shape: targ_shape, seq_lens: n_frames}
                        del val_data, val_transcript,targ_indices,targ_values,targ_shape,n_frames
                    vl_loss, l_val_pr, err_rt_vl = sess.run([loss, log_prob, error_rate], feed_dict=feed)
                        del feed
                        print("epoch %d, step: %d, tr_loss: %.2f, vl_loss: %.2f, tr_err: %.2f, vl_err: %.2f"
                          % (epoch, step, np.mean(loss_tr), np.mean(loss_vl), err_rt_tr, err_rt_vl))

                    end_time = time.time()
                    elapsed = round(end_time - start_time, 3)

                    # On training set
                    # Select a random index within batch_size
                    sample_index = np.random.randint(0, batch_size)

                    # Fetch the target transcript
                    actual_str = [data_train.reverse_map[i] for i in transcript[sample_index]]

                    # Fetch the decoded path from probability map
                    pred_sparse = tf.SparseTensor(deco[0].indices, deco[0].values, deco[0].shape)
                    pred_dense = tf.sparse_tensor_to_dense(pred_sparse)
                    ans = pred_dense.eval()
                    #pred = [data_train.reverse_map[i] for i in ans[sample_index, :]]
                    pred = []
                    for i in ans[sample_index,:]:
                        if i == n_chars-1:
                    print("time_elapsed for 200 steps: %.3f, " % (elapsed))
                    if epoch%2 == 0:
                        print("Sample mini-batch results: \n" \
                              "predicted string: ", np.array(pred))
                        print("actual string: ", np.array(actual_str))
                    print("On training set, the loss: %.2f, log_pr: %.3f, error rate %.3f:"% (loss_val, np.mean(l_pr), err_rt_tr))
                    print("On validation set, the loss: %.2f, log_pr: %.3f, error rate: %.3f" % (vl_loss, np.mean(l_val_pr), err_rt_vl))

                    # Save the trainable parameters after the end of an epoch
                    if epoch > 7:
                        path = saver.save(sess, 'model_%d' % epoch)
                    print("Session saved at: %s" % path)
                    np.save(results_fn, np.array([loss_tr, log_tr, loss_vl, log_vl, err_tr, err_vl], dtype=np.object))
            except (KeyboardInterrupt, SystemExit, Exception), e:
                print("Error/Interruption: %s" % str(e))
                exc_type, exc_obj, exc_tb = sys.exc_info()
                print("Line no: %d" % exc_tb.tb_lineno)
                if epoch > 7:
                    print("Saving model: %s" % saver.save(sess, 'Last.cpkt'))
                print("Current batch: %d" % data_train.b_id)
                print("Current epoch: %d" % epoch)
                print("Current step: %d"%step)
                np.save(results_fn, np.array([loss_tr, log_tr, loss_vl, log_vl, err_tr, err_vl], dtype=np.object))
                print("Clossing TF Session...")
                print("Terminating Program...")



loss_tr = []
log_tr = []
loss_vl = []
log_vl = []
err_tr = []
err_vl = []

