多层感知机的简洁实现

2023-11-10

import torch
from torch import nn
from torch.nn import init
import numpy as np
import sys
import torchvision
from torchvision import transforms
import time

class FlattenLayer(nn.Module):
    def __init__(self):
        super(FlattenLayer, self).__init__()
    def forward(self, x): # x shape: (batch, *, *, ...)
        return x.view(x.shape[0], -1)

def load_data_fashion_mnist(batch_size):
    mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True,
                                                    transform=transforms.ToTensor())
    mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True,
                                                   transform=transforms.ToTensor())

    if sys.platform.startswith('win'):
        num_workers = 0  # 0表示不用额外的进程来加速读取数据
    else:
        num_workers = 4
    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    start = time.time()
    for X, y in train_iter:
        continue
    print('load time: %.2f sec' % (time.time() - start))
    return train_iter, test_iter

def evaluate_accuracy(data_iter, net):
    acc_sum, n = 0.0, 0
    for X, y in data_iter:
        acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
        n += y.shape[0]
    return acc_sum / n

def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,
              params=None, lr=None, optimizer=None):
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        for X, y in train_iter:
            y_hat = net(X)
            l = loss(y_hat, y).sum()

            # 梯度清零
            if optimizer is not None:
                optimizer.zero_grad()
            elif params is not None and params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()
            l.backward()
            optimizer.step()  # “softmax回归的简洁实现”一节将用到


            train_l_sum += l.item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
            n += y.shape[0]  #有多少行标签就是有多少个样本
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
              % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))

num_inputs, num_outputs, num_hiddens =  784, 10, 256

net = nn.Sequential(
    FlattenLayer(),
    nn.Linear(num_inputs,num_hiddens),
    nn.ReLU(),
    nn.Linear(num_hiddens,num_outputs),
)

for params in net.parameters():
    init.normal_(params, mean=0, std=0.01)

batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)

num_epochs = 5

train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

