TensorFlow - Saver.restore 未恢复所有参数

2023-12-29

我训练了双向 LSTM 类型的 RNN 近 24 小时,由于误差波动,我决定在允许其继续训练之前减少学习量。由于模型在每个时期都使用 Saver.save(sess,file) 保存,因此我终止了训练,CTC 损失已最小化至大约 115。

现在恢复模型后,我得到的初始错误率约为 162,这与我在第 7 个 epoch 中得到的错误率流不一致,也是我在第一个 epoch 中得到的错误率。所以我的印象是“恢复”功能不起作用,或者如果它起作用,那么一定有其他东西不允许它生效。

这是我的代码:

    graph = tf.Graph()
    with graph.as_default():
        # Graph creation
        graph_start = time.time()
        seq_inputs = tf.placeholder(tf.float32, shape=     [None,batch_size,frame_length], name="sequence_inputs")
        seq_lens = tf.placeholder(shape=[batch_size],dtype=tf.int32)
        seq_inputs = seq_bn(seq_inputs,seq_lens)

        initializer = tf.truncated_normal_initializer(mean=0,stddev=0.1)
        forward = tf.nn.rnn_cell.LSTMCell(num_units=num_units,
                                          num_proj = hidden_size,
                                          use_peepholes=use_peephole,
                                          initializer=initializer,
                                          state_is_tuple=True)

        forward = tf.nn.rnn_cell.MultiRNNCell([forward] * n_layers, state_is_tuple=True)

        backward = tf.nn.rnn_cell.LSTMCell(num_units=num_units,
                                           num_proj= hidden_size,
                                           use_peepholes=use_peephole,
                                           initializer=initializer,
                                           state_is_tuple=True)

        backward = tf.nn.rnn_cell.MultiRNNCell([backward] * n_layers, state_is_tuple=True)

        [fw_out,bw_out], _ = tf.nn.bidirectional_dynamic_rnn(cell_fw=forward, cell_bw=backward, inputs=seq_inputs,time_major=True, dtype=tf.float32,                                               sequence_length=tf.cast(seq_lens,tf.int64))


        # Batch normalize forward output
        mew,var_ = tf.nn.moments(fw_out,axes=[0])
        fw_out = tf.nn.batch_normalization(fw_out,mew,var_,0.1,1,1e-6)
        # fw_out = seq_bn(fw_out,seq_lens)

        # Batch normalize backward output
        mew,var_ = tf.nn.moments(bw_out,axes=[0])
        bw_out = tf.nn.batch_normalization(bw_out,mew,var_,0.1,1,1e-6)
        # bw_out = seq_bn(bw_out,seq_lens)

        # Reshaping forward, and backward outputs for affine transformation
        fw_out = tf.reshape(fw_out,[-1,hidden_size])
        bw_out = tf.reshape(bw_out,[-1,hidden_size])

        # Linear Layer params
        W_fw = tf.Variable(tf.truncated_normal(shape=[hidden_size,n_chars],stddev=np.sqrt(2.0 / (hidden_size))))
        W_bw = tf.Variable(tf.truncated_normal(shape=[hidden_size,n_chars],stddev=np.sqrt(2.0 / (hidden_size))))
        b_out = tf.constant(0.1,shape=[n_chars])

        # Perform an affine transformation
        logits =  tf.add(tf.add(tf.matmul(fw_out,W_fw),tf.matmul(bw_out,W_bw)),b_out)
        logits = tf.reshape(logits,[-1,batch_size,n_chars])

        # Use CTC Beam Search Decoder to decode pred string from the prob map
        decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_lens)

        # Target params
        indices = tf.placeholder(dtype=tf.int64, shape=[None,2])
        values = tf.placeholder(dtype=tf.int32, shape=[None])
        shape = tf.placeholder(dtype=tf.int64,shape=[2])
        # Make targets
        targets = tf.SparseTensor(indices,values,shape)

        # Compute Loss
        loss = tf.reduce_mean(tf.nn.ctc_loss(logits, targets, seq_lens))
        # Compute error rate based on edit distance
        predicted = tf.to_int32(decoded[0])
        error_rate = tf.reduce_sum(tf.edit_distance(predicted,targets,normalize=False))/ \
         tf.to_float(tf.size(targets.values))    

        tvars = tf.trainable_variables()
        grad, _ = tf.clip_by_global_norm(tf.gradients(loss,tvars),max_grad_norm)
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr,momentum=momentum)
        train_step = optimizer.apply_gradients(zip(grad,tvars))
        graph_end = time.time()
        print("Time elapsed for creating graph: %.3f"%(round(graph_end-graph_start,3)))
        # steps per epoch
        start_time = 0
        steps = int(np.ceil(len(data_train.files)/batch_size))

        loss_tr = []
        log_tr = []
        loss_vl = []
        log_vl = []
        err_tr = []
        err_vl = []
        saver = tf.train.Saver()
        with tf.Session(config=config) as sess:
            #sess.run(tf.initialize_all_variables())
            checkpt_path = tf.train.latest_checkpoint(checkpoint_dir)
            print(saver.restore(sess,checkpt_path))
            print("Model restore from 7th epoch 188th step")
            feed = None
            epoch = None
            step = None
            try:
                for epoch in range(7,epochs+1):
                    if epoch==7:
                       initial_step = 189
                    else:
                       initial_step = 0
                    transcript = []
                    loss_val = 0
                    l_pr = 0
                    start_time = time.time()
                    for step in range(initial_step,steps):
                        train_data, transcript, \
                        targ_indices, targ_values, \
                        targ_shape, n_frames = data_train.next_batch()
                        n_frames = np.reshape(n_frames,[-1])
                        feed = {seq_inputs: train_data, indices:targ_indices, values:targ_values, shape:targ_shape, seq_lens:n_frames}
                        del train_data,targ_indices,targ_values,targ_shape,n_frames

                        # Evaluate loss value, decoded transcript, and log probability
                        _,loss_val,deco,l_pr,err_rt_tr = sess.run([train_step,loss,decoded,log_prob,error_rate],
                                                            feed_dict=feed)
                        del feed
                        loss_tr.append(loss_val)
                        log_tr.append(l_pr)
                        err_tr.append(err_rt_tr)

                        # On validation set
                        val_data, val_transcript, \
                        targ_indices, targ_values, \
                        targ_shape, n_frames = data_val.next_batch()
                        n_frames = np.reshape(n_frames, [-1])
                        feed = {seq_inputs: val_data, indices: targ_indices,values: targ_values, shape: targ_shape, seq_lens: n_frames}
                        del val_data, val_transcript,targ_indices,targ_values,targ_shape,n_frames
                    vl_loss, l_val_pr, err_rt_vl = sess.run([loss, log_prob, error_rate], feed_dict=feed)
                        del feed
                        loss_vl.append(vl_loss)
                        log_vl.append(l_val_pr)
                        err_vl.append(err_rt_vl)
                        print("epoch %d, step: %d, tr_loss: %.2f, vl_loss: %.2f, tr_err: %.2f, vl_err: %.2f"
                          % (epoch, step, np.mean(loss_tr), np.mean(loss_vl), err_rt_tr, err_rt_vl))

                    end_time = time.time()
                    elapsed = round(end_time - start_time, 3)

                    # On training set
                    # Select a random index within batch_size
                    sample_index = np.random.randint(0, batch_size)

                    # Fetch the target transcript
                    actual_str = [data_train.reverse_map[i] for i in transcript[sample_index]]

                    # Fetch the decoded path from probability map
                    pred_sparse = tf.SparseTensor(deco[0].indices, deco[0].values, deco[0].shape)
                    pred_dense = tf.sparse_tensor_to_dense(pred_sparse)
                    ans = pred_dense.eval()
                    #pred = [data_train.reverse_map[i] for i in ans[sample_index, :]]
                    pred = []
                    for i in ans[sample_index,:]:
                        if i == n_chars-1:
                            pred.append(data_train.reverse_map[0])
                        else:
                            pred.append(data_train.reverse_map[i])
                    print("time_elapsed for 200 steps: %.3f, " % (elapsed))
                    if epoch%2 == 0:
                        print("Sample mini-batch results: \n" \
                              "predicted string: ", np.array(pred))
                        print("actual string: ", np.array(actual_str))
                    print("On training set, the loss: %.2f, log_pr: %.3f, error rate %.3f:"% (loss_val, np.mean(l_pr), err_rt_tr))
                    print("On validation set, the loss: %.2f, log_pr: %.3f, error rate: %.3f" % (vl_loss, np.mean(l_val_pr), err_rt_vl))

                    # Save the trainable parameters after the end of an epoch
                    if epoch > 7:
                        path = saver.save(sess, 'model_%d' % epoch)
                    print("Session saved at: %s" % path)
                    np.save(results_fn, np.array([loss_tr, log_tr, loss_vl, log_vl, err_tr, err_vl], dtype=np.object))
            except (KeyboardInterrupt, SystemExit, Exception), e:
                print("Error/Interruption: %s" % str(e))
                exc_type, exc_obj, exc_tb = sys.exc_info()
                print("Line no: %d" % exc_tb.tb_lineno)
                if epoch > 7:
                    print("Saving model: %s" % saver.save(sess, 'Last.cpkt'))
                print("Current batch: %d" % data_train.b_id)
                print("Current epoch: %d" % epoch)
                print("Current step: %d"%step)
                np.save(results_fn, np.array([loss_tr, log_tr, loss_vl, log_vl, err_tr, err_vl], dtype=np.object))
                print("Clossing TF Session...")
                sess.close()
                print("Terminating Program...")
                sys.exit(0)

