保存具有自定义前向功能的 Bert 模型并将其置于 Huggingface 上

2024-05-11

我创建了自己的 BertClassifier 模型,从预训练开始,然后添加由不同层组成的我自己的分类头。微调后,我想使用 model.save_pretrained() 保存模型,但是当我打印它并从预训练上传时,我看不到我的分类器头。 代码如下。如何保存模型上的所有结构并使其完全可访问 AutoModel.from_preatrained('folder_path')? 。谢谢!

class BertClassifier(PreTrainedModel):
    """Bert Model for Classification Tasks."""
    config_class = AutoConfig
    def __init__(self,config, freeze_bert=True): #tuning only the head
        """
         @param    bert: a BertModel object
         @param    classifier: a torch.nn.Module classifier
         @param    freeze_bert (bool): Set `False` to fine-tune the BERT model
        """
        #super(BertClassifier, self).__init__()
        super().__init__(config)

        # Instantiate BERT model
        # Specify hidden size of BERT, hidden size of our classifier, and number of labels
        self.D_in = 1024 #hidden size of Bert
        self.H = 512
        self.D_out = 2
 
        # Instantiate the classifier head with some one-layer feed-forward classifier
        self.classifier = nn.Sequential(
            nn.Linear(self.D_in, 512),
            nn.Tanh(),
            nn.Linear(512, self.D_out),
            nn.Tanh()
        )
 


    def forward(self, input_ids, attention_mask):


         # Feed input to BERT
        outputs = self.bert(input_ids=input_ids,
                             attention_mask=attention_mask)
         
         # Extract the last hidden state of the token `[CLS]` for classification task
        last_hidden_state_cls = outputs[0][:, 0, :]
 
         # Feed input to classifier to compute logits
        logits = self.classifier(last_hidden_state_cls)
 
        return logits

configuration=AutoConfig.from_pretrained('Rostlab/prot_bert_bfd')
model = BertClassifier(config=configuration,freeze_bert=False)

微调后保存模型

model.save_pretrained('path')

加载微调模型

model = AutoModel.from_pretrained('path') 

加载后打印模型显示我有以下最后一层,并且缺少我的 2 个线性层:

 (output): BertOutput(
          (dense): Linear(in_features=4096, out_features=1024, bias=True)
          (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (adapters): ModuleDict()
          (adapter_fusion_layer): ModuleDict()
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=1024, out_features=1024, bias=True)
    (activation): Tanh()
  )
  (prefix_tuning): PrefixTuningPool(
    (prefix_tunings): ModuleDict()
  )
)

也许有问题config_class你里面的属性BertClassifier班级。根据文档,您需要创建一个继承 form 的附加配置类PretrainedConfig并初始化model_type属性与您的自定义模型的名称。

The BertClassifier's config_class必须与您的自定义配置类类型一致。 之后,您可以通过以下调用注册您的配置和模型:

AutoConfig.register('CustomModelName', CustomModelConfigClass)
AutoModel.register(CustomModelConfigClass, CustomModelClass)

并加载您的微调模型AutoModel.from_pretrained('YourCustomModelName')

基于您的代码的不完整示例可能如下所示:

class BertClassifierConfig(PretrainedConfig):
    model_type="BertClassifier"


class BertClassifier(PreTrainedModel):
    config_class = BertClassifierConfig
    # ...


configuration = BertClassifierConfig()
bert_classifier = BertClassifier(configuration)

# do your finetuning and save your custom model
bert_classifier.save_pretrained("CustomModels/BertClassifier")

# register your config and your model
AutoConfig.register("BertClassifier", BertClassifierConfig)
AutoModel.register(BertClassifierConfig, BertClassifier)

# load your model with AutoModel
bert_classifier_model = AutoModel.from_pretrained("CustomModels/BertClassifier")

打印模型输出应与此类似:

    (pooler): BertPooler(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (activation): Tanh()
    )
  )
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): Tanh()
    (2): Linear(in_features=512, out_features=2, bias=True)
    (3): Tanh()
    (4): Linear(in_features=2, out_features=512, bias=True)
    (5): Tanh()
  )

希望这可以帮助。

https://huggingface.co/docs/transformers/custom_models#registering-a-model-with-custom-code-to-the-auto-classes https://huggingface.co/docs/transformers/custom_models#registering-a-model-with-custom-code-to-the-auto-classes

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

