用C++实现softmax函数(面试经验)

2023-11-10

背景

今天面试字节算法岗时被问到的问题,让我用C++实现一个softmax函数。softmax是逻辑回归在多分类问题上的推广。大概的公式如下:
i n p u t : { x 1 , x 2 , ⋯   , x n } s o f t m a x ( x t ) = e x t ∑ i = 1 n e x i input: \{x_1, x_2,\cdots, x_n\}\\ softmax(x_t)=\frac{e^{x_t}}{\sum_{i=1}^{n}e^{x_i}} input:{x1,x2,,xn}softmax(xt)=i=1nexiext
即判断该变量在总体变量中的占比。

第一次实现

实现

我们用vector来封装输入和输出,简单的按公式复现。

vector<double> softmax(vector<double> input)
{
	double total=0;
	for(auto x:input)
	{
		total+=exp(x);
	}
	vector<double> result;
	for(auto x:input)
	{
		result.push_back(exp(x)/total);
	}
	return result;
}

测试

test 1

  • 测试用例1: {1, 2, 3, 4, 5}
  • 测试输出1: {0.0116562, 0.0316849, 0.0861285, 0.234122, 0.636409}

经过简单测试是正常的。
在这里插入图片描述

test 2

但是这时面试官提出了一个问题,即如果有较大输入变量时会怎么样?

  • 测试用例2: {1, 2, 3, 4, 5, 1000}
  • 测试输出2: {0, 0, 0, 0, 0, nan}

由于 e 1000 e^{1000} e1000已经溢出了双精度浮点(double)所能表示的范围,所以变成了NaN(not a number)。
在这里插入图片描述

第二次实现(改进)

改进原理

我们注意观察softmax的公式:
i n p u t : { x 1 , x 2 , ⋯   , x n } s o f t m a x ( x t ) = e x t ∑ i = 1 n e x i input: \{x_1, x_2,\cdots, x_n\}\\ softmax(x_t)=\frac{e^{x_t}}{\sum_{i=1}^{n}e^{x_i}} input:{x1,x2,,xn}softmax(xt)=i=1nexiext
如果我们给上下同时乘以一个很小的数,最后答案的值是不变的。
那我们可以给每一个输入 x i x_i xi都减去一个值 a a a,防止爆精度。
大致表示如下:
e x t ∑ i = 1 n e x i = e x t ⋅ e − a e − a ⋅ ∑ i = 1 n e x i = e x t ⋅ e − a ∑ i = 1 n e x i ⋅ e − a = e x t − a ∑ i = 1 n e x i − a \frac{e^{x_t}}{\sum_{i=1}^{n}e^{x_i}}= \frac{e^{x_t}\cdot e^{-a}}{e^{-a}\cdot \sum_{i=1}^{n}e^{x_i}}= \frac{e^{x_t}\cdot e^{-a}}{ \sum_{i=1}^{n}e^{x_i}\cdot e^{-a}}= \frac{e^{x_t-a}}{ \sum_{i=1}^{n}e^{x_i-a}} i=1nexiext=eai=1nexiextea=i=1nexieaextea=i=1nexiaexta
那我们如何取这个 a a a的值呢?直接取输入中最大的那个即 m a x ( x i ) max(x_i) max(xi)就好啦,这样所有的 e x i − a e^{x_i-a} exia的值都不会超过 e 0 = 1 e^0=1 e0=1,更不可能爆精度了。

实现

vector<double> softmax(vector<double> input)
{
	double total=0;
	double MAX=input[0];
	for(auto x:input)
	{
		MAX=max(x,MAX);
	}
	for(auto x:input)
	{
		total+=exp(x-MAX);
	}
	vector<double> result;
	for(auto x:input)
	{
		result.push_back(exp(x-MAX)/total);
	}
	return result;
}

测试

test 1

  • 测试用例1: {1, 2, 3, 4, 5, 1000}
  • 测试输出1: {0, 0, 0, 0, 0, 1}
    在这里插入图片描述

