mllib NaiveBayes 中的类数量有限制吗?调用 model.save() 时出错

2023-12-03

我正在尝试训练一个模型来预测文本输入数据的类别。我使用以下方法遇到了似乎数值不稳定的问题pyspark.ml.classification.NaiveBayes当类别数量超过一定数量时,对词袋进行分类。

在我的现实世界项目中,我有大约 10 亿条记录和大约 50 个类。我能够训练我的模型并做出预测,但是当我尝试使用保存它时出现错误model.save()。从操作上来说,这很烦人,因为我每次都必须从头开始重新训练我的模型。

在尝试调试时,我将数据缩小到大约 10k 行,并在尝试保存时遇到了相同的问题。但是,如果我减少类标签的数量,保存效果很好。

这让我相信标签的数量是有限的。我无法重现我的确切问题,但下面的代码是相关的。如果我设置num_labels任何大于 31 的值,model.fit()抛出错误。

我的问题:

  1. 班级人数有限制吗mllib实施NaiveBayes?
  2. 如果我可以成功地使用模型进行预测,那么我无法保存模型的原因可能是什么?
  3. 如果确实存在限制,是否可以将我的数据分成更小的类别组,训练单独的模型,然后组合?

完整的工作示例

创建一些虚拟数据。

我要使用nltk.corpus.comparitive_sentences and nltk.corpus.sentence_polarity。请记住,这只是一个带有无意义数据的说明性示例 - 我不关心拟合模型的性能。

import pandas as pd
from pyspark.sql.types import StringType

# create some dummy data
from nltk.corpus import comparative_sentences, sentence_polarity
df = pd.DataFrame(
    {
        'sentence': [" ".join(s) for s in cs.sents() + sp.sents()]
    }
)

# assign a 'category' to each row
num_labels = 31  # seems to be the upper limit
df['category'] = (df.index%num_labels).astype(str)

# make it into a spark dataframe
spark_df = sqlCtx.createDataFrame(df)

数据准备管道

from pyspark.ml.feature import NGram, Tokenizer, StopWordsRemover
from pyspark.ml.feature import HashingTF, IDF, StringIndexer, VectorAssembler
from pyspark.ml import Pipeline
from pyspark.ml.linalg import Vector

indexer = StringIndexer(inputCol='category', outputCol='label')
tokenizer = Tokenizer(inputCol="sentence", outputCol="sentence_tokens")
remove_stop_words = StopWordsRemover(inputCol="sentence_tokens", outputCol="filtered")
unigrammer = NGram(n=1, inputCol="filtered", outputCol="tokens") 
hashingTF = HashingTF(inputCol="tokens", outputCol="hashed_tokens")
idf = IDF(inputCol="hashed_tokens", outputCol="tf_idf_tokens")

clean_up = VectorAssembler(inputCols=['tf_idf_tokens'], outputCol='features')

data_prep_pipe = Pipeline(
    stages=[indexer, tokenizer, remove_stop_words, unigrammer, hashingTF, idf, clean_up]
)
transformed = data_prep_pipe.fit(spark_df).transform(spark_df)
clean_data = transformed.select(['label','features'])

训练模型

from pyspark.ml.classification import NaiveBayes
nb = NaiveBayes()
(training,testing) = clean_data.randomSplit([0.7,0.3], seed=12345)
model = nb.fit(training)
test_results = model.transform(testing)

评估模型

from pyspark.ml.evaluation import MulticlassClassificationEvaluator
acc_eval = MulticlassClassificationEvaluator()
acc = acc_eval.evaluate(test_results)
print("Accuracy of model at predicting label was: {}".format(acc))

在我的机器上,打印:

Accuracy of model at predicting label was: 0.0305764788269

错误信息

如果我改变num_labels到 32 或更高,这是我调用时收到的错误model.fit():

Py4JJavaError:调用 o1336.fit 时发生错误。 : org.apache.spark.SparkException:作业由于阶段失败而中止: 阶段 86.0 中的任务 0 失败了 4 次,最近一次失败:丢失任务 0.3 阶段 86.0(TID 1984,someserver.somecompany.net,执行器 22):org.apache.spark.SparkException:Kryo 序列化失败:缓冲区 溢出。可用:7,必需:8 序列化跟踪:值 (org.apache.spark.ml.linalg.DenseVector)。为了避免这种情况,增加 Spark.kryoserializer.buffer.最大值。 ... ... 等等等等更多永远持续下去的java东西