多层感知机的简洁实现 的相关文章

  • python 类装饰器的好处_python 装饰器重要在哪

    1 什么是装饰器 要理解什么是装饰器 您首先需要熟悉Python处理函数的方式 从它的观点来看 函数和对象没有什么不同 它们有属性 可以重新分配 def func print hello from func func gt hello fr
  • 在QT的信号槽中使用自定义数据类型

    qt中使用信号槽来处理GUI与后台数据同步是不错的 耗时的任务可以在处理完数据后使用信号通知UI更新 对于qt中的已有类型 可以直接使用 但 多数时候都需要用到自定义类型 如果像内建类型那样使用 编译时正常 但运行时会报错 QObject
  • javax.xml.parsers.FactoryConfigurationError: Provider for javax.xml.parsers.

  • linux 打开.v文件,构建riscv上运行的linux系统

    在qemu上启动linux kernel 总述 最终目标还是要在RTL上跑linux系统 但做这个之前第一步先把系统工具链整清楚很重要 所以先在qemu上把相关的工具链 镜像搞定 为了完全这项任务 我们需要安装几个工具 qemu for r
  • VIM编辑命令

    转载 进入编辑模式 vim命令模式 vim实践 吉米乐享驿站 博客园 cnblogs com 5 5 进入编辑模式 所谓编辑模式就是进入到一个可以编辑文本文档的模式 常规的方式就是按小i进入编辑模式 左下角显示 insert插入 状态 此时
  • Java基础篇——入门

    转眼间这已经是自己工作的第五个年头了 期间做过安卓 做过web 现在又加入了小程序的阵营 可谓是历尽各种坎坷啊 让我不由得想起了 红楼梦 的开篇之句 满纸荒唐言 一把辛酸泪 期间的艰难困苦 实属夸张之感慨 无病呻吟 哈哈 只有自己能够体味啊
  • 语义分割制作自己的数据集——训练集、验证集、测试集

    用于语义分割的VOC数据集格式 语义分割任务voc数据集主要包括JPEGImages 存放原始图像 SegmentationClass 存放label ImageSets Segmentation 存放划分的数据集 包括train txt
  • VS2008, MFC 文件的操作3 - Win32 API 方式 文本方式打开

    接上一节笔记 VS2008 MFC 文件的操作2 C 语言方式 文本方式打开 1 代码 void Cvs2008 SX jiaocheng12View OnFileWritefile TODO 在此添加命令处理程序代码 Win32 API
  • cuda编程学习笔记 第二章 cuda memory management

    应用的性能可能有 75 都花费在内存相关问题上 NVPROF and NVVP 这俩是调试工具 不知道是不是基于CUPTI CUDA Profiler Tools Interface NVPROF是命令行工具 nvvp是可视化工具 nvvp
  • 12、适配器

    文章目录 package com example demo designpattern 又叫包装模式 Wrapper 各种 wrapper bridge 就是适配器模式 jbdc odbc bridge io 字节流字符流转换 角色 tar
  • STM32-ESP8266-12F与PC通信

    1 默认ESP8266的波特率是115200 2 指令及其返回值 3 使用PC的网络调试助手 协议类型选择TCP Server 端口号以80开头 表示TCP协议 如8080 8040等等 IP地址填PC的WI F网口的IP地址 配置完成后点
  • 解决浏览器设置代理IP无法上网的问题

    大家都知道 在当今信息时代 互联网已经成为了我们生活必不可少的一部分 而浏览器作为我们上网的窗口 更是被广泛使用 有时候 我们会遇到一些问题 例如设置了代理IP后无法正常上网 那么该如何解决这个问题呢 别担心 本文将为您一一解答 首先 让我
  • 易优cms:guestbookform 留言表单标签

    guestbookform 留言表单标签 基础用法 名称 guestbookform 功能 留言表单提交 语法 eyou guestbookform type default
  • 计算机组成中的阶符是什么意思,计算机中阶符,阶码,数符,尾数是什么?

    一般地 任一个二进制N 可表示为N 2j S 其中J为二进制数 叫阶码 J如果有正负号的话 正负号就叫阶符 S为纯小数 叫做尾数 数符 指的是N整个数的符号 二进制的 00101000 直接可以转换成16进制的 28 字节是电脑中的基本存储
  • 详解 Android 是如何启动的

    详解 Android 是如何启动的 2016 08 12 唐琪森 安卓开发 javascript void 0 来自 石头铺 微信号 Android Programmer 网站 www woaitqs cc 本文是 Android 系统学习
  • xilinx平台下DDR3映射为VFIFO

    FPGA开发中 数据采集 数据分析场景下需要用对高速ADC数据缓存 FPGA片内RAM无法做到大的容量 基于MIG IP做了个DDR3映射成FIFO的模块 以完成高速 量大的数据缓存应用 背景和选择 part1 官方也提供了类似功能的IP
  • makefile学习

    基本介绍 makefile编写的关键在于解决源文件的 文件依赖性 编译链接过程 源文件首先会生成中间目标文件 再由中间目标文件生成执行文件 在编译时 编译器只检测程序语法 和函数 变量是否被声明 如果函数未声明 编译器会给出一个警告 但可以
  • VS Code 打开时黑屏的恢复处理

    VS Code安装后一直黑屏的情况 一 兼容模式 Win10版本以下系统 右击VS Code打开属性窗口并在兼容性标签页内勾上以兼容模式运行这个程序 二 自动选择显卡 VS Code的渲染跟显卡设定有一定关系 打开NVIDIA控制面板 调整
  • shell脚本对硬盘进行分区——fdisk、blkid、mke2fs、mount、lsblk

    1 前言 本文介绍的是嵌入式设备烧录系统时 如何用shell脚本对硬盘进行分区 文章主要介绍的是制作烧录U盘的分区思路和关键的shell脚本语句 代码并不能直接拷贝使用 2 总体思路 1 用U盘进行系统的烧录 就是在U盘上制作一个可以运行的

