softmax用于分类问题/逻辑回归

2023-11-16

参考:d2l


请添加图片描述

线性回归问题最后输出一个参数用于预测。多分类问题最后输出多个维度的数据(多少个output_channels就有多少个类别)。

softmax是一种激活函数,它常见于分类问题的最后一层激活函数,目的是让输出属于一个概率密度函数。我们首先从数学方面推导了softmax函数;再自己手撕了一个分类网络;最后调用torch的API解决了softmax分类问题。

  1. 数学推导

  1. 从零实现softmax

使用数据集fashion-MNIST。

'''
从零开始手撸一个 softmax 分类问题

逻辑回归, softmax回归都是分类问题, 他们是强监督学习的另一个大类。对此他们采用的损失函数(cross-entropy loss交叉熵损失, 交叉熵损失可以通过信息论了解)和最后一层的激活函数(softmax)是不一样的。损失函数同样是由极大似然估计(MLE)得到的。

分类问题,网络模型最终的输出会有多个channel且引入one-hot编码的概念。使用Fashion-MNIST

手撸一个 softmax 分类问题
'''


import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from timer import Timer
from d2l import torch as d2l


# 1. 下载数据,并了解transforms.ToTensor()方法,查看一下数据等操作
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0〜1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="/home/yingmuzhi/_learning/d2l/data/Fashion_MNIST", train=True, transform=trans, download=False)
mnist_test = torchvision.datasets.FashionMNIST(
    root="/home/yingmuzhi/_learning/d2l/data/Fashion_MNIST", train=False, transform=trans, download=False)

print(len(mnist_train), len(mnist_test))

print(mnist_train[0][0].shape)


def get_fashion_mnist_labels(labels): #@save
    """返回Fashion-MNIST数据集的⽂本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]



# # 展示图片
# def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
#     """绘制图像列表"""
#     figsize = (num_cols * scale, num_rows * scale)
#     _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
#     axes = axes.flatten()
#     for i, (ax, img) in enumerate(zip(axes, imgs)):
#         if torch.is_tensor(img):
#             # 图⽚张量
#             ax.imshow(img.numpy())
#         else:
#             # PIL图⽚
#             ax.imshow(img)
#         ax.axes.get_xaxis().set_visible(False)
#         ax.axes.get_yaxis().set_visible(False)
#         if titles:
#             ax.set_title(titles[i])
#     return axes


batch_size = 256
def get_dataloader_workers(): #@save
    """使⽤4个进程来读取数据"""
    return 4
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers())


# 查看一下读取时间
timer = Timer()
for X, y in train_iter:
    continue
print(f'{timer.stop():.2f} sec')




# 2. 整合上面所有的组件 - 形成一个加载dataloader的脚本
def load_data_fashion_mnist(batch_size, resize=None): #@save
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="/home/yingmuzhi/_learning/d2l/data/Fashion_MNIST", train=True, transform=trans, download=False)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="/home/yingmuzhi/_learning/d2l/data/Fashion_MNIST", train=False, transform=trans, download=False)
    return (data.DataLoader(
        mnist_train, batch_size, shuffle=True,
        num_workers=get_dataloader_workers()),  # 设置num_workers的个数,使得DataLoader加载的数据更多,更快
        data.DataLoader(mnist_test, batch_size, shuffle=False,
        num_workers=get_dataloader_workers()))

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
    print(X.shape, X.dtype, y.shape, y.dtype)
    break


""" 
(3). 开始softmax分类,包括:

