Keras 通过设置种子获得不同的结果[重复]

2024-05-12

在keras中,每次运行都有很高的方差和不稳定的性能。为了解决这个问题,根据https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development。我如图所示设置种子。

不幸的是,这没有帮助,我继续得到好坏参半的结果。任何指导都会有所帮助。

# Seed value (can actually be different for each attribution step)
seed_value= 0

# 1. Set `PYTHONHASHSEED` environment variable at a fixed value
import os
os.environ['PYTHONHASHSEED']=str(seed_value)

# 2. Set `python` built-in pseudo-random generator at a fixed value
import random
random.seed(seed_value)

# 3. Set `numpy` pseudo-random generator at a fixed value
import numpy as np
np.random.seed(seed_value)

# 4. Set `tensorflow` pseudo-random generator at a fixed value
import tensorflow as tf
tf.set_random_seed(seed_value)

# 5. Configure a new global `tensorflow` session
from keras import backend as K
session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
K.set_session(sess)
from itertools import permutations
import keras
from keras import optimizers
from keras.callbacks import Callback
from keras.initializers import glorot_uniform
from keras.layers import Input, LSTM, Dense, concatenate, Lambda, Multiply, Add, Dropout, multiply, TimeDistributed, Conv1D, GlobalMaxPooling1D, LeakyReLU
from keras.activations import softmax, sigmoid
from keras.models import Model
from keras_sequential_ascii import keras2ascii
from keras import backend as K

从 keras.callbacks 导入 ModelCheckpoint、EarlyStopping

# Custom loss to take full batch (of size beam) and apply a mask to calculate the true loss within the beam
beam_size = 10

K.clear_session()
def create_mask(y, yhat):
    idxs = list(permutations(range(beam_size), r=2))
    perms_y = tf.gather(y, idxs)
    perms_yhat = tf.gather(yhat, idxs)
    mask  = tf.where(tf.not_equal(perms_y[:,0], perms_y[:,1]))
    mask = tf.reduce_sum(mask, 1)
    uneq = tf.squeeze(tf.gather(perms_y, mask))
    yhat_uneq = tf.squeeze(tf.gather(perms_yhat, mask))
    return uneq, yhat_uneq

def mask_acc(y, yhat):
    uneq, yhat_uneq = create_mask(y, yhat)
    uneq = tf.argmax(uneq,1)
    yhat_uneq = tf.argmax(yhat_uneq, 1)
    #uneq = tf.Print(uneq, [uneq], summarize=-1)
    #yhat_uneq = tf.Print(yhat_uneq, [yhat_uneq], 'pred', summarize=-1)
    # argmax and compare
    #a = tf.Print(tf.reduce_mean(tf.cast(tf.equal(uneq, yhat_uneq), tf.float32)), [tf.reduce_mean(tf.cast(tf.equal(uneq, yhat_uneq), tf.float32))])
    return tf.reduce_mean(tf.cast(tf.equal(uneq, yhat_uneq), tf.float32))#tf.cond(tf.greater(tf.size(yhat_uneq), 1), lambda: tf.reduce_sum(tf.cast(tf.equal(uneq, yhat_uneq), tf.float32)), lambda: 100.)

def beam_acc(y, yhat):
    #a = tf.Print(yhat, [yhat], 'pred', summarize=-1)
    #yhat = tf.Print(yhat, [yhat],'\nSTART', summarize=-1)
    yhat_uneq = tf.argmax(yhat, 0)
    # argmax and compare
    # do possible indexes and predicted index
    y = tf.reshape(y, [-1])
    #y = tf.Print(y, [y], summarize=-1)
    possible = tf.where(tf.equal(y, tf.constant(1.0,dtype=tf.float32)))
    yhat_uneq = tf.Print(yhat_uneq, [yhat_uneq], 'prediction')
    possible = tf.reshape(possible, [-1])
    #possible = tf.Print(possible, [possible], 'actual')
    mean = tf.reduce_mean(tf.cast(tf.reduce_any(tf.equal(possible, yhat_uneq)), tf.float32))
    #mean = tf.Print(mean, [mean], 'mean\n')
    return mean#tf.reduce_mean(tf.cast(tf.reduce_any(tf.equal(possible, yhat_uneq)), tf.float32))#tf.cond(tf.equal(tf.reduce_sum(y), tf.constant(0.0)), true_fn=lambda: 0., false_fn=lambda: tf.reduce_mean(tf.cast(tf.equal(yhat_uneq, possible), tf.float32)))

