over 2 years ago


我們想知道,交友平台的用戶中,朋友數與上網時間是否有正關係。假設具有線性關係。希望能最小化成本函數(最小平方法),此問題分別測試以下解。

  1. 解析解(analytic solution):
    線性迴歸模型,只要使用最小平方法保證有解析解。其單變數迴歸解為,
    $$
    \beta = \frac{Corr(x,y) * \sigma(y)}{\sigma(x)} \tag{1}
    $$
    其中Pearson correlation coefficient, 標準差 。其值可以利用np.corrcoef * np.std(y)/np.std(x)求得。
    $$
    \beta_0 = E[Y] - \beta E[X] \tag{2}
    $$
    為x之平均值。

  2. mini-batch梯度下降法
    利用此節技巧,針對不同的(mini-batch)測試。


  3. scikit-learn LinearRegression
    在py裡機器學習基本上都會使用別人寫好的library。這邊是利用sklearn.linear_model.LinearRegression和自己的計算結果作一個比較。呼叫時須注意x,y的格式須為
    x = np.array([[x1],[x2],[x3],...]) , y = np.array([y1,y2,y3...])


程式碼

## linear regression

from collections import Counter,defaultdict
from sklearn import linear_model
import numpy as np
import matplotlib.pyplot as plt
import GD_mini_batch as gd

num_friends = [100,49,41,40,25,21,21,19,19,18,18,16,15,15,15,15,14,14,13,13,13,13,12,12,11,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,8,8,8,8,8,8,8,8,8,8,8,8,8,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,6,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]
daily_minutes = [1,68.77,51.25,52.08,38.36,44.54,57.13,51.4,41.42,31.22,34.76,54.01,38.79,47.59,49.1,27.66,41.03,36.73,48.65,28.12,46.62,35.57,32.98,35,26.07,23.77,39.73,40.57,31.65,31.21,36.32,20.45,21.93,26.02,27.34,23.49,46.94,30.5,33.8,24.23,21.4,27.94,32.24,40.57,25.07,19.42,22.39,18.42,46.96,23.72,26.41,26.97,36.76,40.32,35.02,29.47,30.2,31,38.11,38.18,36.31,21.03,30.86,36.07,28.66,29.08,37.28,15.28,24.17,22.31,30.17,25.53,19.85,35.37,44.6,17.23,13.47,26.33,35.02,32.09,24.81,19.33,28.77,24.26,31.98,25.73,24.86,16.28,34.51,15.23,39.72,40.8,26.06,35.76,34.76,16.13,44.04,18.03,19.65,32.62,35.59,39.43,14.18,35.24,40.13,41.82,35.45,36.07,43.67,24.61,20.9,21.9,18.79,27.61,27.21,26.61,29.77,20.59,27.53,13.82,33.2,25,33.1,36.65,18.63,14.87,22.2,36.81,25.53,24.62,26.25,18.21,28.08,19.42,29.79,32.8,35.99,28.32,27.79,35.88,29.06,36.28,14.1,36.63,37.49,26.9,18.58,38.48,24.48,18.95,33.55,14.24,29.04,32.51,25.63,22.22,19,32.73,15.16,13.9,27.2,32.01,29.27,33,13.74,20.42,27.32,18.23,35.35,28.48,9.08,24.62,20.12,35.26,19.92,31.02,16.49,12.16,30.7,31.22,34.65,13.13,27.51,33.2,31.57,14.1,33.42,17.44,10.12,24.42,9.82,23.39,30.93,15.03,21.67,31.09,33.29,22.61,26.89,23.48,8.38,27.81,32.35,23.84]
friend_counts = Counter(num_friends)

data = zip(num_friends,daily_minutes)
outlier = num_friends.index(100)
## cleaning data

num_friends_good = [x 
                    for i, x in enumerate(num_friends) 
                    if i != outlier]

daily_minutes_good = [x 
                      for i, x in enumerate(daily_minutes) 
                      if i != outlier]
data = zip(num_friends_good,daily_minutes_good)
data = sorted(data)
x,y = zip(*data)
xarray = np.array([[e] for e in x]) # this format prepare for sklearn.LinearRegression

yarray = np.array(y)

## 1. gradient descent 

theta_initial = [0,1]
gdp = gd.GradientDescent(data,theta_initial)

mini_batch_size=10; iter_no=10**4; tol=10**-4;alpha = 10**-4 
gdp.predict(mini_batch_size,iter_no,tol,alpha) #predict(self,mini_batch_size,iter_no,tol,alpha):


## 2. scikit-learn linear model

lm = linear_model.LinearRegression()
lm.fit(xarray,yarray)
lm_coef = [lm.intercept_,lm.coef_]
## 3. analytic solution: y = beta x + beta0

corr = np.corrcoef(x,y)[0][1] # correlation of x,y 

stdx = np.std(x)
stdy = np.std(y)

beta = corr*stdy / stdx
beta0 = np.mean(y) - beta*np.mean(x)
y_anal = xarray*beta + beta0
analytic_coef = [beta0,beta]
## plot


y_lm = lm.predict(xarray)
y_gd = np.array(xarray)*gdp.theta[1] + gdp.theta[0]

plt.plot(x,y_lm,label='LinearRegression(scikit-learn)',linewidth=2)
plt.scatter(x,y,color='black' ,label='data sets')
plt.plot(x,y_gd,label='mini-batch(n={}) gradient descent'.format(mini_batch_size),linewidth=2)
plt.plot(x,y_anal,'--',label='analytic',linewidth =2)
plt.legend(loc='upper left')

plt.xlabel('# of friends')
plt.ylabel('minutes per day')
plt.axis([0,50,0,100])
plt.show()

## print


print "gradient descent: {}, sklearn:{}, analytic:{} ".format(gdp.theta,lm_coef,analytic_coef)
← 梯度下降法-Gradient Descent 多變數線性迴歸(一)-數值篇 →
 
comments powered by Disqus