该解决方案基于以下代码scipy.stats.ks_2samp
运行时间约为 1/10000 (notebook https://colab.research.google.com/drive/1KPTakjZpCx0VGcPEI53UzdyFlNkxY7x-?usp=sharing):
import numpy as np
def ks_w2(data1, data2, wei1, wei2):
ix1 = np.argsort(data1)
ix2 = np.argsort(data2)
data1 = data1[ix1]
data2 = data2[ix2]
wei1 = wei1[ix1]
wei2 = wei2[ix2]
data = np.concatenate([data1, data2])
cwei1 = np.hstack([0, np.cumsum(wei1)/sum(wei1)])
cwei2 = np.hstack([0, np.cumsum(wei2)/sum(wei2)])
cdf1we = cwei1[[np.searchsorted(data1, data, side='right')]]
cdf2we = cwei2[[np.searchsorted(data2, data, side='right')]]
return np.max(np.abs(cdf1we - cdf2we))
这是对其准确性和性能的测试:
ds1 = np.random.rand(10000)
ds2 = np.random.randn(40000) + .2
we1 = np.random.rand(10000) + 1.
we2 = np.random.rand(40000) + 1.
ks_w2(ds1, ds2, we1, we2)
# 0.4210415232236593
ks_w(ds1, ds2, we1, we2)
# 0.4210415232236593
%timeit ks_w2(ds1, ds2, we1, we2)
# 100 loops, best of 3: 17.1 ms per loop
%timeit ks_w(ds1, ds2, we1, we2)
# 1 loop, best of 3: 3min 44s per loop