随机推荐

  • (必备技能)使用Python实现屏幕截图

    必备技能 使用Python实现屏幕截图 文章目录 必备技能 使用Python实现屏幕截图 一 序言 二 环境配置 1 下载pyautogui包 2 下载opencv python包 3 下载PyQt5包 4 下载pypiwin32包 三 屏
  • CIFAR10数据集使用笔记

    CIFAR10数据集 1 数据集下载并转换为张量 train set torchvision datasets CIFAR10 root data path train True download True transform transf
  • JAVA语言强制类型转换要求

    JAVA语言强制类型转换要求 数据类型具有高低性的 顺序由低到高为 byte gt short gt char gt int gt long gt float gt double 1 由低到高需要强制类型转换 转换方式如下 public c
  • SIP中继对接

    freeswitch与各种设备对接的成功配置 需要的请参考 有错误的地方请指导 1 对接华为softco 中继配置 sip profiles external
  • 初探设计模式之Adapter模式

    文章目录 设计模式之Adapter模式 一 什么是Adapter模式 二 具体实例 1 使用Banner来表示高电压插座 2 使用Print来表示低电压电器 3 使用PrintBanner来表示适配器 使用的是继承 4 总体结构如下图所示
  • vivado.2019.1 安装教程

    vivado 2019 1 安装教程 下载链接 VIVADIO 2019 1 链接 https pan baidu com s 17 cPUahNzHmm 3xKsKQ7GQ 提取码 rop0 来自百度网盘超级会员V4的分享 1 解压所有文
  • JS实现随机抽奖功能

    点击开始按钮开始抽奖 div依次变红 下面是js代码 需要的自取
  • MLOps极致细节:4. MLFlow Projects 案例介绍(Gitee代码链接)

    MLOps极致细节 4 MLFlow Projects 案例介绍 Gitee代码链接 MLFlow Projects允许我们将代码及其依赖项打包为一个可以在其他平台上以可复制 reproducible 和可重用 reusable 的方式运行
  • 对称群与置换群 定义

    我刚接触抽象代数的那段时间 一直在考虑一个问题 抽象代数有什么实际应用 后来听说 群在研究一些具有对称性质的对象时有奇效 于是我试着用群去描述一些简单的几何变换 发现确实如此 这就是我在置换那篇文章的最后让大家思考等边三角形变换的原因 如果
  • C++多态概念和意义

    目录 一 什么叫重写 二 面向对象期望的重写 1 示例分析 2 所期望的重写 三 多态的概念和意义 1 多态的概念 2 C 如何支持多态概念 3 多态内部运行剖析 4 多态的意义 5 修改示例代码 四 静态联编和动态联编 五 小结 一 什么
  • DragGAN报错Setting up PyTorch plugin “bias_act_plugin“... Failed!和FAILED: bias_act.cuda.o解决办法

    问题 DragGAN终于开源了 于是下载安装结果报错了 查了一大堆资料 都没有解决办法 于是安装了个ChatGLM2 6B 在上面将自己的问题粘贴上去 给出了解决方案 结果直接解决了一天没有解决的问题 下面附上运行之后报的错误 File u
  • nestjs:改变debug端口

    目的 多个项目 如果不改调试端口 会出现无法同时调试的情况 说明 nest start debug port port 不写默认为9229
  • C++11列表初始化

    2023年7月17日 周一上午 今天在看GitHub上的源码时看到了这种用法 于是研究了一下 并把自己的研究成果记录成博客 目录 C 11为什么要推出列表初始化 举例说明 统一初始化语法 对象和容器的初始化得以用一种统一的方式来进行 防止窄
  • glsl语法整理

    glsl 语法 main 方法表示入口函数 标量 在GLSL中标量只有bool int和float三种 向量 共有vec2 vec3 vec4 ivec2 ivec3 ivec4 bvec2 bvec3和bvec4九种类型
  • MySQL8 EXPLAIN 命令输出的都是什么东西?这篇超详细!

    引子 小扎刚毕业不久 在一家互联网公司工作 由于是新人 做的也都是简单的CRUD 刚来的时候还有点不适应 做了几个月之后 就变成了熟练工了 左复制 右粘贴 然后改改就是自己的代码了 生活真美好 有一天 领导说他做的有个列表页面速度很慢 半天
  • 结构体的总结

    目录 一 结构体的定义 二 对结构体的重命名 三 结构体和指针的结合 四 结构体和数组的结合 五 结构体大小 六 结构体嵌套 七 动态内存与结构体 八 总结 用结构体对学生成绩实行升序排序 一 结构体的定义 我们之前接触的数据类型有 基本数
  • GreenDAO数据库版本升级

    GreenDAO在进行默认的数据库升级时 会采取先删除所有的表 再全部重新建的操作 这就意味着所有的数据都会遗失 public void onUpgrade SQLiteDatabase db int oldVersion int newV
  • AttributeError: ‘NoneType‘ object has no attribute ‘shape‘

    在运行训练文件时 出现了这样的问题 AttributeError NoneType object has no attribute shape 后来参考了大神文章后发现是因为有的text文件路径不对 改了文件路径后运行没问题了 还可能有以下
  • 卡尔曼滤波(Kalman filter)及预测

    参考文章 https blog csdn net baidu 38172402 article details 82289998 https www jianshu com p 2768642e3abf kalman滤波的作用 1 数据 滤
  • 多层感知机的简洁实现

    import torch from torch import nn from torch nn import init import numpy as np import sys import torchvision from torchv