如何提高自编码器的准确率?

2024-01-27

我有一个自动编码器,我使用不同的解决方案检查了模型的准确性,例如更改转换层的数量并增加它们,添加或删除批量归一化,更改激活函数,但所有这些解决方案的准确性都是相似的,并且不一样有任何奇怪的改进。我很困惑,因为我认为这些不同解决方案的准确度应该不同,但它是 0.8156。你能帮我看看有什么问题吗?我还用 10000 个 epoch 对其进行训练,但 50 个 epoch 的输出是相同的!我的代码是错误的还是不能变得更好?!准确度图 https://i.stack.imgur.com/KBnrp.png

我也不确定学习率衰减是否有效?! 我也把我的代码放在这里:

from keras.layers import Input, Concatenate, GaussianNoise,Dropout,BatchNormalization
from keras.layers import Conv2D
from keras.models import Model
from keras.datasets import mnist,cifar10
from keras.callbacks import TensorBoard
from keras import backend as K
from keras import layers
import matplotlib.pyplot as plt
import tensorflow as tf
import keras as Kr
from keras.callbacks import ReduceLROnPlateau
from keras.callbacks import EarlyStopping
import numpy as np
import pylab as pl
import matplotlib.cm as cm
import keract
from matplotlib import pyplot
from keras import optimizers
from keras import regularizers
from tensorflow.python.keras.layers import Lambda;

image = Input((28, 28, 1))
conv1 = Conv2D(16, (3, 3), activation='elu', padding='same', name='convl1e')(image)
conv2 = Conv2D(32, (3, 3), activation='elu', padding='same', name='convl2e')(conv1)
conv3 = Conv2D(16, (3, 3), activation='elu', padding='same', name='convl3e')(conv2)
#conv3 = Conv2D(8, (3, 3), activation='relu', padding='same', name='convl3e', kernel_initializer='Orthogonal',bias_initializer='glorot_uniform')(conv2)
BN=BatchNormalization()(conv3)
#DrO1=Dropout(0.25,name='Dro1')(conv3)
DrO1=Dropout(0.25,name='Dro1')(BN)
encoded =  Conv2D(1, (3, 3), activation='elu', padding='same',name='encoded_I')(DrO1)



#-----------------------decoder------------------------------------------------
#------------------------------------------------------------------------------
deconv1 = Conv2D(16, (3, 3), activation='elu', padding='same', name='convl1d')(encoded)
deconv2 = Conv2D(32, (3, 3), activation='elu', padding='same', name='convl2d')(deconv1)
deconv3 = Conv2D(16, (3, 3), activation='elu',padding='same', name='convl3d')(deconv2)
BNd=BatchNormalization()(deconv3)
DrO2=Dropout(0.25,name='DrO2')(BNd)
#DrO2=Dropout(0.5,name='DrO2')(deconv3)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same', name='decoder_output')(DrO2) 
#model=Model(inputs=[image,wtm],outputs=decoded)

#--------------------------------adding noise----------------------------------
#decoded_noise = GaussianNoise(0.5)(decoded)


watermark_extraction=Model(inputs=image,outputs=decoded)

watermark_extraction.summary()
#----------------------training the model--------------------------------------
#------------------------------------------------------------------------------
#----------------------Data preparation----------------------------------------

(x_train, _), (x_test, _) = mnist.load_data()
x_validation=x_train[1:10000,:,:]
x_train=x_train[10001:60000,:,:]
#(x_train, _), (x_test, _) = cifar10.load_data()
#x_validation=x_train[1:10000,:,:]
#x_train=x_train[10001:60000,:,:]
#
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_validation = x_validation.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))  # adapt this if using `channels_first` image data format
x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))  # adapt this if using `channels_first` image data format
x_validation = np.reshape(x_validation, (len(x_validation), 28, 28, 1))

#---------------------compile and train the model------------------------------
# is accuracy sensible metric for this model?
learning_rate = 0.1
decay_rate = learning_rate / 50
opt = optimizers.SGD(lr=learning_rate, momentum=0.9, decay=decay_rate, nesterov=False)