test 2

  • 测试用例1: {0, 19260817, 19260817}
  • 测试输出1: {0, 0.5, 0.5}

在这里插入图片描述

我们发现结果正常了。

完整代码

#include <iostream>
#include <vector>
#include <math.h>
using namespace std;

vector<double> softmax(vector<double> input)
{
	double total=0;
	double MAX=input[0];
	for(auto x:input)
	{
		MAX=max(x,MAX);
	}
	for(auto x:input)
	{
		total+=exp(x-MAX);
	}
	vector<double> result;
	for(auto x:input)
	{
		result.push_back(exp(x-MAX)/total);
	}
	return result;
}

int main(int argc, char *argv[])
{
	int n;
	cin>>n;
	vector<double> input;
	while(n--)
	{
		double x;
		cin>>x;
		input.push_back(x);
	}
	for(auto y:softmax(input))
	{
		cout<<y<<' ';
	}
}
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

用C++实现softmax函数(面试经验) 的相关文章

  • 使用 std::packaged_task/std::exception_ptr 时,线程清理程序报告数据争用

    我遇到了线程清理程序 TSan 的一些问题 抱怨某些生产代码中的数据争用 其中 std packaged task 通过将它们包装在 std function 中而移交给调度程序线程 对于这个问题 我简化了它在生产中的作用 同时触发 TSa
  • 使用 Newtonsoft 和 C# 反序列化嵌套 JSON

    我正在尝试解析来自 Rest API 的 Json 响应 我可以获得很好的响应并创建了一些类模型 我正在使用 Newtonsoft 的 Json Net 我的响应中不断收到空值 并且不确定我的模型设置是否正确或缺少某些内容 例如 我想要获取
  • 获取两个工作日之间的天数差异

    这听起来很简单 但我不明白其中的意义 那么获取两次之间的天数的最简单方法是什么DayOfWeeks当第一个是起点时 如果下一个工作日较早 则应考虑在下周 The DayOfWeek 枚举 http 20 20 5B1 5D 3a 20htt
  • 在 C 中初始化变量

    我知道有时如果你不初始化int 如果打印整数 您将得到一个随机数 但将所有内容初始化为零似乎有点愚蠢 我问这个问题是因为我正在评论我的 C 项目 而且我对缩进非常直接 并且它可以完全编译 90 90 谢谢 Stackoverflow 但我想
  • qdbusxml2cpp 未知类型

    在使用 qdbusxml2cpp 程序将以下 xml 转换为 Qt 类时 我收到此错误 qdbusxml2cpp c ObjectManager a ObjectManager ObjectManager cpp xml object ma
  • 为什么调用非 const 成员函数而不是 const 成员函数?

    为了我的目的 我尝试包装一些类似于 Qt 共享数据指针的东西 经过测试 我发现当应该调用 const 函数时 会选择它的非 const 版本 我正在使用 C 0x 选项进行编译 这是一个最小的代码 struct Data int x con
  • 标准化 UTF-8 到底是什么?

    The 重症监护室项目 http userguide icu project org transforms normalization 现在也有一个PHP库 http us php net manual en class normalize
  • 在一个平台上,对于所有数据类型,所有数据指针的大小是否相同? [复制]

    这个问题在这里已经有答案了 Are char int long 甚至long long 大小相同 在给定平台上 不能保证它们的大小相同 尽管在我有使用经验的平台上它们通常是相同的 C 2011 在线草稿 http www open std
  • C#:帮助理解 UML 类图中的 <>

    我目前正在做一个项目 我们必须从 UML 图编写代码 我了解 UML 类图的剖析 但我无法理解什么 lt
  • C# 中的合并运算符?

    我想我记得看到过类似的东西 三元运算符 http msdn microsoft com en us library ty67wk28 28VS 80 29 aspx在 C 中 它只有两部分 如果变量值不为空 则返回变量值 如果为空 则返回默
  • 为什么 std::strstream 被弃用?

    我最近发现std strstream已被弃用 取而代之的是std stringstream 我已经有一段时间没有使用它了 但它做了我当时需要做的事情 所以很惊讶听到它的弃用 我的问题是为什么做出这个决定 有什么好处std stringstr
  • CMake 无法确定目标的链接器语言

    首先 我查看了this https stackoverflow com questions 11801186 cmake unable to determine linker language with c发帖并找不到解决我的问题的方法 我
  • “接口”类似于 boost::bind 的语义

    我希望能够将 Java 的接口语义与 C 结合起来 起初 我用过boost signal为给定事件回调显式注册的成员函数 这非常有效 但后来我发现一些函数回调池是相关的 因此将它们抽象出来并立即注册所有实例的相关回调是有意义的 但我了解到的
  • 使用管道时,如果子进程数量大于处理器数量,进程是否会被阻塞?

    当子进程数量很大时 我的程序停止运行 我不知道问题是什么 但我猜子进程在运行时以某种方式被阻止 下面是该程序的主要工作流程 void function int process num int i initial variables for
  • 使用 C# 读取 Soap 消息

  • 方法优化 - C#

    我开发了一种方法 允许我通过参数传入表 字符串 列数组 字符串 和值数组 对象 然后使用这些参数创建参数化查询 虽然它工作得很好 但代码的长度以及多个 for 循环散发出一种代码味道 特别是我觉得我用来在列和值之间插入逗号的方法可以用不同的
  • 无法接收 UDP Windows RT

    我正在为 Windows 8 RT 编写一个 Windows Store Metro Modern RT 应用程序 需要在端口 49030 上接收 UDP 数据包 但我似乎无法接收任何数据包 我已按照使用教程进行操作DatagramSock
  • Oracle Data Provider for .NET 不支持 Oracle 19.0.48.0.0

    我们刚刚升级到 Oracle 19c 19 3 0 所有应用程序都停止工作并出现以下错误消息 Oracle Data Provider for NET 不支持 Oracle 19 0 48 0 0 我将 Oracle ManagedData
  • 当从finally中抛出异常时,Catch块不会被评估

    出现这个问题的原因是之前在 NET 4 0 中运行的代码在 NET 4 5 中因未处理的异常而失败 部分原因是 try finallys 如果您想了解详细信息 请阅读更多内容微软连接 https connect microsoft com
  • 如何将 PostgreSql 与 EntityFramework 6.0.2 集成? [复制]

    这个问题在这里已经有答案了 我收到以下错误 实体框架提供程序类型的 实例 成员 Npgsql NpgsqlServices Npgsql 版本 2 0 14 2 文化 中性 PublicKeyToken 5d8b90d52f46fda7 没

