手撕/手写/自己实现 BN层/batch norm/BatchNormalization python torch pytorch

2023-10-29

计算过程

在卷积神经网络中,BN 层输入的特征图维度是 (N,C,H,W), 输出的特征图维度也是 (N,C,H,W)
N 代表 batch size
C 代表 通道数
H 代表 特征图的高
W 代表 特征图的宽

我们需要在通道维度上做 batch normalization,
在一个 batch 中,
使用 所有特征图 相同位置上的 channel 的 所有元素,计算 均值和方差,
然后用计算出来的 均值和 方差,更新对应特征图上的 channel , 生成新的特征图

如下图所示:
对于4个橘色的特征图,计算所有元素的均值和方差,然后在用于更新4个特征图中的元素(原来元素减去均值,除以方差)
![[attachments/BN示意图.png]]

代码

def my_batch_norm_2d_detail(features, eps=1e-5):
    '''
        这个函数的写法是为了帮助理解 BatchNormalization 具体运算过程
        实际使用时这样写会比较慢
    '''
    
    n,c,h,w = features.shape
    features_copy = features.clone()
    running_var = torch.randn(c)
    running_mean = torch.randn(c)
    for ci in range(c):# 分别 处理每一个通道
        mean = 0 # 均值
        var = 0 # 方差
        
        _sum = 0 
        # 对一个 batch 中,特征图相同位置 channel 的每一个元素求和
        for ni in range(n):            
            for hi in range(h):
                for wi in range(w):
                    _sum += features[ni,ci, hi, wi]
        mean = _sum / (n * h * w) 
        running_mean[ci] = mean
        

        _sum = 0
        # 对一个 batch 中,特征图相同位置 channel 的每一个元素求平方和,用于计算方差 
        for ni in range(n):            
            for hi in range(h):
                for wi in range(w):
                    _sum += (features[ni,ci, hi, wi] - mean) ** 2
        var = _sum / (n * h * w )
        running_var[ci] = _sum / (n * h * w - 1)

        # 更新元素
        for ni in range(n):            
            for hi in range(h):
                for wi in range(w):
                    features_copy[ni,ci, hi, wi] = (features_copy[ni,ci, hi, wi] - mean) / torch.sqrt(var + eps) 
        
    return features_copy, running_mean, running_var

if __name__ == "__main__":


    torch.set_printoptions(precision=7)

    torch_bn = nn.BatchNorm2d(4)  # 设置 channel 数
    torch_bn.momentum = None
    features = torch.randn(4, 4, 2, 2) # (N,C,H,W)
        
    torch_bn_output = torch_bn(features)    
    my_bn_output, running_mean, running_var = my_batch_norm_2d_detail(features)        
            
    print(torch.allclose(torch_bn_output, my_bn_output))
    print(torch.allclose(torch_bn.running_mean, running_mean))
    print(torch.allclose(torch_bn.running_var, running_var))

注意事项

方差计算

需要注意的是,在训练的过程中,方差有两种不同的计算方式,

在训练时,用于更新特征图的是 有偏方差
而 running_var 的计算,使用的是 无偏方差
在这里插入图片描述

相关链接

官方人员手写BN

"""
Comparison of manual BatchNorm2d layer implementation in Python and
nn.BatchNorm2d

@author: ptrblck
"""

import torch
import torch.nn as nn


def compare_bn(bn1, bn2):
    err = False
    if not torch.allclose(bn1.running_mean, bn2.running_mean):
        print('Diff in running_mean: {} vs {}'.format(
            bn1.running_mean, bn2.running_mean))
        err = True

    if not torch.allclose(bn1.running_var, bn2.running_var):
        print('Diff in running_var: {} vs {}'.format(
            bn1.running_var, bn2.running_var))
        err = True

    if bn1.affine and bn2.affine:
        if not torch.allclose(bn1.weight, bn2.weight):
            print('Diff in weight: {} vs {}'.format(
                bn1.weight, bn2.weight))
            err = True

        if not torch.allclose(bn1.bias, bn2.bias):
            print('Diff in bias: {} vs {}'.format(
                bn1.bias, bn2.bias))
            err = True

    if not err:
        print('All parameters are equal!')


