猫狗数据集

2023-11-09

import numpy as np
import pickle
import cv2
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

#mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
train_data = {b'data': [], b'labels': []}
with open("D:/TensorFlow_gpu/animal.pickle", mode='rb') as file:
data = pickle.load(file, encoding='bytes')
train_data[b'data'] += list(data['train_images'])
train_data[b'labels'] += list(data['train_label'])

train_epochs = 802 # 训练轮数
batch_size = 40 # 随机出去数据大小
display_step = 10 # 显示训练结果的间隔
learning_rate = 0.000001 # 学习效率
drop_prob = 0.2 # 正则化,丢弃比例
fch_nodes = 256 # 全连接隐藏层神经元的个数

def weight_init(shape):
weights = tf.truncated_normal(shape, stddev=0.1, dtype=tf.float32)#符合正太分布mean=0
#weights = tf.truncated_normal(shape, mean=0.01, stddev=0.1, dtype=tf.float32)
return tf.Variable(weights)


# 偏置的初始化
def biases_init(shape):
biases = tf.random_normal(shape, dtype=tf.float32)
# biases = tf.random_normal(shape, mean=-0.01, stddev=0.1, dtype=tf.float32)
return tf.Variable(biases)


# 随机选取mini_batch
def get_random_batchdata(n_samples, batchsize):
start_index = np.random.randint(0, n_samples - batchsize)
return (start_index, start_index + batchsize)


def xavier_init(layer1, layer2, constant=1):
Min = -constant * np.sqrt(6.0 / (layer1 + layer2))
Max = constant * np.sqrt(6.0 / (layer1 + layer2))
return tf.Variable(tf.random_uniform((layer1, layer2), minval=Min, maxval=Max, dtype=tf.float32))


def conv2d(x, w):
return tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME')


def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')


x = tf.placeholder(tf.float32, [None, 224,224,3])
y = tf.placeholder(tf.float32, [None, 2])
# 把灰度图像一维向量,转换为28x28二维结构
x_image = x

w_conv1 = weight_init([3, 3, 3, 96]) # 3*3,深度为3,96
b_conv1 = biases_init([96])
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1) # 输出张量的尺寸:112
h_pool1 = max_pool_2x2(h_conv1)

W_conv2 = weight_init([3, 3, 96, 96])
b_conv2 = biases_init([96])
h_conv2 = tf.nn.tanh(conv2d(h_pool1, W_conv2) + b_conv2)#输出是56
h_pool2 = max_pool_2x2(h_conv2)#池化后输出16*16*96
#2-1
W_conv3 = weight_init([3, 3, 96, 128])
b_conv3 = biases_init([128])
h_conv3 = tf.nn.relu(conv2d(h_pool2, W_conv3) + b_conv3)#输出28
h_pool3 = max_pool_2x2(h_conv3)#池化后输出16*16*96
#第2层卷积2-2

W_conv4 = weight_init([3, 3, 128, 128])
b_conv4 = biases_init([128])
h_conv4 = tf.nn.tanh(conv2d(h_pool3, W_conv4) + b_conv4)#14
h_pool4 = max_pool_2x2(h_conv4)#池化输出8*8*128
#3-1
W_conv5 = weight_init([3, 3, 128, 256])
b_conv5 = biases_init([256])
h_conv5 = tf.nn.relu(conv2d(h_pool4, W_conv5) + b_conv5)#7*7*256
h_pool5 = max_pool_2x2(h_conv5)#

h_pool5_flat = tf.reshape(h_pool5, [-1, 7 * 7 * 256])

w_fc1 = xavier_init(7 * 7 * 256, fch_nodes)
b_fc1 = biases_init([fch_nodes])
h_fc1 = tf.nn.relu(tf.matmul(h_pool5_flat, w_fc1) + b_fc1)

h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob=drop_prob)

# 隐藏层与输出层权重初始化
w_fc2 = xavier_init(fch_nodes, 2)
b_fc2 = biases_init([2])

# 未激活的输出
y_ = tf.add(tf.matmul(h_fc1, w_fc2), b_fc2)
#y_ = tf.add(tf.matmul(h_fc1_drop, w_fc2), b_fc2)


# 激活后的输出
y_out = tf.nn.softmax(y_)
#y_out = tf.nn.sigmoid(y_)

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_out), reduction_indices=[1]))
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
#optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy)