watermark_extraction.compile(optimizer=opt, loss=['mse'], metrics=['accuracy'])
es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=20)
#rlrp = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_delta=1E-7, verbose=1)
history=watermark_extraction.fit(x_train, x_train,
          epochs=50,
          batch_size=32, 
          validation_data=(x_validation, x_validation),
          callbacks=[TensorBoard(log_dir='E:/output of tensorboard', histogram_freq=0, write_graph=False),es])
watermark_extraction.summary()
#--------------------visuallize the output layers------------------------------
#_, train_acc = watermark_extraction.evaluate(x_train, x_train)
#_, test_acc = watermark_extraction.evaluate([x_test[5000:5001],wt_expand], [x_test[5000:5001],wt_expand])
#print('Train: %.3f, Test: %.3f' % (train_acc, test_acc))
## plot loss learning curves
pyplot.subplot(211)
pyplot.title('MSE Loss', pad=-40)
pyplot.plot(history.history['loss'], label='train')
pyplot.plot(history.history['val_loss'], label='validation')
pyplot.legend()

pyplot.subplot(212)
pyplot.title('Accuracy', pad=-40)
pyplot.plot(history.history['acc'], label='train')
pyplot.plot(history.history['val_acc'], label='test')
pyplot.legend()
pyplot.show

既然你说你是一个初学者,我将尝试从下往上构建,并尝试尽可能多地用该解释来解释你的代码。

Part 1自动编码器由两部分组成(编码器和解码器)。自动编码器减少存储信息所需的变量数量,而解码器尝试从压缩形式中获取此信息。 (请注意,由于自动编码器的不确定性和数据依赖性,因此在实际数据压缩任务中不使用自动编码器)。

现在在你的代码中你保留padding一样。

conv1 = Conv2D(16, (3, 3), activation='elu', padding='same', name='convl1e')(image)

这基本上消除了自动编码器的压缩和扩展功能,即在每个步骤中,您都使用相同数量的变量来表示信息。

Part 2现在开始训练算法

history=watermark_extraction.fit(x_train, x_train,
          epochs=50,
          batch_size=32, 
          validation_data=(x_validation, x_validation),
          callbacks=[TensorBoard(log_dir='E:/PhD/thesis/deepwatermark/journal code/autoencoder_watermark/11-2-2019/output of tensorboard', histogram_freq=0, write_graph=False),es])

从这个表达式/语句/代码行我得出的结论是,您想要生成与您放入代码中的相同的图像,现在,由于图像存储在相同数量的变量中,您的模型只需传递相同的图像图像到每个步骤而不更改图像中的任何内容,这会激励您的模型将每个过滤器参数优化为 1。

Part 3现在棺材上最大的钉子来了,你已经实现了一个dropout层,首先你应该NEVER在卷积层中实现dropout。此链接解释了原因,并讨论了我认为如果您是初学者应该查看的各种想法。 https://towardsdatascience.com/dont-use-dropout-in-convolutional-networks-81486c823c16现在让我们看看为什么你使用 Dropout 的方式真的很糟糕。正如已经解释过的,最适合您模型的参数是学习值 1 的过滤器中的所有参数。现在发生的情况是您强制关闭其中一些过滤器,这除了关闭所讨论的一些过滤器之外没有任何作用在文章中,这一切都会降低下一层图像的强度。(因为 CNN 过滤器对所有输入通道取平均值)

DrO2=Dropout(0.25,name='DrO2')(BNd)

Part 4这只是一点建议,不会成为任何问题的根源BNd=BatchNormalization()(deconv3)

在这里,您尝试对批次中的数据进行标准化,在大多数情况下,数据标准化非常重要,因为您可能知道它不会让一个特征决定模型,并且每个特征在模型中获得平等的发言权,但在图像数据中每个点都已在 0 到 255 之间缩放,因此使用归一化将其缩放到 0 到 1 之间不会增加任何值,只会向模型添加不必要的计算。