Notes

  • 在此示例中,如果我添加二元组的功能,则在以下情况下会发生错误num_labels> 15. 我想知道这也是小于 2 的幂 1 是否是巧合。
  • 在我的实际项目中,我在尝试调用时也会遇到错误model.theta。 (我认为错误本身没有意义 - 它们只是从 java/scala 方法传回的异常。)

硬限制:

Number of features * Number of classes has to be lower Integer.MAX_VALUE (231 - 1). You are nowhere near these value.

软限制:

Theta 矩阵(条件概率)的大小为特征数 * 类数。 Theta 既存储在驱动程序本地(作为模型的一部分),又序列化并发送给工作人员。这意味着所有机器至少需要足够的内存来序列化或反序列化并存储结果。

Since you use default settings for HashingTF.numFeatures (220) each additional class adds 262144 - it is not that much, but quickly adds up. Based on the partial traceback you've posted, it looks like the failing component is Kryo serializer. The same traceback also suggests the solution, which is increasing spark.kryoserializer.buffer.max.

您还可以通过设置尝试使用标准 Java 序列化:

 spark.serializer org.apache.spark.serializer.JavaSerializer 

由于您使用 PySparkpyspark.ml and pyspark.sql在没有显着性能损失的情况下,这可能是可以接受的。

除了配置之外,我将重点关注功能工程组件。使用二进制CountVetorizer(请参阅关于HashingTF下)与ChiSqSelector可能提供一种既提高可解释性又有效减少特征数量的方法。您还可以考虑更复杂的方法(确定特征重要性并仅在数据子集上应用朴素贝叶斯、更高级的文本处理(例如词形还原/词干提取)或使用自动编码器的某些变体来获得更紧凑的向量表示)。

Notes:

  • 请记住,跨国朴素贝叶斯仅考虑二元特征。NaiveBayes将在内部处理这个问题,但我仍然建议使用setBinary为了清楚起见。
  • 可以说HashingTF这里是相当无用的。除了哈希冲突之外,高度稀疏的特征和本质上无意义的特征,使得它作为预处理步骤的糟糕选择NaiveBayes.
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