随机推荐

  • JAVA基础知识点

    一 概述 JAVA语言是美国Sun公司 Stanford University Network 在1995年推出的高级变成语言 2009年Oracle甲骨文公司收购Sun公司 并于2011年发布Java7版本 DOS命令 Win R cmd
  • 2022 年度软件质量保障行业调查报告

    2022 年度软件质量保障行业调查报告 TesterHome https testerhome com topics 35615 覆盖的测试类型 个人提升工作效率的方式 优秀测试人员应该具备的能力 测试同行们的未来计划 阻碍测试进度的因素
  • CMake进阶(一)设置编译选项

    CMake 进阶 一 设置编译选项 CMake设置编译选项 构建Debug版本和Release版本 CMake文件设置 编译过程 CMake设置编译选项 在cmake脚本中 设置编译选项可以通过add compile options命令 也
  • 【超细节】Vue3的属性传递——Props

    目录 前言 一 定义 二 使用 1 在 setup 中 推荐 2 非 setup 中 3 对象写法的校验类型 4 使用ts进行类型约束 5 使用ts时props的默认值 三 注意事项 1 Prop 名字格式 2 对象或数组类型的默认值 3
  • 第十届蓝桥杯 修改数组 (研究生组)

    修改数组 问题描述 给定一个长度为 N 的数组 A A1 A2 AN 数组中有可能有重复出现的整数 现在小明要按以下方法将其修改为没有重复整数的数组 小明会依次修改 A2 A3 AN 当修改 Ai 时 小明会检查 Ai 是否在 A1 Ai
  • hp服务器g5 u盘装系统,hp 440g5怎么装系统

    惠普probook440g5为一款14英寸高性能商务办公本 在升级了英特尔酷睿i7 8代系列处理器后 配合显卡迸发出超凡的性能 很适合外出携带使用 那这款惠普笔记本怎么安装操作系统 今天小编就为大家分享hp 440g5怎么装系统 hp 44
  • PowerMod@快速幂取模

    图片链接 快速幂取模使用心得 看到过于大的数不要害怕 要学会细致分析 想想取模的作用 不就是帮你把大数化小了吗 include
  • 最强自动化测试框架Playwright(25)-浏览器

    Browser Playwright Python 方法 创建page页面 from playwright sync api import sync playwright def run playwright firefox playwri
  • 深度学习正则化

    在设计机器学习算法时不仅要求在训练集上误差 且希望在新样本上 的泛化能 强 许多机器学习算法都采 相关的策略来减 测试误差 这 些策略被统称为正则化 因为神经 络的强 的表示能 经常遇到过拟 合 所以需要使 不同形式的正则化策略 正则化通过
  • JavaWeb-通过表格显示数据库的信息(jsp+mysql)

    login jsp h2 登录 h2 br
  • python+numpy+pandas数据类型+类/对象

    写代码时逻辑明确 但是被各种数据类型以及对象类型搞蒙了 补习并简单记录一下 在进行数据分析之前需要对数据进行数据处理 其中就包含转化数据格式 可以先查看数据信息 再依据分析需求对进行处理 编写python程序时各种数据类型以及对象的类型以及
  • 微信测试账号 (2)-消息验证sha1签名

    在第1篇中实现了收发微信消息 但是没有做验证 本篇将介绍微信如何使用sha签名 对消息进行认证 其中安全相关的概念 如sha1散列值 签名等 可参考web安全 1 验证参数 GetMapping handler public String
  • live555 server 搭建

    一 直接下载live555MediaServer可执行程序 二 Live555在linux平台上编译 下载源码包 http www live555 com liveMedia live555 latest tar gz 1 解压 2 生成M
  • CCPC-南阳比赛总结

    打铁归来 感触好多 记得英语课上 收到晓红老师的消息 说给我们争取下了国赛的名额 我们感觉好幸运 没想到最后打铁了 这个结果也不太意外 下面说下 这几天的行程 反思 下一步的目标 15号 淄博到济南 15 16号 济南 郑州 南阳 17号
  • 如何过滤 map

    获取 EntrySet 然后正常使用 stream 的 filter 过滤 Entry 最后再转为 Map 即可 对 map 过滤 filter Test public void testMapFilter Map
  • OCJP题库1Z0-851(21/30)

    这套题我参考了这篇文章 以及一些百度上找的内容 再加上自己的解释 总结下来的详解 第1题 1 Given a pre generics implementation of a method 11 public static int sum
  • 申请被拒模板 (六)

    这里只是模板 仅供学习 出现任何问题 与博主无关 Hi xxxxx We really appreciate the time and effort you took to connect with us and apply for the
  • nginx配置详解

    一 什么是nginx nginx是一款自由的 开源的 高性能的HTTP服务器和反向代理服务器 同时也是一个IMAP POP3 SMTP代理服务器 Nginx作为一个HTTP服务器进行网络的发布处理 另外Nginx可以作为反向代理进行负载均衡
  • 【数据分析】为什么要学习分析方法?

    为什么要学习分析方法 如果你有以下这些症状 没有数据分析意识 工作由拍脑袋决定 而不是靠数据分析来支持决策 统计时的数据分析 做了很多图表 却发现不了业务中存在的问题 只会使用工具的数据分析 谈起使用工具的技巧头头是道 但是面对问题 还是不
  • 用C++实现softmax函数(面试经验)

    背景 今天面试字节算法岗时被问到的问题 让我用C 实现一个softmax函数 softmax是逻辑回归在多分类问题上的推广 大概的公式如下 i n p u t