class MyBatchNorm2d(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(MyBatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

    def forward(self, input):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:
            mean = input.mean([0, 2, 3])
            # use biased var in train
            var = input.var([0, 2, 3], unbiased=False)
            n = input.numel() / input.size(1)
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean\
                    + (1 - exponential_average_factor) * self.running_mean
                # update running_var with unbiased var
                self.running_var = exponential_average_factor * var * n / (n - 1)\
                    + (1 - exponential_average_factor) * self.running_var
        else:
            mean = self.running_mean
            var = self.running_var

        input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        if self.affine:
            input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]

        return input


# Init BatchNorm layers
my_bn = MyBatchNorm2d(3, affine=True)
bn = nn.BatchNorm2d(3, affine=True)

compare_bn(my_bn, bn)  # weight and bias should be different
# Load weight and bias
my_bn.load_state_dict(bn.state_dict())
compare_bn(my_bn, bn)

# Run train
for _ in range(10):
    scale = torch.randint(1, 10, (1,)).float()
    bias = torch.randint(-10, 10, (1,)).float()
    x = torch.randn(10, 3, 100, 100) * scale + bias
    out1 = my_bn(x)
    out2 = bn(x)
    compare_bn(my_bn, bn)

    torch.allclose(out1, out2)
    print('Max diff: ', (out1 - out2).abs().max())

# Run eval
my_bn.eval()
bn.eval()
for _ in range(10):
    scale = torch.randint(1, 10, (1,)).float()
    bias = torch.randint(-10, 10, (1,)).float()
    x = torch.randn(10, 3, 100, 100) * scale + bias
    out1 = my_bn(x)
    out2 = bn(x)
    compare_bn(my_bn, bn)

    torch.allclose(out1, out2)
    print('Max diff: ', (out1 - out2).abs().max())
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

手撕/手写/自己实现 BN层/batch norm/BatchNormalization python torch pytorch 的相关文章

  • 没有名为 crypto.cipher 的模块

    我现在正在尝试加密一段时间 我最近得到了这个基于 python 的密码器 名为PythonCrypter https github com jbertman PythonCrypter 我对 Python 相当陌生 当我尝试通过终端打开 C
  • 通过 Scrapy 抓取 Google Analytics

    我一直在尝试使用 Scrapy 从 Google Analytics 获取一些数据 尽管我是一个完全的 Python 新手 但我已经取得了一些进展 我现在可以通过 Scrapy 登录 Google Analytics 但我需要发出 AJAX
  • Python 的键盘中断不会中止 Rust 函数 (PyO3)

    我有一个使用 PyO3 用 Rust 编写的 Python 库 它涉及一些昂贵的计算 单个函数调用最多需要 10 分钟 从 Python 调用时如何中止执行 Ctrl C 好像只有执行结束后才会处理 所以本质上没什么用 最小可重现示例 Ca
  • Python(Selenium):如何通过登录重定向/组织登录登录网站

    我不是专业程序员 所以请原谅任何愚蠢的错误 我正在做一些研究 我正在尝试使用 Selenium 登录数据库来搜索大约 1000 个术语 我有两个问题 1 重定向到组织登录页面后如何使用 Selenium 登录 2 如何检索数据库 在我解决
  • 如何在 Python 中检索 for 循环中的剩余项目?

    我有一个简单的 for 循环迭代项目列表 在某些时候 我知道它会破裂 我该如何退回剩余的物品 for i in a b c d e f g try some func i except return remaining items if s
  • python 相当于 R 中的 get() (= 使用字符串检索符号的值)

    在 R 中 get s 函数检索名称存储在字符变量 向量 中的符号的值s e g X lt 10 r lt XVI s lt substr r 1 1 X get s 10 取罗马数字的第一个符号r并将其转换为其等效整数 尽管花了一些时间翻
  • 根据列值突出显示数据框中的行?

    假设我有这样的数据框 col1 col2 col3 col4 0 A A 1 pass 2 1 A A 2 pass 4 2 A A 1 fail 4 3 A A 1 fail 5 4 A A 1 pass 3 5 A A 2 fail 2
  • Python 函数可以从作用域之外赋予新属性吗?

    我不知道你可以这样做 def tom print tom s locals locals def dick z print z name z name z guest Harry print z guest z guest print di
  • 在Python中获取文件描述符的位置

    比如说 我有一个原始数字文件描述符 我需要根据它获取文件中的当前位置 import os psutil some code that works with file lp lib open path to file p psutil Pro
  • IO 密集型任务中的 Python 多线程

    建议仅在 IO 密集型任务中使用 Python 多线程 因为 Python 有一个全局解释器锁 GIL 只允许一个线程持有 Python 解释器的控制权 然而 多线程对于 IO 密集型操作有意义吗 https stackoverflow c
  • 如何在Python中对类别进行加权随机抽样

    给定一个元组列表 其中每个元组都包含一个概率和一个项目 我想根据其概率对项目进行采样 例如 给出列表 3 a 4 b 3 c 我想在 40 的时间内对 b 进行采样 在 python 中执行此操作的规范方法是什么 我查看了 random 模
  • 将图像分割成多个网格

    我使用下面的代码将图像分割成网格的 20 个相等的部分 import cv2 im cv2 imread apple jpg im cv2 resize im 1000 500 imgwidth im shape 0 imgheight i
  • 如何在 Python 中追加到 JSON 文件?

    我有一个 JSON 文件 其中包含 67790 1 kwh 319 4 现在我创建一个字典a dict我需要将其附加到 JSON 文件中 我尝试了这段代码 with open DATA FILENAME a as f json obj js
  • 有没有办法检测正在运行的代码是否正在上下文管理器内执行?

    正如标题所述 有没有办法做到这样的事情 def call back if called inside context print running in context else print called outside context 这将
  • 有人用过 Dabo 做过中型项目吗? [关闭]

    Closed 这个问题是基于意见的 help closed questions 目前不接受答案 我们正处于一个新的 ERP 风格的客户端 服务器应用程序的开始阶段 该应用程序是作为 Python 富客户端开发的 我们目前正在评估 Dabo
  • Python:如何将列表列表的元素转换为无向图?

    我有一个程序 可以检索 PubMed 出版物列表 并希望构建一个共同作者图 这意味着对于每篇文章 我想将每个作者 如果尚未存在 添加为顶点 并添加无向边 或增加每个合著者之间的权重 我设法编写了第一个程序 该程序检索每个出版物的作者列表 并
  • 如何计算 pandas 数据帧上的连续有序值

    我试图从给定的数据帧中获取连续 0 值的最大计数 其中包含来自 pandas 数据帧的 id date value 列 如下所示 id date value 354 2019 03 01 0 354 2019 03 02 0 354 201
  • Scrapy:如何使用元在方法之间传递项目

    我是 scrapy 和 python 的新手 我试图将 parse quotes 中的项目 item author 传递给下一个解析方法 parse bio 我尝试了 request meta 和 response meta 方法 如 sc
  • 在 Qt 中自动调整标签文本大小 - 奇怪的行为

    在 Qt 中 我有一个复合小部件 它由排列在 QBoxLayouts 内的多个 QLabels 组成 当小部件调整大小时 我希望标签文本缩放以填充标签区域 并且我已经在 resizeEvent 中实现了文本大小的调整 这可行 但似乎发生了某
  • 如何使用 Pycharm 安装 tkinter? [复制]

    这个问题在这里已经有答案了 I used sudo apt get install python3 6 tk而且效果很好 如果我在终端中打开 python Tkinter 就可以工作 但我无法将其安装在我的 Pycharm 项目上 pip

随机推荐

  • IDEA导入Eclipse项目步骤详解

    IDEA导入Eclipse项目步骤详解 文章目录 IDEA导入Eclipse项目步骤详解 首先在idea里file gt new gt Project from Existing Sources 选中到要导入的项目 这里我选用创建新的 Cl
  • 情感分析概述

    情感分析主要研究观点挖掘 倾向性分析等 一 为什么需要观点挖掘和倾向性分析 文本信息主要包括两类 客观性事实 主观性观点 但是已有的文本分析方法主要侧重在客观性文本内容的分析和挖掘 二 什么是观点挖掘与倾向性分析 观点挖掘与倾向性分析就是从
  • Java多线程进阶(十九)—— J.U.C之synchronizer框架:CyclicBarrier

    本文首发于一世流云专栏 https segmentfault com blog 一 CyclicBarrier简介 CyclicBarrier是一个辅助同步器类 在JDK1 5时随着J U C一起引入 这个类的功能和我们之前介绍的Count
  • Jmeter录制脚本

    性能关注点 接口响应时间 50毫秒 1000毫秒 吞度量 10000万每天 tPs 每秒处理事务数 压测需求与业务操作步骤 压测对象 http news baidu com 压测页面 首页 国际频道 财经频道 步骤 访问首页 单击 国际频道
  • 测试用例的优先级

    刚接触软件测试 先熟悉一下测试用例的优先级的概念 有时会听到0级别case的说法 其实这是对具有一定优先级的测试用例的说法 在这际测试实践中 测试用例根据重要性分成一定的等级 在不通的公司 可能测试用例的等级划分有所差异 但是基本大同小异
  • 积分计算两条曲线围绕y坐标轴旋转形成的立体体积

    积分计算两条曲线围绕y坐标轴旋转形成的立体体积 和附录文章1类似 计算两条曲线y x 2和y 2x围绕y坐标轴形成的立方体体积 首先要计算积分的上限和下限 根据两者相交的点求出 0 4 外层大圆R y y 1 2 和内层小圆r y y 2的
  • 使用iptables进行入站流量过滤

    iptables是Linux内置的流量过滤工具 同时也是多种防火墙的底层实现 如fw3 在本次应用中 iptables通过丢弃不符合规则的数据包 使得未注册设备在DHCP获取ip阶段失败 无法连接到专用内网 保证系统安全 iptables使
  • 10年软件测试工程师感悟——写给还在迷茫中的朋友

    这两天和朋友谈到软件测试的发展 其实软件测试已经在不知不觉中发生了非常大的改变 前几年的软件测试行业还是一个风口 随着不断地转行人员以及毕业的大学生疯狂地涌入软件测试行业 目前软件测试行业 缺口 已经基本饱和 当然 我说的是最基础的功能测试
  • QT之D指针

    什么是D指针 如果你已经看过到Qt源码 你会发现它经常使用Q D和Q Q 宏 本文介绍了这些宏的用途 该Q D和Q Q宏是一个设计模式的一部分被称为d 指针 也称为 不透明的指针 其中一个库的实现细节可以从它的用户 并转移到执行被隐藏 另外
  • LLVM每日谈之二 LLVM IR

    作者 snsn1984 在介绍LLVM IR之前 我们需要先了解下LLVM的结构 传统的静态编译器分为三个阶段 前端 优化和后端 LLVM的三阶段设计是这样的 这样做的优点是如果需要支持一种新的编程语言 那么我们只需要实现一种新的前端 如果
  • 0基础java入门:第二十五节.面向对象思想理解思路。

    0基础java入门 第二十五节 面向对象思想理解思路 本章需要时间和代码积累才能理解通透 不要着急 先来了解 敲上三年代码再回来看 面向对象是现在大部分编程语言中都会提及和使用到的一种思想方式 有人说很难理解 但个人觉得其实不难 因为面向对
  • element ui tabs 修改成hover触发点击

    Element UI tabs标签页 将点击选择改成鼠标指到就点击 类似hover 1 单个组件 在el tabs里添加个ref 删去el tab pane里的 name绑定 然后在mounted里添加代码 mounted this nex
  • f12获取网页文本_网页上的文字不能复制怎么办?有这5招轻松复制

    有时候我们需要一些辅助资料时 会经常使用搜索工具查坎相关网页文件 但遇到一些需要用到的段落却不能直接复制时 一个字一个字的敲肯定是不现实 有什么方法可以让其直接进行复制呢 方法1 打印网页 这种方式相对比较简单 而且电脑也不需要真的安装打印
  • 串行通信协议---HART协议

    实际应用中 HART协议是仅次于Modbus协议的最接近统一现场总线的标准 主要是在4 20mA电流信号上面叠加数字信号 物理层采用Bell 202标准的FSK技术成功实现模拟信号和数字信号双向同时通信而互不干扰 HART协议规定了传输的物
  • 怎么启用windwos无线网驱动

    重启windwos无线网驱动 说明 进入系统窗口 打开设备管理器 在设备管理器目录中找到网络适配器 找到 Realtek 8822BE Wireless LAN 802 11ac PCI ENIC 左键选中Realtek 8822BE Wi
  • 【QT5】tslib移植

    tslib全称应该是Touch Screen Library 也就是专门针对触摸屏创建的开源库 tslib的最新工程的github地址为 https github com libts tslib 感谢牛人的开源工程 clone下来 进入源码
  • 使用Visual Studio开发Linux程序

    首先我们使用visual studio创建项目 这里我使用的是visual studio 2022 visual studio 2019的也一样 如下创建项目即可 然后我们需要在visual studio中连接我们的Linux服务器 点击
  • 刷脸支付顺应时代各种优惠政策出现

    相比于人工合成的二维码扫码支付 刷脸支付采用的是生物信息识别技术 在安全性上后者要比前者高很多 刷脸支付自从出世以来就受到广大创业者 商家的关注 自从去年支付宝推出刷脸支付并在实体店投入运营 到今年刷脸支付得到快速的发展 微信也加入刷脸支付
  • 后台运行VirtualBox虚拟机

    运行一个VirtualBox虚拟机最常见的方式是 打开VirtualBox 点击对应的虚拟机来运行 使用这种传统方式运行的虚拟机通常都有一个前台界面 可以像操作本地电脑一样进行操作 但是Linuxer有时候更喜欢通过终端远程接入 而不是在虚
  • 手撕/手写/自己实现 BN层/batch norm/BatchNormalization python torch pytorch

    计算过程 在卷积神经网络中 BN 层输入的特征图维度是 N C H W 输出的特征图维度也是 N C H W N 代表 batch size C 代表 通道数 H 代表 特征图的高 W 代表 特征图的宽 我们需要在通道维度上做 batch