# 准确率
# 每个样本的预测结果是一个(1,10)的vector
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_out, 1))
# tf.cast把bool值转换为浮点数
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
init = tf.global_variables_initializer()
#mnist = input_data.read_data_sets('MNIST/mnist', one_hot=True)
n_samples = int(1800)
total_batches = int(n_samples / batch_size)

#x_train = np.array(train_data[b'data']) / 255
x_train = np.array(train_data[b'data'])
y_train = np.array(pd.get_dummies(train_data[b'labels']))
#x_test = test_data[b'data'] / 255

转载于:https://www.cnblogs.com/TheKat/p/11115554.html

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

猫狗数据集 的相关文章

  • 下载 PyQt6 的 Qt Designer 并使用 pyuic6 将 .ui 文件转换为 .py 文件

    如何下载 PyQt6 的 QtDesigner 如果没有适用于 PyQt6 的 QtDesigner 我也可以使用 PyQt5 的 QtDesigner 但是如何将此 ui 文件转换为使用 PyQt6 库而不是 PyQt5 的 py 文件
  • Django REST序列化器:创建对象而不保存

    我已经开始使用 Django REST 框架 我想做的是使用一些 JSON 发布请求 从中创建一个 Django 模型对象 然后使用该对象而不保存它 我的 Django 模型称为 SearchRequest 我所拥有的是 api view
  • 如何在刻度标签和轴之间添加空间

    我已成功增加刻度标签的字体 但现在它们距离轴太近了 我想在刻度标签和轴之间添加一点呼吸空间 如果您不想全局更改间距 通过编辑 rcParams 并且想要更简洁的方法 请尝试以下操作 ax tick params axis both whic
  • DreamPie 不适用于 Python 3.2

    我最喜欢的 Python shell 是DreamPie http dreampie sourceforge net 我想将它与 Python 3 2 一起使用 我使用了 添加解释器 DreamPie 应用程序并添加了 Python 3 2
  • pandas 替换多个值

    以下是示例数据框 gt gt gt df pd DataFrame a 1 1 1 2 2 b 11 22 33 44 55 gt gt gt df a b 0 1 11 1 1 22 2 1 33 3 2 44 4 3 55 现在我想根据
  • 如何在Windows上模拟socket.socketpair

    标准Python函数套接字 套接字对 https docs python org 3 library socket html socket socketpair不幸的是 它在 Windows 上不可用 从 Python 3 4 1 开始 我
  • 为 pandas 数据透视表中的每个值列定义 aggfunc

    试图生成具有多个 值 列的数据透视表 我知道我可以使用 aggfunc 按照我想要的方式聚合值 但是如果我不想对两列求和或求平均值 而是想要一列的总和 同时求另一列的平均值 该怎么办 那么使用 pandas 可以做到这一点吗 df pd D
  • __del__ 真的是析构函数吗?

    我主要用 C 做事情 其中 析构函数方法实际上是为了销毁所获取的资源 最近我开始使用python 这真的很有趣而且很棒 我开始了解到它有像java一样的GC 因此 没有过分强调对象所有权 构造和销毁 据我所知 init 方法对我来说在 py
  • keras加载模型错误尝试将包含17层的权重文件加载到0层的模型中

    我目前正在使用 keras 开发 vgg16 模型 我用我的一些图层微调 vgg 模型 拟合我的模型 训练 后 我保存我的模型model save name h5 可以毫无问题地保存 但是 当我尝试使用以下命令重新加载模型时load mod
  • 当玩家触摸屏幕一侧时,如何让 pygame 发出警告?

    我使用 pygame 创建了一个游戏 当玩家触摸屏幕一侧时 我想让 pygame 给出类似 你不能触摸屏幕两侧 的错误 我尝试在互联网上搜索 但没有找到任何好的结果 我想过在屏幕外添加一个方块 当玩家触摸该方块时 它会发出警告 但这花了很长
  • 使用 OpenPyXL 迭代工作表和单元格,并使用包含的字符串更新单元格[重复]

    这个问题在这里已经有答案了 我想使用 OpenPyXL 来搜索工作簿 但我遇到了一些问题 希望有人可以帮助解决 以下是一些障碍 待办事项 我的工作表和单元格数量未知 我想搜索工作簿并将工作表名称放入数组中 我想循环遍历每个数组项并搜索包含特
  • Python:尝试检查有效的电话号码

    我正在尝试编写一个接受以下格式的电话号码的程序XXX XXX XXXX并将条目中的任何字母翻译为其相应的数字 现在我有了这个 如果启动不正确 它将允许您重新输入正确的数字 然后它会翻译输入的原始数字 我该如何解决 def main phon
  • Python - 在窗口最小化或隐藏时使用 pywinauto 控制窗口

    我正在尝试做的事情 我正在尝试使用 pywinauto 在 python 中创建一个脚本 以在后台自动安装 notepad 隐藏或最小化 notepad 只是一个示例 因为我将编辑它以与其他软件一起使用 Problem 问题是我想在安装程序
  • Numpy 优化

    我有一个根据条件分配值的函数 我的数据集大小通常在 30 50k 范围内 我不确定这是否是使用 numpy 的正确方法 但是当数字超过 5k 时 它会变得非常慢 有没有更好的方法让它更快 import numpy as np N 5000
  • 通过数据框与函数进行交互

    如果我有这样的日期框架 氮 EG 00 04 NEG 04 08 NEG 08 12 NEG 12 16 NEG 16 20 NEG 20 24 datum von 2017 10 12 21 69 15 36 0 87 1 42 0 76
  • 如何将 PIL 图像转换为 NumPy 数组?

    如何转换 PILImage来回转换为 NumPy 数组 这样我就可以比 PIL 进行更快的像素级转换PixelAccess允许 我可以通过以下方式将其转换为 NumPy 数组 pic Image open foo jpg pix numpy
  • 检查所有值是否作为字典中的键存在

    我有一个值列表和一本字典 我想确保列表中的每个值都作为字典中的键存在 目前我正在使用两组来确定字典中是否存在任何值 unmapped set foo set bar keys 有没有更Pythonic的方法来测试这个 感觉有点像黑客 您的方
  • 在 Python 类中动态定义实例字段

    我是 Python 新手 主要从事 Java 编程 我目前正在思考Python中的类是如何实例化的 我明白那个 init 就像Java中的构造函数 然而 有时 python 类没有 init 方法 在这种情况下我假设有一个默认构造函数 就像
  • Spark.read 在 Databricks 中给出 KrbException

    我正在尝试从 databricks 笔记本连接到 SQL 数据库 以下是我的代码 jdbcDF spark read format com microsoft sqlserver jdbc spark option url jdbc sql
  • Python - 字典和列表相交

    给定以下数据结构 找出这两种数据结构共有的交集键的最有效方法是什么 dict1 2A 3A 4B list1 2A 4B Expected output 2A 4B 如果这也能产生更快的输出 我可以将列表 不是 dict1 组织到任何其他数