def mask_loss(y, yhat):
    # Cosider weighted loss
    uneq, yhat_uneq = create_mask(y, yhat)
    #yhat_uneq = tf.Print(yhat_uneq, [yhat_uneq], summarize=-1)
    #create all permutations and zero out matches with mask
    total_loss = tf.reduce_mean(tf.losses.softmax_cross_entropy(onehot_labels=tf.cast(uneq, tf.int32), logits=yhat_uneq))
    #d = tf.Print(yhat_uneq, [yhat_uneq], summarize=-1)
    #total_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=uneq, logits=yhat_uneq))
    #total_loss = tf.Print(total_loss, [total_loss])
    return total_loss


x = Input((19,78))
lstm1 = LSTM(64, batch_input_shape=(10, 19, 78),return_sequences=True, unroll=True, activation='relu')(x)
#mult = multiply([encoded, ff])
#cat = concatenate([encoded, squeezed])
#dense2 = Dense(10)(encoded)
dense = Dense(1)(lstm1)
#mult = multiply([dense, prob])
#dense2 = Dense(1)(mult)
#print(dense2.shape)
output = Lambda(lambda x: K.sum(x, axis=1))(dense)
#output = Lambda(lambda x: K.squeeze(x, -1))(added)
#lam2 = Lambda(lambda x: K.sum(x, axis=1))(lam)
#probs_aug = Lambda(lambda x: x * .01)(probs)
#output = Add()([lam, probs])
sgd = optimizers.SGD(lr=0.01, nesterov=True, momentum=.9, decay=1e-5)
adam = optimizers.Adam(lr=0.001,decay=1e-5)#, nesterov=True, momentum=.9, decay=1e-5)
lstm_model = Model(inputs=[x], outputs=output)
lstm_model.compile(sgd, loss=mask_loss, metrics=[mask_acc, beam_acc])
filepath="weights.best.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_beam_acc', verbose=1, save_best_only=True, mode='max')
stop = EarlyStopping(monitor='val_beam_acc', patience=3) 
#lstm_model.fit([X_train], y_train, batch_size=10,epochs=10, verbose=1, shuffle=False,validation_data=([X_dev], y_dev))#, callbacks=[checkpoint, ])
           #, callbacks=[PlotLossesCallback()])

lstm_model.fit(X_train, y_train, batch_size=10,
               epochs=10, verbose=1, shuffle=False,validation_data=(X_dev, y_dev), callbacks=[checkpoint, stop])
               #, callbacks=[PlotLossesCallback()])

