-
建模中经常使用线性最小二乘法,实际上就是求超定线性方程组(未知数少,方程个数多)的最小二乘解,前面已经使用pinv()求超定线性方程组的最小二乘解.下面再举两个求最小二乘解的例子,并使用numpy.linalg模块的lstsq()函数 求解.
-
先要明确这个函数的原义是用来求超定线性方程组的:
例如下面的方程组:
系数矩阵的第一列相当于给定了x的观测值 X=[0,1,2,3].transpose
右边的结果矩阵相当于给定了y的观测值 Y=[-1,0.2,0.9,2.1].transpose
然后使用两个观测值来拟合经验函数 y=mx+c
系数矩阵的第二列存在的意义有点类似于机器学习中的偏置θ0,用于和C相乘,注意这是必要的,在只给定观测值的情况下,我们也常常需要np.ones_like(X的长度来构建有这一“无效列”的矩阵.
- **lstsq(a,b,rcond=“warn”)**函数的参数详解(下面的矩阵都是array_like(类数组对象)):
1. a是一个M行N列的系数矩阵,前面说过需要构造np.ones_like(M)
2. b是一个(M,)或者(M,K),如果b是一个M行K列的二维矩阵,函数会逐个计算每一列的最小二乘法
3. rcond这个参数是可选的,是用于奇异矩阵的处理的,感兴趣的可以自行查看源码,官方推荐我们一般用 rcond=None
返回值:以下提到的所有矩阵都是ndarray, NumPy 最重要的一个特点是其 N 维数组对象 ndarray,它是一系列同类型数据的集合,以 0 下标为开始进行集合中元素的索引):
-
x : {(N,), (N, K)} ndarray (我们所要的结果,如果前面的b是二维的,那么这里也会有k列的a和b结果)
-
residuals : {(1,), (K,), (0,)} ndarray
-
rank: int
-
a 的奇异值
返回值重点关注返回集合中的x就行,所以我们一般的用法是lstsq()[0]
官方的使用栗子:
Examples
--------
Fit a line, ``y = mx + c``, through some noisy data-points:
>>> x = np.array([0, 1, 2, 3])
>>> y = np.array([-1, 0.2, 0.9, 2.1])
By examining the coefficients, we see that the line should have a
gradient of roughly 1 and cut the y-axis at, more or less, -1.
We can rewrite the line equation as ``y = Ap``, where ``A = [[x 1]]``
and ``p = [[m], [c]]``. Now use `lstsq` to solve for `p`:
>>> A = np.vstack([x, np.ones(len(x))]).T
>>> A
array([[ 0., 1.],
[ 1., 1.],
[ 2., 1.],
[ 3., 1.]])
>>> m, c = np.linalg.lstsq(A, y, rcond=None)[0]
>>> print(m, c)
1.0 -0.95
Plot the data along with the fitted line:
>>> import matplotlib.pyplot as plt
>>> plt.plot(x, y, 'o', label='Original data', markersize=10)
>>> plt.plot(x, m*x + c, 'r', label='Fitted line')
>>> plt.legend()
>>> plt.show()
-
下面举个栗子:
给定一组实验数据
0 |
27. |
1 |
26.8 |
2 |
26.5 |
3 |
26.3 |
4 |
26.1 |
5 |
25.7 |
6 |
25.3 |
|
24.8 |
我们来进行一元线性拟合 y=at+b
import numpy as np
import numpy.linalg as LA
import matplotlib.pyplot as plt
t=np.arange(8)
y=np.array([27.0,26.8,26.5,26.3,26.1,25.7,25.3,24.8])
A=np.c_[t, np.ones_like(t)]
print(np.ones_like(t))
ab=LA.lstsq(A,y,rcond=None)[0]
print(ab);
plt.rc('font',size=16)
plt.plot(t,y,'o',label='Original data',markersize=5)
plt.plot(t,A.dot(ab),'r',label="Fitted line")
plt.legend();
plt.show();
**感兴趣的可以自己运行试试看**