Tensorflow 和 Keras 实现之间的比较(第 1 部分:模型)

2023-12-22

我正在尝试使用 Keras 重写 Tensorflow 网络。 Tensorflow中的模型定义为

def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)

def leaky_relu(x, alpha=0.2):
  return tf.nn.relu(x) - alpha * tf.nn.relu(-x)

X = tf.placeholder(tf.float32, shape=[None, 9, 15])
W1 = tf.Variable(xavier_init([135, 128]))
b1 = tf.Variable(tf.zeros(shape=[128]))
W11 = tf.Variable(xavier_init([128, 256]))
b11 = tf.Variable(tf.zeros(shape=[256]))
W12 = tf.Variable(xavier_init([256, 512]))
b12 = tf.Variable(tf.zeros(shape=[512]))
W13 = tf.Variable(xavier_init([512, 45]))
b13 = tf.Variable(tf.zeros(shape=[45]))

W2 = tf.Variable(xavier_init([135, 128]))
b2 = tf.Variable(tf.zeros(shape=[128]))
W21 = tf.Variable(xavier_init([128, 256]))
b21 = tf.Variable(tf.zeros(shape=[256]))
W22 = tf.Variable(xavier_init([256, 512]))
b22 = tf.Variable(tf.zeros(shape=[512]))
W23 = tf.Variable(xavier_init([512, 540]))
b23 = tf.Variable(tf.zeros(shape=[540]))

def fcn(x):
    out1 = tf.reshape(x, (-1, 135))
    out1 = leaky_relu(tf.matmul(out1, W1) + b1)
    out1 = leaky_relu(tf.matmul(out1, W11) + b11)
    out1 = leaky_relu(tf.matmul(out1, W12) + b12)
    out1 = leaky_relu(tf.matmul(out1, W13) + b13)
    out1 = tf.reshape(out1, (-1, 9, 5))

    out2 = tf.reshape(x, (-1, 135))
    out2 = leaky_relu(tf.matmul(out2, W2) + b2)
    out2 = leaky_relu(tf.matmul(out2, W21) + b21)
    out2 = leaky_relu(tf.matmul(out2, W22) + b22)
    out2 = leaky_relu(tf.matmul(out2, W23) + b23)
    out2 = tf.reshape(out2, [-1, 9, 4, 15])
    out2 = leaky_relu(tf.matmul(tf.transpose(out2, perm=[0, 2, 1, 3]), tf.transpose(out2, perm=[0, 2, 3, 1])))
    out2 = tf.transpose(out2, perm=[0, 2, 3, 1])
    return [out1, out2]

我已经“翻译”了这个,这是我的 Keras 实现

def keras_version():
    input = Input(shape=(135,), name='feature_input')
    out1 = Dense(128, kernel_initializer='glorot_normal', activation='linear')(input)
    out1 = LeakyReLU(alpha=.2)(out1)
    out1 = Dense(256, kernel_initializer='glorot_normal', activation='linear')(out1)
    out1 = LeakyReLU(alpha=.2)(out1)
    out1 = Dense(512, kernel_initializer='glorot_normal', activation='linear')(out1)
    out1 = LeakyReLU(alpha=.2)(out1)
    out1 = Dense(45, kernel_initializer='glorot_normal', activation='linear')(out1)
    out1 = LeakyReLU(alpha=.2)(out1)
    out1 = Reshape((9, 5))(out1)

    out2 = Dense(128, kernel_initializer='glorot_normal', activation='linear')(input)
    out2 = LeakyReLU(alpha=.2)(out2)
    out2 = Dense(256, kernel_initializer='glorot_normal', activation='linear')(out2)
    out2 = LeakyReLU(alpha=.2)(out2)
    out2 = Dense(512, kernel_initializer='glorot_normal', activation='linear')(out2)
    out2 = LeakyReLU(alpha=.2)(out2)
    out2 = Dense(540, kernel_initializer='glorot_normal', activation='linear')(out2)
    out2 = LeakyReLU(alpha=.2)(out2)
    out2 = Reshape((9, 4, 15))(out2)
    out2 = Lambda(lambda x: K.dot(K.permute_dimensions(x, (0, 2, 1, 3)), K.permute_dimensions(x, (0, 2, 3, 1))), output_shape=(4,9,9))(out2)
    out2 = Flatten()(out2)
    out2 = Dense(324, kernel_initializer='glorot_normal', activation='linear')(out2)
    # K.dot should be of size (-1, 4, 9, 9), so I set the output size to 324, and later on, reshaped data
    out2 = LeakyReLU(alpha=.2)(out2)
    out2 = Reshape((4, 9, 9))(out2)
    out2 = Lambda(lambda x: K.permute_dimensions(x, (0, 2, 3, 1)))(out2)

    out1 = Lambda(identity, name='output_1')(out1)
    out2 = Lambda(identity, name='output_2')(out2)

    return Model(input, [out1, out2])

