Tensorflow中可调用函数tf.truncated_normal来进行截断高斯分布的采样(什么是截断高斯分布,看下图,分布在-0.1和0.1处被截断了),具体如下
import tensorflow as tf
import matplotlib.pyplot as plt
c = tf.truncated_normal(shape=[10000, ], mean=0, stddev=0.05)
with tf.Session() as sess:
sess.run(c)
data = c.eval()
plt.hist(x=data, bins=100, color='steelblue', edgecolor='black')
plt.show()
上述代码生成的数据介于[mean-2*stddev, mean+2*stddev]之间,本例中为,[-0.1,0.1]采样结果大致为
Python中还可通过scipy包实现与上述功能相同的采样,具体为
import scipy.stats as stats
import matplotlib.pyplot as plt
mu, sigma = 0, 0.05
lower, upper = mu - 2 * sigma, mu + 2 * sigma
x = stats.truncnorm(
(lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
plt.hist(x.rvs(10000), bins=100, color='red', edgecolor='black')
plt.show()
图示结果为
颜色很骚气...
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)