我认为你需要为每个时期重新初始化你的累加器。

所以这些必须放在循环的开头。

loss_tr = []
log_tr = []
loss_vl = []
log_vl = []
err_tr = []
err_vl = []
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TensorFlow - Saver.restore 未恢复所有参数 的相关文章

随机推荐

  • 将 ListView 包装在 LinearLayout 中

    我正在尝试制作一个屏幕TextView在顶部 ListView在中间和一个Button在底部 我希望这样TextView始终位于屏幕顶部 按钮始终位于底部 然后ListView介于两者之间 当 的时候ListView超过 中间的空间 我希望
  • PHP mail() 变成垃圾邮件,可能是 DNS 问题?

    我正在通过 PHP 的 mail 发送一条带有正确且完整标头的消息 它拥有应有的一切 而且 Hotmail 喜欢电子邮件本身 然而 Hotmail 在消息来源中显示了这一点 X DKIM Result None X Message Stat
  • SQL Server - 计算 HH:MM:SS 格式的两个日期时间戳之间经过的时间

    我有一个包含 时间 列的 SQL Server 表 该表是一个日志表 其中包含状态消息和每条消息的时间戳 日志表通过批处理文件插入 有一个 ID 列将行分组在一起 每次运行批处理文件时 它都会初始化 ID 并写入记录 我需要做的是获取从 I
  • 所有文件夹和子文件夹的列表[关闭]

    Closed 这个问题是无关 help closed questions 目前不接受答案 在Linux中 我想找出所有文件夹 子文件夹名称并重定向到文本文件 我试过ls alR gt list txt 但它给出了所有文件 文件夹 您可以使用
  • 如何获得自定义工具栏上菜单项的连锁反应?

    我有一个具有以下布局的工具栏
  • 使用 Python 向 AWS Elasticsearch 发出签名的 HTTP 请求

    我正在尝试制作一个简单的 Python Lambda 来制作我们的 Elasticsearch 数据库的快照 这是通过Elasticsearch 的 REST API https www elastic co guide en elasti
  • Django:从表单示例保存到数据库

    看来我很难找到关于从表单将数据保存到数据库的良好来源 教程 随着事情的进展 我慢慢迷失了方向 我是 Django 新手 请指导我 我收到错误 赋值前引用的局部变量 store 这是我的相关代码 模型 py from django db im
  • 我们可以使用pm2来启动Vue cli的开发服务器吗?

    使用 vue cli 创建 Vue 项目后 我们可以使用以下命令运行它 yarn run serve 我无法开始使用 pm2 跑步 pm2 start yarn run serve 我遇到了一些崩溃并重新启动的情况 之后 pm2 将停止尝试
  • 如何从代码隐藏中显示隐藏的div C#

    我正在尝试初始化一个用户控件 其中包含一个在页面首次加载时隐藏的网格视图 当用户单击页面上的 搜索 按钮时 我想显示该用户控件内的网格 视图 我尝试了多种不同的方法来显示和隐藏用户控件 我尝试将用户控件放在 div 中 然后使用 style
  • 在Python中处理深度嵌套字典的便捷方法

    我在 python 中有一个深度嵌套的字典 占用了很多空间 有没有办法缩写这样的东西 master dictionary sub categories sub cat name attributes attribute name speci
  • 如何将依赖于 jQuery 的 Javascript 小部件嵌入到未知环境中

    我正在开发一个依赖于 jQuery 的 javascript 小部件 该小部件可能会也可能不会加载到已加载 jQuery 的页面上 在这种情况下会出现很多问题 如果网页没有jQuery 我必须加载我自己的jQuery 然而 这样做时似乎存在
  • 在rails应用程序中使用google图表api - 使用arrayToDataTable时如何在系列中指定空(缺失)值

    我的应用程序使用谷歌图表 API 绘制 4 个数据系列的图表 控制器加载一个数组 视图有谷歌图表 JavaScript 来绘制购物车 如果数组已满 它就可以工作 但当然有时数据系列会丢失一些点 并且我看不到如何指定系列中的 丢失 数据点 因
  • 如何在程序结束时关闭数据库连接?

    在Java程序中 我有一个单例类来保存数据库连接 该连接由整个程序使用 如何告诉Java在程序结束时关闭连接 我可以在 main 末尾放置一个 connection close 语句 但是如果程序意外结束 例如 由于程序中某处未捕获的异常或
  • PHP 在编辑模式下显示下拉列表中选定的值

    这个问题已经被问过 但我的问题很简单 在我的帐户页面中 我在下拉列表中显示了员工所在国家 地区 在编辑模式下如何选择组合中的值 假设您的用户所在国家 地区是 user country以及所有国家 地区的列表 all countries ar
  • 如何从c#返回List并通过com在vc++中使用它

    如何从 C 方法返回 List 以及在 C 中使用 List 返回值 您可以指导如何操作吗 我将按照我的完整方案进行操作 在 c DemoLib cs 中 usng System using System Collections Gener
  • 使用 ffmpeg 命令在视频中添加多个元数据

    添加单个元数据的命令工作正常 ffmpeg i var www html public uploads wp video akka mov metadata kKeyContentIdentifier com apple quicktime
  • 在 numpy 一维数组中查找拐点和驻点

    假设我有以下 numpy 数组 import numpy as np import matplotlib pyplot as plt x np array 11 53333333 11 86666667 11 1 10 66666667 1
  • 无法更改 unicode 字符的字体颜色

    确实很小的事情 但我在 Joomla 前端编辑页面的发布按钮上有这些日历图标 我为此使用 Unicode 字符 U 1F5D2 但我似乎无法使用 CSS 更改其颜色 我试图将其变成白色 浏览器检查员说它是白色的 但显然不是 请参阅此处的示例
  • 在 ARMv8-A Linux 上禁用 CPU 缓存 (L1/L2)

    我想在运行 Linux 的 ARMv8 A 平台上禁用低级缓存 以便独立于缓存访问来测量优化代码的性能 对于英特尔系统 我找到了以下资源 有没有办法在 Linux 系统上禁用 CPU 缓存 L1 L2 https stackoverflow
  • TensorFlow - Saver.restore 未恢复所有参数

    我训练了双向 LSTM 类型的 RNN 近 24 小时 由于误差波动 我决定在允许其继续训练之前减少学习量 由于模型在每个时期都使用 Saver save sess file 保存 因此我终止了训练 CTC 损失已最小化至大约 115 现在