我想知道这个实现是否正确,具体来说:

  1. 定义层尺寸的方式。
  2. 权重初始化的方式。
  3. 矩阵乘法被展平并重新整形的方式。

如果您能指出是否有错误实施或我理解不正确的地方,我将不胜感激。


编辑:这是摘要:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
feature_input (InputLayer)      (None, 135)          0                                            
__________________________________________________________________________________________________
dense_5 (Dense)                 (None, 128)          17408       feature_input[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU)       (None, 128)          0           dense_5[0][0]                    
__________________________________________________________________________________________________
dense_6 (Dense)                 (None, 256)          33024       leaky_re_lu_5[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU)       (None, 256)          0           dense_6[0][0]                    
__________________________________________________________________________________________________
dense_7 (Dense)                 (None, 512)          131584      leaky_re_lu_6[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU)       (None, 512)          0           dense_7[0][0]                    
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 128)          17408       feature_input[0][0]              
__________________________________________________________________________________________________
dense_8 (Dense)                 (None, 540)          277020      leaky_re_lu_7[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU)       (None, 128)          0           dense_1[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU)       (None, 540)          0           dense_8[0][0]                    
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 256)          33024       leaky_re_lu_1[0][0]              
__________________________________________________________________________________________________
reshape_2 (Reshape)             (None, 9, 4, 15)     0           leaky_re_lu_8[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU)       (None, 256)          0           dense_2[0][0]                    
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 4, 9, 9)      0           reshape_2[0][0]                  
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 512)          131584      leaky_re_lu_2[0][0]              
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 324)          0           lambda_1[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU)       (None, 512)          0           dense_3[0][0]                    
__________________________________________________________________________________________________
dense_9 (Dense)                 (None, 324)          105300      flatten_1[0][0]                  
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 45)           23085       leaky_re_lu_3[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU)       (None, 324)          0           dense_9[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU)       (None, 45)           0           dense_4[0][0]                    
__________________________________________________________________________________________________
reshape_3 (Reshape)             (None, 4, 9, 9)      0           leaky_re_lu_9[0][0]              
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 9, 5)         0           leaky_re_lu_4[0][0]              
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, 9, 9, 4)      0           reshape_3[0][0]                  
__________________________________________________________________________________________________
output_1 (Lambda)               (None, 9, 5)         0           reshape_1[0][0]                  
__________________________________________________________________________________________________
output_2 (Lambda)               (None, 9, 9, 4)      0           lambda_2[0][0]                   
==================================================================================================
Total params: 769,437
Trainable params: 769,437
Non-trainable params: 0
__________________________________________________________________________________________________