1. 加载DataLoader数据
2. 初始化权重
3. 定义激活函数softmax
4. 定义网络模型
5. 定义损失函数交叉熵损失
6. 用Accuracy做评价指标
"""

# 从0开始softmax 分类(逻辑回归)
batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)


# 初始化权重,正态分布weight 和 为0的bias
num_inputs = 784
num_outputs = 10
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True) 
b = torch.zeros(num_outputs, requires_grad=True)


# 使用sum()手动定义激活函数softmax
def softmax(X):
    X_exp = torch.exp(X)
    partition = X_exp.sum(1, keepdim=True)
    return X_exp / partition # 这⾥应⽤了⼴播机制

X = torch.normal(0, 1, (2, 5))
X_prob = softmax(X)
X_prob, X_prob.sum(1)


# 定义模型Y = softmax(Xw + b)
def net(X):
    return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)


# 定义损失函数:还是由极大似然估计得到的**交叉熵损失**作为 损失函数
y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y_hat[[0, 1], y]

def cross_entropy(y_hat, y):
    return - torch.log(y_hat[range(len(y_hat)), y])
cross_entropy(y_hat, y)


# 用Accuracy做评价指标
def accuracy(y_hat, y): #@save
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

def evaluate_accuracy(net, data_iter): #@save
    """计算在指定数据集上模型的精度"""
    if isinstance(net, torch.nn.Module):
        net.eval() # 将模型设置为评估模式
    metric = Accumulator(2) # 正确预测数、预测总数
    with torch.no_grad():
        for X, y in data_iter:
            metric.add(accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

class Accumulator: #@save
    """在n个变量上累加"""
    def __init__(self, n):
        self.data = [0.0] * n
    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]
    def reset(self):
        self.data = [0.0] * len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]

print(evaluate_accuracy(net, test_iter))


"""
(4). 开始训练

"""

# 开始训练
def train_epoch_ch3(net, train_iter, loss, updater): #@save
    """训练模型⼀个迭代周期(定义⻅第3章)"""
    # 将模型设置为训练模式
    if isinstance(net, torch.nn.Module):
        net.train()
    # 训练损失总和、训练准确度总和、样本数
    metric = Accumulator(3) # 累计计算
    for X, y in train_iter:
        # 计算梯度并更新参数
        y_hat = net(X)
        l = loss(y_hat, y)
        if isinstance(updater, torch.optim.Optimizer):
            # 使⽤PyTorch内置的优化器和损失函数
            updater.zero_grad()
            l.mean().backward()
            updater.step()
        else:# 使⽤定制的优化器和损失函数
            l.sum().backward()
            updater(X.shape[0])
            metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
    # 返回训练损失和训练精度
    return metric[0] / metric[2], metric[1] / metric[2]

# optimizer to minimize cost function
def sgd(params, lr, batch_size):
    """mini batchsize SGD"""
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()

def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater): #@save
    """训练模型(定义⻅第3章)"""
    # animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],
        # legend=['train loss', 'train acc', 'test acc'])
    for epoch in range(num_epochs):
        train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
        test_acc = evaluate_accuracy(net, test_iter)
        # animator.add(epoch + 1, train_metrics + (test_acc,))
    train_loss, train_acc = train_metrics
    assert train_loss < 0.5, train_loss
    assert train_acc <= 1 and train_acc > 0.7, train_acc
    assert test_acc <= 1 and test_acc > 0.7, test_acc

lr = 0.1
def updater(batch_size):
    return sgd([W, b], lr, batch_size)

num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)


"""
(5). 预测

