PyTorch:用于训练和测试/验证的不同前向方法

2023-12-22

我目前正在尝试延长a model https://github.com/microsoft/MASS这是基于 FairSeq/PyTorch 的。在训练过程中,我需要训练两个编码器:一个使用目标样本,另一个使用源样本。

所以当前的forward函数看起来像这样:

def forward(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
    return decoder_out

并以此为基础这个想法 https://github.com/golsun/SpaceFusion我想要这样的东西:

def forward_test(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
    return decoder_out

def forward_train(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
    encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
    autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
    concat = some_concatination_func(encoder_out, autoencoder_out)
    decoder_out = self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
    return decoder_out

有什么办法可以做到这一点吗?

编辑: 这些是我所面临的限制,因为我需要扩展FairseqEncoderDecoder模型:

@register_model('transformer_mass')
class TransformerMASSModel(FairseqEncoderDecoderModel):
    def __init__(self, encoder, decoder):
        super().__init__(encoder, decoder) 

编辑2: 传递给 Fairseq 中的前向函数的参数可以通过实现您自己的标准来更改,请参见示例交叉熵准则 https://github.com/pytorch/fairseq/blob/master/fairseq/criterions/cross_entropy.py#L28, where sample['net_input']被传递到__call__模型的函数,它调用forward method.


首先你应该总是使用和定义forward不是您调用的其他方法torch.nn.Module实例。

绝对不能超载eval()如图所示trsvchn https://stackoverflow.com/a/58659193/10886420因为它是 PyTorch 定义的评估方法(see here https://pytorch.org/docs/stable/nn.html#torch.nn.Module.eval).此方法允许将模型内的层置于评估模式(例如,对层的特定更改,例如推理模式)Dropout or BatchNorm).

此外,你应该用__call__魔法方法。为什么?因为钩子和其他 PyTorch 特定的东西是以这种方式正确注册的。

其次,不要使用一些外来的东西。mode字符串变量,如建议的那样@阿南特·米塔尔 https://stackoverflow.com/questions/58655207/pytorch-different-forward-methods-for-train-and-test-validation/58655415#58655415。就是这样trainPyTorch中的变量是for,通过它来区分模型是否在eval模式或train mode.

话虽这么说,你最好这样做:

import torch


class Network(torch.nn.Module):
    def __init__(self):
        super().__init__()
        ...

    # You could split it into two functions but both should be called by forward
    def forward(
        self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs
    ):
        encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
        if self.train:
            return self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
        autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
        concat = some_concatination_func(encoder_out, autoencoder_out)
        return self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)

您可以(并且可以说应该)将上述内容拆分为两个单独的方法,但这还不错,因为该函数相当短且可读。如果可能的话,只要坚持 PyTorch 的处理方式即可,而不是一些临时解决方案。不,反向传播不会有问题,为什么会有问题呢?

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch:用于训练和测试/验证的不同前向方法 的相关文章

  • 如何获取 Flask 中当前的基本 URI? [复制]

    这个问题在这里已经有答案了 在下面的代码中 我想将 URL 存储在变量中以检查发生 URL 错误的错误 app route flights methods GET def get flight flight data mongo db fl
  • 如何将 xlsx 读取为 pandas 数据框,并将公式作为字符串

    我有一个包含一些计算列的 Excel 文件 例如 我在 a 列中有一些数据 而 b 列是使用 a 列中的值计算的 我需要将新数据附加到 a 列并计算 b 列并保存文件 import pandas as pd df pd DataFrame
  • 使 pycaffe 致命错误:找不到“Python.h”文件

    我在运行 OSX 10 9 5 的 Mac 上编译了 caffe 并且我知道尝试编译 pycaffe 当我在 caffe 根文件夹中运行 make pycaffe 时 我得到 CXX LD o python caffe caffe so p
  • PyInstaller 可执行文件无法获取 TorchScript 源代码

    我正在尝试使包含 PyTorch 的脚本在 Windows 中可执行 我的脚本的导入是 import numpy core multiarray which is a workaround for ImportError numpy cor
  • 在 Pandas 数据框中显示对图

    我试图通过从 pandas 数据框中的 scatter matrix 创建来显示一对图 这就是创建配对图的方式 Create dataframe from data in X train Label the columns using th
  • 使用 xlwings 排序(pywin32)

    我需要使用 python 按给定行对 Excel 电子表格进行排序 为了进行测试 我使用以下数据 在名为 xlwings sorting xlsx 的文件中 Numbers Letters Letters 2 7 A L 6 B K 5 C
  • 使用 imaplib 库连接到电子邮件时遇到 AUTHENTICATIONFAILED 错误

    如何连接到 imaplib 库而不遇到 AUTHENTICATIONFAILE 错误 通过网络浏览器登录时 我的 Gmail 收件箱显示严重的安全警报 登录尝试被阻止 IMAP SERVER imap gmail com USERNAME
  • 根据caffe中的“badness”缩放损失值

    我想根据训练期间 当前预测 与 正确标签 的接近 远近来缩放每个图像的损失值 例如 如果正确的标签是 猫 而网络认为它是 狗 那么惩罚 损失 应该小于网络认为它是 汽车 的情况 我正在做的方式如下 1 我定义了标签之间距离的矩阵 2 将该矩
  • python sqlite3从excel创建数据库

    我正在尝试从 Excel 电子表格创建数据库 我有下面的代码 问题是当我运行代码时 我的数据库为每列创建一个表 我想为工作簿中列出的每个电子表格创建一个表格 工作表名称为工作表 1 和工作表 2 import sqlite3 import
  • 将 Pytorch 模型 .pth 转换为 onnx 模型

    我有一个预训练的模型 其格式为 pth 扩展名 我想将其转换为 Tensorflow protobuf 但我没有找到任何方法来做到这一点 我见过 onnx 可以将模型从 pytorch 转换为 onnx 然后从 onnx 转换为 Tenso
  • python3中“super”对象没有属性“__getattr__”

    如何覆盖 getattr 使用 python 3 和继承 当我使用以下内容时 class MixinA def getattr self item Process item and return value if known if item
  • 导入错误 - 发生了什么?

    Python 导入 再次 我有这个文件结构 test start py from scripts import main scripts init py empty main py from import install install p
  • Python 3 如何知道如何 pickle 扩展类型,尤其是 Numpy 数组?

    Numpy 数组是扩展类型 也称为使用 C API 扩展定义的 声明了 Python 解释器范围之外的附加字段 例如data属性 这是一个Buffer Structure 如 Numpy 中所述阵列接口 https docs scipy o
  • xlwings: 删除一个列 | Excel 中的行

    如何删除 Excel 中的一行 wb xw Book Shipment xlsx wb sheets Page1 1 range 1 1 clear clear 用于删除内容 我想删除该行 我很惊讶 clear 函数有效 但 delete
  • 互补DNA序列

    我在编写这个循环时遇到问题 它似乎在第二个序列之后停止了 我想返回给定 DNA 序列的互补 DNA 序列 例如 AGATTC gt TCTAAG 其中 A T 和 C G def get complementary sequence dna
  • 无法使用 beautifulsoup 模块 python 从 HTML 检索温度值

    我正在使用 BeautifulSoup4 来解析此 HTML 查看源代码 https weather com en IN weather today l 17 39 78 49 https weather com en IN weather
  • Python 类型提示:typing.Mapping 与typing.Dict

    我正在开发一个 python 项目 我们使用typing整个模块类型提示 看来我们用的是typing Dict and typing Mapping几乎可以互换 有理由选择其中一种而不是另一种吗 我自己设法回答了这个问题 typing Di
  • Python 3 中“map”类型的对象没有 len()

    我在使用 Python 3 时遇到问题 我得到了 Python 2 7 代码 目前我正在尝试更新它 我收到错误 类型错误 map 类型的对象没有 len 在这部分 str len seed candidates 在我像这样初始化它之前 se
  • 类型错误:只能使用标量值执行操作

    如果您能让我知道如何为所提供的表格绘制一些信息丰富的图表 我将不胜感激here https www iasplus com en resources ifrs topics use of ifrs 例如 我需要一个名为 国内非上市公司 非上
  • 设置 torch.gather(...) 调用的结果

    我有一个形状为 n x m 的 2D pytorch 张量 我想使用索引列表来索引第二个维度 可以使用 torch gather 完成 然后然后还设置新值到索引的结果 Example data torch tensor 0 1 2 3 4

随机推荐