保存具有自定义前向功能的 Bert 模型并将其置于 Huggingface 上 的相关文章

  • Vimeo API:获取下载所有视频文件的链接列表

    再会 我正在尝试从 Vimeo 帐户获取所有视频文件的列表 直接下载的链接 有没有办法在 1 GET 请求中做到这一点 好的 如果是API限制的话 就100倍 我有硬编码脚本 我在其中发出 12 个 GET 请求 1100 多个视频 根据文
  • 导入错误:无法导入名称“FFProbe”

    我无法获取ffprobe包 https github com simonh10 ffprobe在 Python 3 6 中工作 我使用 pip 安装它 但是当我输入import ffprobe it says Traceback most
  • 从 torch.autograd.gradcheck 导入 zero_gradients

    我想复制代码here https github com LTS4 DeepFool blob master Python deepfool py 并且我在 Google Colab 中运行时收到以下错误 ImportError 无法导入名称
  • GUI 测试工具 PyUseCase 与 Dogtail 相比如何?

    GUI测试工具如何Py用例 http pypi python org pypi PyUseCase重命名为故事文本 http pypi python org pypi StoryText 相比于Dogtail http en wikiped
  • 如何同时运行多个功能[关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 我有以下代码 my func1 my func2 my func3 my func4 my func5 是否可以同时计算函数的数据 而
  • 使用管理员权限打开cmd(Windows 10)

    我有自己的 python 脚本来管理我的计算机上的 IP 地址 它主要在命令行 Windows 10 中执行netsh命令 您必须具有管理员权限 这是我自己的计算机 我是管理员 运行脚本时我已经使用管理员类型的用户 Adrian 登录 我无
  • Python:json_normalize pandas 系列给出 TypeError

    我在 pandas 系列中有数万行像这样的 json 片段df json IDs lotId 1 Id 123456 date 2009 04 17 bidsCount 2 IDs lotId 2 Id 123456 date 2009 0
  • Python3将模块从文件夹导入到另一个文件夹

    我的结构字典是 mainFolder folder1 init py file1 py file2 py folder2 init py file3 py file4 py setup py init py 我需要将 file4 py 从f
  • Python Selenium 打印另存为 PDF 等待文件名输入

    我正在尝试通过打印对话框将网站另存为 PDF 我的代码允许我另存为pdf 但要求我输入文件名 我不知道如何将文件名传递到弹出框 附上我的代码 import time from selenium import webdriver import
  • 如何用函数记录一个文件?

    我有一个带有函数 lib py 但没有类的python 文件 每个函数都有以下样式 def fnc1 a b c This fonction does something param a lalala type a str param b
  • 使用 Tkinter 打开网页

    因此 我的应用程序需要能够打开其中的单个网页 并且它必须来自互联网并且未保存 特别是我想使用 Tkinter GUI 工具包 因为它是我最熟悉的工具包 最重要的是 我希望能够在窗口中生成事件 例如单击鼠标 但无需实际使用鼠标 有什么好的方法
  • 会话数据库表清理

    该表是否需要清除或者由 Django 自动处理 Django 不提供自动清除功能 然而 有一个方便的命令可以帮助您手动完成此操作 Django 文档 清除会话存储 https docs djangoproject com en dev to
  • Python在没有pandas的情况下解码excel表

    我正在尝试在 python 中读取 excel 文件而不使用pandas or xlrd 我一直在尝试将结果转换为bytes to utf 8没有任何成功 xls 文件中的数据 colA colB colC spc 1D0 20190705
  • pandas groupby 操作缺少数据

    在 pandas 数据框中 我有一列如下所示 0 M 1 E 2 L 3 M 1 4 M 2 5 M 3 6 E 1 7 E 2 8 E 3 9 E 4 10 L 1 11 L 2 12 M 1 a 13 M 1 b 14 M 1 c 15
  • Eclipse/PyDev 中未使用导入警告,尽管已使用

    我正在我的文件中导入一个绘图包 如下所示 import matplotlib pyplot as plt 稍后我会在我的代码中成功使用此导入 fig plt figure figsize 16 10 然而 Eclipse 告诉我 未使用的导
  • 将图与热图(可能是对数)配对?

    How to create a pair plot in Python like the following but with heat maps instead of points or instead of a hex bin plot
  • 如何创建增量加载网页

    我正在编写一个处理大量数据的页面 它会永远持续到我的结果页面加载 几乎无限 因为返回的数据太大了 因此 我需要实现一个增量加载页面 例如 url 中的页面 http docs python org http docs python org
  • 根据标点符号列表替换数据框中的标点符号[重复]

    这个问题在这里已经有答案了 使用 Canopy 和 Pandas 我有数据框 a 其定义如下 a pd read csv text txt df pd DataFrame a df columns test test txt 是一个单列文件
  • 将 Django 中的所有视图限制为经过身份验证的用户

    我是 Django 新手 我正在开发一个项目 该项目有一个登录页面作为其索引和一个注册页面 其余页面都必须仅限于登录用户 如果未经身份验证的用户尝试访问这些页面 则必须将他 她重定向到登录页面 我看到 login required装饰器会将
  • py2exe ImportError:没有名为 的模块

    我已经实现了一个名为 myUtils 的包 它由文件夹 myUtils 文件 组成 init py 和许多名称为 myUtils 的 py 文件 该包包含在 myOtherProject py 中 当我从 Eclipse 运行它们时可以找到

随机推荐