"""

# 做预测
def predict_ch3(net, test_iter, n=6): #@save
    """预测标签(定义⻅第3章)"""
    for X, y in test_iter:
        break
    trues = get_fashion_mnist_labels(y)
    preds = get_fashion_mnist_labels(net(X).argmax(axis=1))
    titles = [true +'\n' + pred for true, pred in zip(trues, preds)]
    print(titles)

predict_ch3(net, test_iter)
  1. API softmax
'''
softmax的简洁表示

'''
import torch.nn as nn
import torch, d2l
"""(1). 初始化权重参数 """
# PyTorch不会隐式地调整输⼊的形状。因此,
# 我们在线性层前定义了展平层(flatten),来调整⽹络输⼊的形状

net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);


""" (2). 设置损失函数 """
loss = nn.CrossEntropyLoss(reduction='none')


""" (3). 设置优化函数 """
trainer = torch.optim.SGD(net.parameters(), lr=0.1)


""" (4). 训练 """
# 开始训练
num_epochs = 10
# d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer) # 不用d2l
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

softmax用于分类问题/逻辑回归 的相关文章

  • 人工智能边缘计算:连接智能的边界

    导言 人工智能边缘计算是将智能计算推向数据源头的重要发展方向 本文将深入探讨边缘计算与人工智能的交融 以及在未来数字化社会中的前景 1 边缘计算的基础 分布式计算 边缘计算通过将计算任务推送至数据产生的地方 实现更高效的分布式计算 低延迟通
  • 软件测试/测试开发/人工智能丨机器学习中特征的含义,什么是离散特征,什么是连续特征。

    在机器学习中 特征 Feature 是输入数据中的属性或变量 用于描述样本或数据点 特征对于机器学习模型而言是输入的一部分 模型通过学习样本的特征与其对应的标签 或输出 之间的关系来做出预测或分类 特征可以分为不同类型 其中两个主要的类型是
  • 【YOLO算法训练数据集处理】缩放训练图片的大小,同时对图片的标签txt文件中目标的坐标进行同等的转换

    背景 在训练一个自己的yolo模型目标检测模型时 使用公共数据集时 通常要将图片缩放处理 而此时图片对应的标签文件中目标的坐标也应进行同等的变换 这样才能保证模型的正确训练 当然 如果是自建的数据集 则将图片进行缩放后 使用Labelimg
  • .h5文件简介

    一 简介 HDF5 Hierarchical Data Format version 5 是一种用于存储和组织大量数据的文件格式 它支持高效地存储和处理大规模科学数据的能力 HDF5 是一种灵活的数据模型 可以存储多种数据类型 包括数值数据
  • 一网打尽目前常用的聚类方法,详细介绍了每一种聚类方法的基本概念、优点、缺点!!

    目前常用的聚类方法 1 K 均值聚类 K Means Clustering 2 层次聚类 Hierarchical Clustering 3 DBSCAN聚类 DBSCAN Clustering 4 谱聚类 Spectral Cluster
  • 基于生成式对抗网络的视频生成技术

    随着人工智能的快速发展 生成式对抗网络 GAN 作为一种强大的生成模型 已经在多个领域展现出了惊人的能力 其中 基于GAN的视频生成技术更是引起了广泛的关注 本文将介绍基于生成式对抗网络的视频生成技术的原理和应用 探索其对电影 游戏等领域带
  • HWSD中国土壤数据库

    数据名称 HWSD中国土壤数据库 数据时间 2009年 数据格式 Shp和Tiff 数据坐标系 WGS1984和krasovsky 1940 Albers 数据介绍 数据来源于联合国粮农组织 FAO 和维也纳国际应用系统研究所 IIASA
  • 2023年30米分辨率土地利用遥感监测数据

    改革开放以来 中国经济的快速发展对土地利用模式产生了深刻的影响 同时 中国又具有复杂的自然环境背景和广阔的陆地面积 其土地利用变化不仅对国家发展 也对全球环境变化产生了深刻的影响 为了恢复和重建我国土地利用变化的现代过程 更好地预测 预报土
  • 【最新】2023年30米分辨率土地利用遥感监测数据

    改革开放以来 中国经济的快速发展对土地利用模式产生了深刻的影响 同时 中国又具有复杂的自然环境背景和广阔的陆地面积 其土地利用变化不仅对国家发展 也对全球环境变化产生了深刻的影响 为了恢复和重建我国土地利用变化的现代过程 更好地预测 预报土
  • 机器学习之迁移学习(Transfer Learning)

    概念 迁移学习 Transfer Learning 是一种机器学习方法 其核心思想是将从一个任务中学到的知识应用到另一个相关任务中 传统的机器学习模型通常是从头开始训练 使用特定于任务的数据集 而迁移学习则通过利用已经在一个任务上学到的知识
  • 基于生成式对抗网络的视频生成技术

    随着人工智能的快速发展 生成式对抗网络 GAN 作为一种强大的生成模型 已经在多个领域展现出了惊人的能力 其中 基于GAN的视频生成技术更是引起了广泛的关注 本文将介绍基于生成式对抗网络的视频生成技术的原理和应用 探索其对电影 游戏等领域带
  • 第二部分相移干涉术

    典型干涉图 相移干涉术 相移干涉术的优点 1 测量精度高 gt 1 1000 条纹 边缘跟踪仅为 1 10 边缘 2 快速测量 3 低对比度条纹测量结果良好 4 测量结果不受瞳孔间强度变化的影响 独立于整个瞳孔的强度变化 5 在固定网格点获
  • 澳鹏干货解答!“关于机器学习的十大常见问题”

    探索机器学习的常见问题 了解机器学习和人工智能的基本概念 原理 发展趋势 用途 方法和所需的数据要求从而发掘潜在的商机 什么是机器学习 机器学习即教授机器如何学习的过程 为机器提供指导 帮助它们自己开发逻辑 访问您希望它们访问的数据 机器学
  • 详解数据科学自动化与机器学习自动化

    过去十年里 人工智能 AI 构建自动化发展迅速并取得了多项成就 在关于AI未来的讨论中 您可能会经常听到人们交替使用数据科学自动化与机器学习自动化这两个术语 事实上 这些术语有着不同的定义 如今的自动化机器学习 即 AutoML 特指模型构
  • 澳鹏干货解答!“关于机器学习的十大常见问题”

    探索机器学习的常见问题 了解机器学习和人工智能的基本概念 原理 发展趋势 用途 方法和所需的数据要求从而发掘潜在的商机 什么是机器学习 机器学习即教授机器如何学习的过程 为机器提供指导 帮助它们自己开发逻辑 访问您希望它们访问的数据 机器学
  • ResNet实战:CIFAR-10数据集分类

    本节将使用ResNet实现CIFAR 10数据集分类 7 2 1 CIFAR 10 数据集简介 CIFAR 10数据集共有60000幅彩色图像 这些图像是32 32像素的 分为10个类 每类6000幅图 这里面有50000幅用于训练 构成了
  • 图神经网络与智能教育:创新教育技术的未来

    导言 图神经网络 GNNs 和智能教育技术的结合为教育领域注入新活力 本文深入研究二者的结合可能性 涉及各自侧重 当前研究动态 技术运用 实际场景 未来展望 并提供相关链接 1 图神经网络与智能教育的结合方向 1 1 图神经网络在教育技术中
  • 机器学习算法实战案例:LSTM实现多变量多步负荷预测

    文章目录 1 数据处理 1 1 数据集简介 1 2 数据集处理 2 模型训练与预测 2
  • AI在广告中的应用——预测性定位和调整

    营销人员的工作就是在恰当的时间将适合的产品呈现在消费者面前 从而增加他们购买的可能性 随着时间的推移 营销人员能够深入挖掘越来越精准的客户细分市场 他们不仅具备了实现上述目标的能力 而且这种能力还在呈指数级提升 在AI技术帮助下 现在的营销
  • 5_机械臂运动学基础_矩阵

    上次说的向量空间是为矩阵服务的 1 学科回顾 从科技实践中来的数学问题无非分为两类 一类是线性问题 一类是非线性问题 线性问题是研究最久 理论最完善的 而非线性问题则可以在一定基础上转化为线性问题求解 线性变换 数域 F 上线性空间V中的变

随机推荐

  • STM32F103C8T6详细引脚表

    今天准备画一个STM32F103C8T6的最小系统板 就去STM32F103C8的数据手册查看了一下相应的引脚 因为数据手册里面的引脚表有中容量的多种封装描述 看上去比较麻烦 我就单独做了一个LQFP48脚的引脚表 方便后期自己画封装 就图
  • Spring+Mybatis 查询所有数据时发生异常:org.apache.ibatis.reflection.ReflectionException: There is no getter for

    Spring Mybatis框架整合时 根据条件查询数据 发生异常 Caused by org apache ibatis reflection ReflectionException There is no getter for prop
  • JavaScript分支语句总结

    注 js变量算术运算符和逻辑运算符知识点的补充 1 的区别 表示值相等 表示值相等 数据类型也必须相等 案例 的区别 表示值相等 表示值相等 数据类型也必须相等 var x 10 var y 10 console log x y true
  • 图像降质

    1 逆滤波和维纳滤波 附Matlab完整代码 https blog csdn net weixin 41730407 article details 80455612 2 python 运动模糊 退化模型 点扩散函数 逆滤波与维纳滤波 ht
  • GG-CNN代码学习

    文章目录 1 源码网址 https github com dougsm ggcnn 2 数据集格式转化 下载后的康奈尔数据集 解压完之后里面的格式 里面的 tiff图像通过 txt文件转化得到 python m utils dataset
  • layui 数据表格 sort排序,filter过滤——soulTable

    1 效果图 2 页面代码 div class fp table style margin left 0 5 width 86 table style margin bottom 0px table div 3 js代码 引入扩展组件 lay
  • 【学vue跟玩一样】快速搞懂vue渲染

    Vue的渲染分为条件渲染和列表渲染 那究竟什么式渲染呢 1 条件渲染 1 v if写法 1 v if 表达式 2 v else if 表达式 3 v else 表达式 和我们曾经学过的JavaScript里面的if语句几乎一样 适用于 切换
  • Quartz misfire详解

    一 前言 最近在学习Quartz 看到misfire这一部分 发现官方文档上讲解的很简单 没有看明白 然后去搜索了一下网上的讲解 发现讲的也都大同小异 也没有看明白 最后只能自己动手做测试 总结了一下 这篇文章把自己总结的记录下来 方便自己
  • 使用 HEX 参数在 Python 中实现六边形图像的显示数据关系

    使用 HEX 参数在 Python 中实现六边形图像的显示数据关系 在数据可视化中 六边形图被广泛应用于显示多元数据之间的关系 本文将介绍如何使用 Python 中的 hex 参数来设置六边形图像 并展示如何使用这种方法来显示数据的关系 首
  • Spring Boot —— Security 控制按钮权限

    文章目录 Spring Boot Security 控制按钮权限 前言 实现 引入对应的依赖 配置标签 Spring Boot Security 控制按钮权限 前言 在freemarker中 通过Security根据用户角色控制页面按钮或菜
  • win8.1仅允许运行使用网络级别身份认证的远程桌面计算机连接,使用Win10通过Mstsc远程连接 Server 2012 R2 时出现 身份验证错误,要求的函数不受支持,这可能是由于CredSSP...

    使用Win10通过Mstsc远程连接 Server 2012 R2 时出现 身份验证错误 要求的函数不受支持 这可能是由于CredSSP加密Oracle修正 最终解决方法 原因 因为CVE 2018 0886 的 CredSSP 2018
  • unity shader 之基础四 数学

    4 2 笛卡尔坐标系 笛卡尔坐标系分为二维和三维坐标系 4 2 1二位坐标系 OpenGL 和 DirectX 二位坐标系是不同的 OpenGL 和 DirectX 是不同的图形访问接口 用来和硬件交互的 二维坐标系 是可以相互转换的 既
  • 【经典】centos 安装 mysql

    CentOS第一次安装MySQL的完整步骤 目录 1 官方安装文档 2 下载 Mysql yum包 3 安转软件源 4 安装mysql服务端 5 首先启动mysql 6 接着检查mysql 的运行状态 7 修改临时密码 7 1 获取MySQ
  • [转] 英文写作中分号和冒号的使用

    我们先来了解下分号和冒号的作用 分号的主要作用是来连接两个在语法上平等的成分 冒号的主要作用是引起读者对冒号后面内容的注意力 下面总结下规则 用分号的情况 1 用分号连接两个独立的句子 两个独立的句子不能够用逗号隔开 如果用逗号 必须逗号后
  • idea忽略.iml文件

    1 点击file文件下的设置中 2 点下file types 文件类型 进入到file types窗口 如图 然后点击忽略文件那添加需要忽略的类型
  • 自用HTML+CSS学习笔记

    HTML CSS学习笔记 1 Web标准 Web标准也称为网页标准 由一系列的标准组成 大部分由W3C World Wide Web Consortium 万维网联盟 负责制定 由三个组成部分 HTML 负责网页的结构 页面元素和内容 CS
  • IT的教育

    IT的教育 李颜芯 CSDN的网友大家好 欢迎大家收看这一起的CSDN视频访谈节目 今天我们请到了两位嘉宾 一位是 金旭亮 老师 一位是 金戈 老师 两位老师作一下自我介绍怎么样 金旭亮 我先介绍一下吧 我叫金旭亮是北京理工大学的讲师 我在
  • 怎样把pdf转换成word-多语言ocr支持

    http jingyan baidu com article 86fae34699bb4e3c49121a23 html PDF格式良好的视觉阅读性和通用性使得PDF文件的使用越来越广泛了 网络上的PDF资料也越来越多 但是我们往往想要提出
  • 【大屏】 amap + echarts 踩坑以及避免办法

    amap echarts 踩坑以及避免办法 大屏 踩坑 代码 大屏 html body container margin 0 padding 0 width 5376px height 1944px background color 000
  • softmax用于分类问题/逻辑回归

    参考 d2l 线性回归问题最后输出一个参数用于预测 多分类问题最后输出多个维度的数据 多少个output channels就有多少个类别 softmax是一种激活函数 它常见于分类问题的最后一层激活函数 目的是让输出属于一个概率密度函数 我