我建议你逐步理解,如果有不清楚的地方,请在下面评论,尽量不要使用 CNN 来谈论自动编码器(无论如何它们没有任何实际应用),而是用它来理解 ConvNet 的各种复杂性( CNN),我选择写这样的答案来解释你的网络部分而不是代码的原因是因为你正在寻找的代码只需谷歌搜索即可,如果你对这个答案感兴趣并且想要了解 CNN 的具体工作原理,请查看此内容,如果您对此答案中的任何内容甚至对这些视频有任何疑问,请在下面评论。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何提高自编码器的准确率? 的相关文章

  • 使用 openCV 对图像中的子图像进行通用检测

    免责声明 我是计算机视觉菜鸟 我看过很多关于如何在较大图像中查找特定子图像的堆栈溢出帖子 我的用例有点不同 因为我不希望它是具体的 而且我不确定如何做到这一点 如果可能的话 但我感觉应该如此 我有大量图像数据集 有时 其中一些图像是数据集的
  • 如何使用固定的 pandas 数据框进行动态 matplotlib 绘图?

    我有一个名为的数据框benchmark returns and strategy returns 两者具有相同的时间跨度 我想找到一种方法以漂亮的动画风格绘制数据点 以便它显示逐渐加载的所有点 我知道有一个matplotlib animat
  • 如何收集列表、字典等中重复计算的结果(或制作修改每个元素的列表的副本)?

    There are a great many existing Q A on Stack Overflow on this general theme but they are all either poor quality typical
  • 如何使用包含代码的“asyncio.sleep()”进行单元测试?

    我在编写 asyncio sleep 包含的单元测试时遇到问题 我要等待实际的睡眠时间吗 I used freezegun到嘲笑时间 当我尝试使用普通可调用对象运行测试时 这个库非常有用 但我找不到运行包含 asyncio sleep 的测
  • 如何使用 Scrapy 从网站获取所有纯文本?

    我希望在 HTML 呈现后 可以从网站上看到所有文本 我正在使用 Scrapy 框架使用 Python 工作 和xpath body text 我能够获取它 但是带有 HTML 标签 而且我只想要文本 有什么解决办法吗 最简单的选择是ext
  • keras加载模型错误尝试将包含17层的权重文件加载到0层的模型中

    我目前正在使用 keras 开发 vgg16 模型 我用我的一些图层微调 vgg 模型 拟合我的模型 训练 后 我保存我的模型model save name h5 可以毫无问题地保存 但是 当我尝试使用以下命令重新加载模型时load mod
  • 在 NumPy 中获取 ndarray 的索引和值

    我有一个 ndarrayA任意维数N 我想创建一个数组B元组 数组或列表 其中第一个N每个元组中的元素是索引 最后一个元素是该索引的值A 例如 A array 1 2 3 4 5 6 Then B 0 0 1 0 1 2 0 2 3 1 0
  • Python 中的二进制缓冲区

    在Python中你可以使用StringIO https docs python org library struct html用于字符数据的类似文件的缓冲区 内存映射文件 https docs python org library mmap
  • python pandas 中的双端队列

    我正在使用Python的deque 实现一个简单的循环缓冲区 from collections import deque import numpy as np test sequence np array range 100 2 resha
  • Python:字符串不会转换为浮点数[重复]

    这个问题在这里已经有答案了 我几个小时前写了这个程序 while True print What would you like me to double line raw input gt if line done break else f
  • 表达式中的 Python 'in' 关键字与 for 循环中的比较 [重复]

    这个问题在这里已经有答案了 我明白什么是in运算符在此代码中执行的操作 some list 1 2 3 4 5 print 2 in some list 我也明白i将采用此代码中列表的每个值 for i in 1 2 3 4 5 print
  • ExpectedFailure 被计为错误而不是通过

    我在用着expectedFailure因为有一个我想记录的错误 我现在无法修复 但想将来再回来解决 我的理解expectedFailure是它会将测试计为通过 但在摘要中表示预期失败的数量为 x 类似于它如何处理跳过的 tets 但是 当我
  • Numpy 优化

    我有一个根据条件分配值的函数 我的数据集大小通常在 30 50k 范围内 我不确定这是否是使用 numpy 的正确方法 但是当数字超过 5k 时 它会变得非常慢 有没有更好的方法让它更快 import numpy as np N 5000
  • Nuitka 未使用 nuitka --recurse-all hello.py [错误] 编译 exe

    我正在尝试通过 nuitka 创建一个简单的 exe 这样我就可以在我的笔记本电脑上运行它 而无需安装 Python 我在 Windows 10 上并使用 Anaconda Python 3 我输入 nuitka recurse all h
  • 在Python中重置生成器对象

    我有一个由多个yield 返回的生成器对象 准备调用该生成器是相当耗时的操作 这就是为什么我想多次重复使用生成器 y FunctionWithYield for x in y print x here must be something t
  • 对输入求 Keras 模型的导数返回全零

    所以我有一个 Keras 模型 我想将模型的梯度应用于其输入 这就是我所做的 import tensorflow as tf from keras models import Sequential from keras layers imp
  • 如何使用google colab在jupyter笔记本中显示GIF?

    我正在使用 google colab 想嵌入一个 gif 有谁知道如何做到这一点 我正在使用下面的代码 它并没有在笔记本中为 gif 制作动画 我希望笔记本是交互式的 这样人们就可以看到代码的动画效果 而无需运行它 我发现很多方法在 Goo
  • 在 Python 类中动态定义实例字段

    我是 Python 新手 主要从事 Java 编程 我目前正在思考Python中的类是如何实例化的 我明白那个 init 就像Java中的构造函数 然而 有时 python 类没有 init 方法 在这种情况下我假设有一个默认构造函数 就像
  • Python - 字典和列表相交

    给定以下数据结构 找出这两种数据结构共有的交集键的最有效方法是什么 dict1 2A 3A 4B list1 2A 4B Expected output 2A 4B 如果这也能产生更快的输出 我可以将列表 不是 dict1 组织到任何其他数
  • 改变字典的哈希函数

    按照此question https stackoverflow com questions 37100390 towards understanding dictionaries 我们知道两个不同的字典 dict 1 and dict 2例