随机推荐

  • RabbitMQ 消息有效期问题

    目录 一 默认情况 二 TTL Time To Live I TTL 的简介 II 单条消息过期 III 队列消息过期 IV 特殊情况 三 死信队列以及死信交换机 I 死信交换机 II 死信队列 III 具体操作 一 默认情况 在默认情况下
  • html 模板

    模板王 10000 免费网页模板 网站模板下载大全 mobanwang com http www mobanwang com
  • IEEE Transactions模板中参考文献作者缩写、期刊名缩写

    IEEE Transactions模板中参考文献作者缩写 期刊名缩写 本文章记录如何在IEEE Transactions的模板中 解决参考文献的作者缩写 期刊名字缩写的问题 目录 IEEE Transactions模板中参考文献作者缩写 期
  • python爬虫一:爬虫简介

    1 什么是爬虫 络爬 被称为 蜘蛛 络机器 就是模拟客户端发送 络请求 接收请求响应 种按照 定的规则 动地抓取互联 信息的程序 只要是浏览器能做的事情 原则上 爬 都能够做 可见即可爬 1 1爬虫有哪些用途 为其他数据提供数据源 像AI人
  • 数据挖掘的特点

    数据挖掘具有以下几个特点 1 基于大量数据 并非说小数据量上就不可以进行挖掘 实际上大多数数据挖掘的算法都可以在小数据量上运行并得到结果 但是 一方面过小的数据量完全可以通过人工分析来总结规律 另一方面来说 小数据量常常无法反映出真实世界中
  • kettle运行spoon.bat时找不到javaw文件

    我也遇到这问题了 分享一下解决方法吧以后没准还有人能用到 我机器的主要问题是环境变量JAVA HOME的值不对 应该写到jdk也就是C Program Files Java jdk1 7 0 25 并且 改完后要重启机器才行 这个很重要
  • DNS服务器的安装与配置

    一 DNS服务器的安装 步骤1 选择 开始 控制面板 添加或删除程序 添加 删除Windows组件 然后选取 网络服务 组件 再单击详细信息按钮 步骤2 选取 域名系统 DNS 组件后单击 确定 按钮 步骤3 回到前一个画面后 单击 下一步
  • vscode远程开发及公钥配置(告别密码登录)

    文章目录 vscode远程开发及公钥配置 简介 关于远程开发官网简介 关于SSH简介 环境 插件安装 配置服务器 找到配置文件 修改配置文件 连接服务器 配置密钥 简介 密钥生成 服务器上安装公钥 查看或配置打开密钥登录功能 服务器私钥复制
  • SSL/TLS 双向认证(一) -- SSL/TLS 工作原理

    本文部分参考 https www wosign com faq faq2016 0309 03 htm https www wosign com faq faq2016 0309 04 htm http blog csdn net hher
  • 四川计算机专业高职高考,四川职高计算机专业分数线

    类似问题答案 2016年贵州大学计算机类专业在四川录取分数线 学校 地 区 专业 年份 批次 类型 分数 贵州大学 四川 计算机类 2016 一批 理科 597 学校 地 区 专业 年份 批次 类型 分数 贵州大学 四川 计算机类 2016
  • 【华为OD统一考试B卷

    在线OJ 已购买本专栏用户 请私信博主开通账号 在线刷题 运行出现 Runtime Error 0Aborted 请忽略 华为OD统一考试A卷 B卷 新题库说明 2023年5月份 华为官方已经将的 2022 0223Q 1 2 3 4 统一
  • 分布式数据库核心原理 Zookeeper+Mysql

    原文 作者 1菩提行者1 笔者一直做java开发 由于技术演进做过大型微服务项目 微服务即将一个大的服务拆分成一个一个小的微服务 每个微服务自成生态 而在落地过程中紧紧只是应用层拆分 数据层往往用同一个库 有点形变神不变 当然将微服务与其对
  • JavaScript如何截取指定位置的字符串

    我们在日常开发中 经常需要对字符串进行删除截取增加的操作 我们这次说一下使用JavaScript截取指定位置的字符串 一 使用slice 截取 slice 方法可以通过指定的开始和结束位置 提取字符串的某个部分 并以新的字符串返回被提取的部
  • MIPI介绍(CSI DSI接口)

    MIPI介绍 CSI DSI接口 MIPI介绍 CSI DSI接口 视频接口 2 MIPI Solution mipi接口 缘来是你远去是我的博客 CSDN博客 MIPI LVDS RGB HDMI等接口对比 mipi和lvds区别 芒果5
  • socket编程

    socket 可以看做用户进程与内核网络协议栈的编程接口 可以用于本机进程间 网络上不同主机进程间的通信 对等通信 是全双工的 socket 异构系统 所以需要统一字节序统一后的字节序为大端字节序 x86为小端字节序 字节序转换函数 可以看
  • vsCode插件安装之汉化和浏览器打开

    一 汉化的方法 点击最左面第五个图标 在搜索框里面输入Chinese 点击如图第一个内容 点击Install 安装 安装后 重启软件即可 二 浏览器打开html 文件方法 在安装插件窗口搜索Browser 点击如图内容 点击install安
  • SpringBoot(十)SpringBoot自定义starter

    一个月的时间 转眼已经到了我的SpringBoot系列的第十篇文章 还记得我的第二篇文章SpringBoot 二 starter介绍 springboot的starter heart荼毒的博客 CSDN博客 曾经介绍过starter sta
  • mmdetection训练自己的VOC数据集 label=self.cat2label 报错解决方案

    废话不多说 直接上报错的图 看了GitHub上的大佬的回答 报错的原因是self cat2label值不对 所以根据大佬的建议 我print了self cat2label值 发现果然不对 类还是VOC数据集的类 而不是我自己的类 我的类是
  • ARM下高效C编程

    通过一定的风格来编写 C 程序 可以帮助 C 编译器生成执行速度更快的 ARM 代码 下面就是一些与性能相关的关键点 1 对局部变量 函数参数和返回值要使用 signed 和 unsigned int 类型 这样可以避免类型转换 而且可高效
  • 猫狗数据集

    import numpy as npimport pickleimport cv2import pandas as pdimport tensorflow as tfimport matplotlib pyplot as plt mnist