mllib NaiveBayes 中的类数量有限制吗?调用 model.save() 时出错 的相关文章

  • 如何在刻度标签和轴之间添加空间

    我已成功增加刻度标签的字体 但现在它们距离轴太近了 我想在刻度标签和轴之间添加一点呼吸空间 如果您不想全局更改间距 通过编辑 rcParams 并且想要更简洁的方法 请尝试以下操作 ax tick params axis both whic
  • Python、Tkinter、更改标签颜色

    有没有一种简单的方法来更改按钮中文本的颜色 I use button text input text here 更改按下后按钮文本的内容 是否存在类似的颜色变化 button color red Use the foreground设置按钮
  • InterfaceError:连接已关闭(使用 django + celery + Scrapy)

    当我在 Celery 任务中使用 Scrapy 解析函数 有时可能需要 10 分钟 时 我得到了这个信息 我用 姜戈 1 6 5 django celery 3 1 16 芹菜 3 1 16 psycopg2 2 5 5 我也使用了psyc
  • 将字符串转换为带有毫秒和时区的日期时间 - Python

    我有以下 python 片段 from datetime import datetime timestamp 05 Jan 2015 17 47 59 000 0800 datetime object datetime strptime t
  • Python PAM 模块的安全问题?

    我有兴趣编写一个 PAM 模块 该模块将利用流行的 Unix 登录身份验证机制 我过去的大部分编程经验都是使用 Python 进行的 并且我正在交互的系统已经有一个 Python API 我用谷歌搜索发现pam python http pa
  • 如何使用固定的 pandas 数据框进行动态 matplotlib 绘图?

    我有一个名为的数据框benchmark returns and strategy returns 两者具有相同的时间跨度 我想找到一种方法以漂亮的动画风格绘制数据点 以便它显示逐渐加载的所有点 我知道有一个matplotlib animat
  • 如何使用包含代码的“asyncio.sleep()”进行单元测试?

    我在编写 asyncio sleep 包含的单元测试时遇到问题 我要等待实际的睡眠时间吗 I used freezegun到嘲笑时间 当我尝试使用普通可调用对象运行测试时 这个库非常有用 但我找不到运行包含 asyncio sleep 的测
  • Spark的distinct()函数是否仅对每个分区中的不同元组进行洗牌

    据我了解 distinct 哈希分区 RDD 来识别唯一键 但它是否针对仅移动每个分区的不同元组进行了优化 想象一个具有以下分区的 RDD 1 2 2 1 4 2 2 1 3 3 5 4 5 5 5 在此 RDD 上的不同键上 所有重复键
  • 为 pandas 数据透视表中的每个值列定义 aggfunc

    试图生成具有多个 值 列的数据透视表 我知道我可以使用 aggfunc 按照我想要的方式聚合值 但是如果我不想对两列求和或求平均值 而是想要一列的总和 同时求另一列的平均值 该怎么办 那么使用 pandas 可以做到这一点吗 df pd D
  • 安装后 Anaconda 提示损坏

    我刚刚安装张量流GPU创建单独的后环境按照以下指示here https github com antoniosehk keras tensorflow windows installation 但是 安装后当我关闭提示窗口并打开新航站楼弹出
  • keras加载模型错误尝试将包含17层的权重文件加载到0层的模型中

    我目前正在使用 keras 开发 vgg16 模型 我用我的一些图层微调 vgg 模型 拟合我的模型 训练 后 我保存我的模型model save name h5 可以毫无问题地保存 但是 当我尝试使用以下命令重新加载模型时load mod
  • 在 NumPy 中获取 ndarray 的索引和值

    我有一个 ndarrayA任意维数N 我想创建一个数组B元组 数组或列表 其中第一个N每个元组中的元素是索引 最后一个元素是该索引的值A 例如 A array 1 2 3 4 5 6 Then B 0 0 1 0 1 2 0 2 3 1 0
  • Java 中的“Lambdifying”scala 函数

    使用Java和Apache Spark 已用Scala重写 面对旧的API方法 org apache spark rdd JdbcRDD构造函数 其参数为 AbstractFunction1 abstract class AbstractF
  • 在pyyaml中表示具有相同基类的不同类的实例

    我有一些单元测试集 希望将每个测试运行的结果存储为 YAML 文件以供进一步分析 YAML 格式的转储数据在几个方面满足我的需求 但测试属于不同的套装 结果有不同的父类 这是我所拥有的示例 gt gt gt rz shorthand for
  • 当玩家触摸屏幕一侧时,如何让 pygame 发出警告?

    我使用 pygame 创建了一个游戏 当玩家触摸屏幕一侧时 我想让 pygame 给出类似 你不能触摸屏幕两侧 的错误 我尝试在互联网上搜索 但没有找到任何好的结果 我想过在屏幕外添加一个方块 当玩家触摸该方块时 它会发出警告 但这花了很长
  • 表达式中的 Python 'in' 关键字与 for 循环中的比较 [重复]

    这个问题在这里已经有答案了 我明白什么是in运算符在此代码中执行的操作 some list 1 2 3 4 5 print 2 in some list 我也明白i将采用此代码中列表的每个值 for i in 1 2 3 4 5 print
  • 如何在 Django 中使用并发进程记录到单个文件而不使用独占锁

    给定一个在多个服务器上同时执行的 Django 应用程序 该应用程序如何记录到单个共享日志文件 在网络共享中 而不保持该文件以独占模式永久打开 当您想要利用日志流时 这种情况适用于 Windows Azure 网站上托管的 Django 应
  • 检查所有值是否作为字典中的键存在

    我有一个值列表和一本字典 我想确保列表中的每个值都作为字典中的键存在 目前我正在使用两组来确定字典中是否存在任何值 unmapped set foo set bar keys 有没有更Pythonic的方法来测试这个 感觉有点像黑客 您的方
  • Spark.read 在 Databricks 中给出 KrbException

    我正在尝试从 databricks 笔记本连接到 SQL 数据库 以下是我的代码 jdbcDF spark read format com microsoft sqlserver jdbc spark option url jdbc sql
  • Python - 字典和列表相交

    给定以下数据结构 找出这两种数据结构共有的交集键的最有效方法是什么 dict1 2A 3A 4B list1 2A 4B Expected output 2A 4B 如果这也能产生更快的输出 我可以将列表 不是 dict1 组织到任何其他数

随机推荐