随机推荐

  • jsTree 拖放按类限制文件夹

    如何通过类名 class locked 锁定文件夹上的拖动功能 同时锁定其他要拖到该文件夹 中的文件夹class locked 我想要一个既具有拖放功能又具有上下文菜单的设置 如果节点的类名 锁定 我只想禁用上下文菜单的编辑以及拖入此文件夹
  • 使用 python 有效提取 1-5 克

    我有一个 3 000 000 行的巨大文件 每行有 20 40 个单词 我必须从语料库中提取 1 到 5 个 ngram 我的输入文件是标记化的纯文本 例如 This is a foo bar sentence There is a com
  • 用于从 Google Sheets URL 中提取电子表格 ID 和工作表 ID 的 JavaScript 正则表达式

    我想要 Javascript 正则表达式从 google 表格 URL 中提取电子表格 ID 和工作表 ID Sheets google com 电子表格的 URL 如下所示 https docs google com spreadshee
  • 删除 d3js 不工作的事件侦听器

    我有一个 SVG 结构 里面有一些形状 我想在单击形状时触发一个事件 在 SVG 上单击时触发另一个事件 问题是 SVG 事件总是被触发 为了防止这种情况 我禁用了形状的事件冒泡 我还尝试使用 d3 禁用该事件 但似乎不起作用 还尝试使用本
  • 朱莉娅 git 错误

    几个月前我在使用 Julia 最近我想再次使用它 我想要一个新版本 所以我删除了以前的版本和我拥有的所有软件包 现在 安装新版本后 0 6 2 我无法使用任何 Pkg 命令 使用后会出现以下错误init add or update 错误 G
  • 通过 pod 访问 kubernetes python api

    所以我需要通过 pod 连接到 python kubernetes 客户端 我一直在尝试使用config load incluster config 基本上遵循以下示例here https github com kubernetes cli
  • Spearman 与底座 R 的尺距距离

    给定两个排列 v1 1 4 3 1 5 2 v2 1 2 3 4 5 1 如何计算以 R 为基数的 Spearman 尺尺距离 所有元素的总位移 可灵活用于任意两种尺寸排列n 例如 对于这两个向量 如下 1被感动了2地点来自v1 to v2
  • 如何为多个开发人员使用 git

    对于经验丰富的 Git 用户来说 这是一个非常简单的问题 我已经在 git 托管上创建了存储库并设置了我的电脑 git init git remote add origin git sourcerepo com git 然后 经过一些更改后
  • 爪哇。 GUI WindowBuilder 通过单击按钮从 JTextField 读取

    I m useing WindowBuilder and I want to ask how to search in a text file for specific word which I enter to JTextField by
  • 如何在 Python 中使用 Selenium 获取
    1. 元素的长度?

    我有一个 ol 在我的 HTML 中列出 如下所示 ol li class foo li li class foo li li class foo li li class foo li ol 我需要做的是验证 ol 列表包含 li 内的项目
  • ReaderWriterLockSlim 和 async\await

    我有一些问题ReaderWriterLockSlim 我无法理解它是如何发挥作用的 My code private async Task LoadIndex if File Exists FileName index txt return
  • 在 vi 中删除连续的重复行而不排序

    这个问题 https stackoverflow com questions 351161 removing duplicate rows in vi已经解决了如何删除重复行 但强制首先对列表进行排序 我想执行删除连续重复行步骤 即uniq
  • 带数组的 SwitchMap 运算符

    我正在尝试学习 rxjs 和 Observable 的一般概念 并且有一个场景 我有一类
  • 如何防止引用的包含搜索当前源文件的目录?

    海湾合作委员会提供 I 选项 其中 I之前的目录 I 搜索引用的包含 include foo h and I以下目录 I 搜索括号内的包含 include
  • 在verilog中将wire值转换为整数

    我想将电线中的数据转换为整数 例如 wire 2 0 w 3 b101 我想要一个将其转换为 5 并将其存储在整数中的方法 我怎样才能以比这更好的方式做到这一点 j 1 for i 0 i lt 2 i i 1 begin a a w i
  • 如何通过 Google Drive API 使用刷新令牌生成访问令牌?

    我已完成授权步骤并获得访问令牌和刷新令牌 接下来我应该做什么来使用我通过 google Drive API 存储的刷新令牌生成访问令牌 由于我在 Force com 上工作 因此我无法使用任何 sdk 因此请建议直接通过 API 实现它的方
  • 经典 asp - 仅接收肥皂响应的一部分

    我试图从经典 asp 调用肥皂请求 它将在稍后更新 但现在它仍然是经典 asp 但我只得到一半的响应 当我在 SoapUI 中使用请求字符串时 我得到了我正在寻找的响应 但在 asp 中我只收到了部分响应 ASP 请求 Set oXmlHT
  • scala:重写构造函数的隐式参数

    我有一个类 它采用隐式参数 该参数由类内部方法调用的函数使用 我希望能够覆盖该隐式参数 或者从其源复制隐式参数 举个例子 def someMethod implicit p List Int uses p class A implicit
  • 如何在市场上发布应用程序的两个版本?

    我想将我的应用程序的两个版本添加到 Android 市场 一种只需几美分 另一种是带有广告的免费版本 这是一种非常常见的做法 我目前正在将 AdMod 构建到我的应用程序中 看来我必须更改相当多的文件 因此最好为此制作一个单独的应用程序版本
  • 如何提高自编码器的准确率?

    我有一个自动编码器 我使用不同的解决方案检查了模型的准确性 例如更改转换层的数量并增加它们 添加或删除批量归一化 更改激活函数 但所有这些解决方案的准确性都是相似的 并且不一样有任何奇怪的改进 我很困惑 因为我认为这些不同解决方案的准确度应