通过使用 ( 在操作系统级别设置 pythonhashseed 来解决这个问题使用 Keras 和 TensorFlow 后端可重现结果 https://stackoverflow.com/questions/48631576/reproducible-results-using-keras-with-tensorflow-backend):

# Seed value (can actually be different for each attribution step)
seed_value= 0

# 1. Set `PYTHONHASHSEED` environment variable at a fixed value
import os
os.environ['PYTHONHASHSEED']=str(seed_value)

# 2. Set `python` built-in pseudo-random generator at a fixed value
import random
random.seed(seed_value)

# 3. Set `numpy` pseudo-random generator at a fixed value
import numpy as np
np.random.seed(seed_value)

# 4. Set `tensorflow` pseudo-random generator at a fixed value
import tensorflow as tf
tf.set_random_seed(seed_value)

# 5. Configure a new global `tensorflow` session
from keras import backend as K
session_conf = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
K.set_session(sess)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Keras 通过设置种子获得不同的结果[重复] 的相关文章

  • 如何选择单选按钮?

    我在用mechanize我正在尝试从单选按钮列表中选择一个按钮 该列表有 5 项 如何选择第一项 文档没有帮助我 gt gt gt br form
  • Matplotlib imshow:如何在矩阵上应用蒙版

    我正在尝试以图形方式分析二维数据 matplotlib imshow在这方面非常有用 但我觉得如果我可以从矩阵中排除一些单元格 超出感兴趣范围的值 我可以更多地利用它 我的问题是这些值使我感兴趣的范围内的色彩图 变平 排除这些值后 我可以获
  • pandas 系列值之间的过滤

    If s is a pandas Series http pandas pydata org pandas docs stable dsintro html series 我知道我可以这样做 b s lt 4 or b s gt 0 但我做
  • 有没有办法拥有租户特定的 JWT 令牌

    我目前正在开发一个 SPA 应用程序 角度 后端使用 Python Flask API 该应用程序将支持多个租户 我对安全概念有点挣扎 我目前正在使用 jwt extend 颁发的 JWT 令牌对所有租户都有效 我当然可以从令牌中获取用户
  • time.sleep - TypeError:需要一个浮点[关闭]

    Closed 这个问题是无法重现或由拼写错误引起 help closed questions 目前不接受答案 time sleep 2 TypeError a float is required 我该如何解决 我不确定我应该在这里做什么 您
  • Python 中字典的合并层次结构

    我有两本词典 而我想做的事情有点奇怪 基本上 我想合并它们 这很简单 但它们是字典的层次结构 我想以这样的方式合并它们 如果字典中的项目本身就是字典并且存在于两者中 我也想合并这些字典 如果它不是字典 我希望第二个字典中的值覆盖第一个字典中
  • 如何检查Docker中是否安装了python包?

    我使用Dockerfile成功构建了一个容器 但是 我的代码在容器中不起作用 如果我手动安装所有软件包 它确实有效 我假设我搞砸了一些导致 docker 没有正确安装软件包的事情 所以 我想检查Docker容器中是否安装了python包 最
  • 字典键中的通配符

    假设我有一本字典 rank dict V 1 A 2 V 3 A 4 正如您所看到的 我在一个 V 的末尾添加了一个 虽然 3 可能只是 V 的值 但我想要 V1 V2 V2234432 等的另一个密钥 我想检查它 checker V30
  • Python 请求包含有值的参数和没有值的参数

    我正在为 API 编写一个 Python 包装器 该 API 支持具有值的查询参数 例如param1如下 和查询参数do not有价值观 例如param2如下 即 https example com service param1 value
  • 将多个 csv 文件连接成具有相同标头的单个 csv

    我目前正在使用以下代码导入 6 000 个 csv 文件 带标题 并将它们导出到单个 csv 文件 带单个标题行 import csv files from folder path r data US market merged data
  • 调度算法,找到设定长度的所有非重叠区间

    我需要为我的管理应用程序实现一种算法 该算法将告诉我何时可以将任务分配给哪个用户 我实现了一个蛮力解决方案 它似乎有效 但我想知道是否有更有效的方法来做到这一点 为了简单起见 我重写了算法以对数字列表进行操作 而不是数据库查询等 下面我将尝
  • 从 SQL 数据库导入表并按日期过滤行时,将 Pandas 列解析为日期时间

    我有一个DataFrame列名为date 我们如何将 日期 列转换 解析为DateTime object 我使用 Postgresql 数据库加载日期列sql read frame 的一个例子date列是2013 04 04 我想做的是选择
  • Mxnet - 缓慢的数组复制到 GPU

    我的问题 我应该如何在 mxnet 中执行快速矩阵乘法 我的具体问题 数组复制到 GPU 的速度很慢 对此我们能做些什么呢 我创建随机数组 将它们复制到上下文中 然后相乘 import mxnet as mx import mxnet nd
  • 通过 Selenium 和 python 切换到 iframe

    我如何在硒中切换到这个 iframe 只知道 您可以使用 XPath 来定位 iframe driver find element by xpath iframe name Dialogue Window Then switch to th
  • 在 envoy 中使用 rm *(通配符):没有这样的文件或目录

    我正在使用 Python 和 Envoy 我需要删除目录中的所有文件 除了一些文件外 该目录是空的 在终端中 这将是 rm tmp my silly directory 常识表明 在特使中 这转化为 r envoy run rm tmp m
  • 当没有 main 函数时,为什么 sys.settrace 不触发?

    import sys def printer frame event arg print frame event arg return printer sys settrace printer x 1 sys settrace None 上
  • python nltk从句子中提取关键字

    我们要做的第一件事 就是杀掉所有律师 威廉 莎士比亚 鉴于上面的引用 我想退出 kill and lawyers 作为两个突出的关键词来描述句子的整体含义 我提取了以下名词 动词 POS 标签 First NNP thing NN do V
  • 我可以在某些网格中打印带有颜色的 pandas 数据框吗?

    我有一个 pandas DataFrame 我想突出显示一些数据 例如 In 1 import pandas as pd In 2 import numpy as np In 3 df pd DataFrame np reshape ran
  • 网站可以检测您何时将 Selenium 与 chromedriver 结合使用吗?

    我一直在使用 Chromedriver 测试 Selenium 我注意到有些页面可以检测到您正在使用 Selenium 即使根本没有自动化 即使我只是通过 Selenium 使用 Chrome 手动浏览 Xephyr https en wi
  • 如何将 pygame Surface 转换为 PIL 图像?

    我正在使用 PIL 来透视地变换屏幕的一部分 原始图像数据是一个 pygame Surface 需要转换为 PIL 图像 因此我发现了 pygame 的 tostring 函数就是为了这个目的而存在的 然而结果看起来很奇怪 见附图 这段代码

随机推荐

  • 使用 Python 开发时保护 MySQL 密码?

    我正在编写一个使用本地托管的 MySQL 数据库的 Python 脚本 该程序将以源代码形式提供 这样 MySQL 密码就肉眼可见 有没有好的办法来保护这个呢 这个想法是为了防止一些顽皮的人查看源代码 直接访问 MySQL 并做一些事情 好
  • 找不到文件异常..但它就在那里

    嘿 这将是那些愚蠢的问题之一 我试图在本地系统上获取一个文件 但我不断收到FileNotFoundException thrown 请有人纠正我 if File Exists C logs hw healthways prod 2009 0
  • 如何绑定到仅检查复选框?

    我有一个复选框 我正在使用 jquery 我想在用户选中复选框时弹出一个对话框 但是 如果他们取消选中该框 则不会弹出任何内容 我怎样才能做到这一点 另外 我需要使用 jquery live 或 livequery 因为页面加载时不会显示复
  • 如何从 httpclient 调用中获取内容正文?

    我一直在试图弄清楚如何读取 httpclient 调用的内容 但我似乎无法理解 我得到的响应状态是 200 但我不知道如何获取返回的实际 Json 这就是我所需要的 以下是我的代码 async Task
  • 从 python 执行 C++ 代码

    我是 python 的初学者 我不知道这是否可行 我在 python 中有一个简单的循环 它为我提供当前目录中的所有文件 我想要做的是从 python 执行我之前在目录中的所有这些文件上编写的 C 代码 建议的 python 循环应该是这样
  • 使用旧的 ruby​​gems 版本进行捆绑安装

    我遇到的问题似乎与1个月前的问题 https stackoverflow com questions 38279896 rubygems 2 0 14 is not threadsafe bunder install message whe
  • 开源隐形 reCAPTCHA 替代方案 [关闭]

    Closed 这个问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 有没有像 Google 的 Invisible reCAPTCHA V2 这样接近或最好的开源解决方案
  • 如何在Unix中将相对路径转换为绝对路径[关闭]

    Closed 这个问题不符合堆栈溢出指南 help closed questions 目前不接受答案 我想转换 相对路径 home stevin data APP SERVICE datafile txt to 绝对路径 home stev
  • 从分离的头进行 Git 推送

    我以超然的态度做出了一些改变 我想用 Git 将这些更改推送到这个独立的头 我不希望我的更改进入开发分支 当然也不想进入主分支 我正在与另一个人一起处理一个文件 分支示例 develop master HEAD detached at or
  • 用css制作一个加号[重复]

    这个问题在这里已经有答案了 我有一个模型 用于制作看起来非常简单的加号 然而 我的 CSS 技能并不是很好 制作圆圈没什么大不了的 但在里面制作加号我似乎无法弄清楚 这就是我正在尝试做的事情 Mockup 这是我目前拥有的 到目前为止 这是
  • 节点异步循环 - 如何使该代码按顺序运行?

    我知道有几个关于此的帖子 但根据我发现的那些帖子 这应该可以正常工作 我想在循环中发出 http 请求 并且不希望循环迭代 直到触发请求回调 我正在使用异步库 如下所示 const async require async const req
  • C++ 从模板参数中解压参数包

    如何实现下面我想要的 我要解压的参数包不在函数参数列表中 而是在模板参数列表中 include
  • 从频率表生成 data.frame

    我在 2 4 数组中有包含 500 个观察值的合成数据 datax array c 120 181 50 43 41 33 24 8 dim c 2 4 dimnames datax list gender c male female pu
  • Orientation改变时如何处理Activity?

    我正在编写一个活动 它从服务器加载数据并使用 ArrayAdapter 将其显示为列表 为此 我显示了一个进度对话框 即加载 同时它从服务器加载所有数据 然后我在处理程序中关闭该对话框 我的问题是 当我更改方向时 会再次显示进度对话框 这是
  • R 抑制系统或 shell 命令的控制台输出

    我有这个 Windows 批处理文件 我使用 R 从 R 调用该文件shell 命令 该批处理文件执行一些计算并将它们写入磁盘上 也写入屏幕上 我只对磁盘输出感兴趣 我无法更改批处理文件 批处理文件可能有点愚蠢 例如 echo off ec
  • 发送WM_SETTEXT时如何避免EN_CHANGE通知?

    我有一个 CEdit 派生控件 当基本数据为空时 该控件显示字符串 N A 我最近添加了代码 以在控件获得焦点时清空控件 SetWindowText 并在用户离开焦点时将其设置回 N A SetWindowText N A 控空 唯一的问题
  • capistrano deploy.rb 中的 require 找不到文件

    我有一个 Rails 3 0 5 应用程序 我正在设置 capistrano 来使用配方 在我的配置目录中 我有一个名为 database capistrano rb 的文件 在我的deploy rb中 也在配置目录中 我有以下行 就在开头
  • Angular2 authguards 执行异步函数失败

    我想通过检查用户是否从服务器登录来保护我的路由 但异步函数不会被执行 这是我的代码 canActivate route ActivatedRouteSnapshot state RouterStateSnapshot Observable
  • 在 any() 语句中迭代一个小列表是否更快?

    在低长度迭代的限制下考虑以下操作 d 3 slice None None None slice None None None In 215 timeit any type i slice for i in d 1000000 loops b
  • Keras 通过设置种子获得不同的结果[重复]

    这个问题在这里已经有答案了 在keras中 每次运行都有很高的方差和不稳定的性能 为了解决这个问题 根据https keras io getting started faq how can i obtain reproducible res