None

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow 和 Keras 实现之间的比较(第 1 部分:模型) 的相关文章

  • Python 类型提示 Dict 语法错误 可变默认值是不允许的。使用“默认工厂”

    我不知道为什么解释器会抱怨这个类型的字典 对于这两个实例 我得到一个 不允许可变默认值 使用默认工厂 语法错误 我使用的是 python 3 7 3 from dataclasses import dataclass from typing
  • 让 VoiceChannel.members 和 Guild.members 返回完整列表的问题

    每当我尝试使用 VoiceChannel members 或 Guild members 时 它都不会提供适用成员的完整列表 我从文本命令的上下文中获取 VoiceChannel 和 Guild 如下所示 bot command name
  • Gunicorn 工作人员无论如何都会超时

    我正在尝试通过gunicorn运行一个简单的烧瓶应用程序 但是无论我做什么 我的工作人员都会超时 无论是否有针对应用程序的活动 工作人员在我设置任何内容后总是会超时timeout值到 是什么导致它们超时 当我发出请求时 请求成功通过 但工作
  • 在 Python 中将列表元素作为单独的项目返回

    Stackoverflow 的朋友们大家好 我有一个计算列表的函数 我想单独返回列表的每个元素 如下所示 接收此返回的函数旨在处理未定义数量的参数 def foo my list 1 2 3 4 return 1 2 3 4 列表中的元素数
  • matplotlib 图中点的标签

    所以这是一个关于已发布的解决方案的问题 我试图在我拥有的 matplotlib 散点图中的点上放置一些数据标签 我试图在这里模仿解决方案 是否有与 MATLAB 的 datacursormode 等效的 matplotlib https s
  • PyQt 使用 ctrl+Enter 触发按钮

    我正在尝试在我的应用程序中触发 确定 按钮 我当前尝试的代码是这样的 self okPushButton setShortcut ctrl Enter 然而 它不起作用 这是有道理的 我尝试查找一些按键序列here http ftp ics
  • Tensorboard SyntaxError:语法无效

    当我尝试制作张量板时 出现语法错误 尽管开源代码我还是无法理解 我尝试搜索张量板的代码 但不清楚 即使我不擅长Python 我这样写路径C Users jh902 Documents logs因为我正在使用 Windows 10 但我不确定
  • 打印数字时添加千位分隔符[重复]

    这个问题在这里已经有答案了 我真的不知道这个问题的 名称 所以它可能是一个不正确的标题 但问题很简单 如果我有一个数字 例如 number 23543 second 68471243 我想要它使print 像这样 23 54368 471
  • Python 内置的 super() 是否违反了 DRY?

    显然这是有原因的 但我没有足够的经验来认识到这一点 这是Python中给出的例子docs http docs python org 2 library functions html super class C B def method se
  • Python 3:将字符串转换为变量[重复]

    这个问题在这里已经有答案了 我正在从 txt 文件读取文本 并且需要使用我读取的数据之一作为类实例的变量 class Sports def init self players 0 location name self players pla
  • 使用 python/numpy 重塑数组

    我想重塑以下数组 gt gt gt test array 11 12 13 14 21 22 23 24 31 32 33 34 41 42 43 44 为了得到 gt gt gt test2 array 11 12 21 22 13 14
  • 未知错误:Chrome 无法启动:异常退出

    当我使用 chromedriver 对 Selenium 运行测试时 出现此错误 selenium common exceptions WebDriverException Message unknown error Chrome fail
  • 当字段是数字时怎么说...在 mongodb 中匹配?

    所以我的结果中有一个名为 城市 的字段 结果已损坏 有时它是一个实际名称 有时它是一个数字 以下代码显示所有记录 db zips aggregate project city substr city 0 1 sort city 1 我需要修
  • Django 视图中的“请求”是什么

    在 Django 第一个应用程序的 Django 教程中 我们有 from django http import HttpResponse def index request return HttpResponse Hello world
  • pandas - 包含时间序列数据的堆积条形图

    我正在尝试使用时间序列数据在 pandas 中创建堆积条形图 DATE TYPE VOL 0 2010 01 01 Heavy 932 612903 1 2010 01 01 Light 370 612903 2 2010 01 01 Me
  • Pandas 组合不同索引的数据帧

    我有两个数据框df 1 and df 2具有不同的索引和列 但是 有一些索引和列重叠 我创建了一个数据框df索引和列的并集 因此不存在重复的索引或列 我想填写数据框df通过以下方式 for x in df index for y in df
  • 重新分配唯一值 - pandas DataFrame

    我在尝试着assign unique值在pandas df给特定的个人 For the df below Area and Place 会一起弥补unique不同的价值观jobs 这些值将分配给个人 总体目标是使用尽可能少的个人 诀窍在于这
  • 如何将Python3设置为Mac上的默认Python版本?

    有没有办法将 Python 3 8 3 设置为 macOS Catalina 版本 10 15 2 上的默认 Python 版本 我已经完成的步骤 看看它安装在哪里 ls l usr local bin python 我得到的输出是这样的
  • JSON:TypeError:Decimal('34.3')不是JSON可序列化的[重复]

    这个问题在这里已经有答案了 我正在运行一个 SQL 查询 它返回一个小数列表 当我尝试将其转换为 JSON 时 出现类型错误 查询 res db execute SELECT CAST SUM r SalesVolume 1000 0 AS
  • 如何在Python脚本中从youtube-dl中提取文件大小?

    我是 python 编程新手 我想在下载之前提取视频 音频大小 任何 YouTube 视频 gt gt gt from youtube dl import YoutubeDL gt gt gt url https www youtube c

随机推荐