試しながら学ぶ統計・機械学習メモ

統計、機械学習、数理最適化の理論や実装に関する疑問について、実際に試しながら学んでいく過程を残したメモ

ロジスティック回帰はなぜ分散処理できる?

重回帰はなぜ分散処理できる? - qz70224の統計学メモ

のロジスティック回帰版です。

根本の考え方は基本的に重回帰と同じなので、詳しくはそちらをご覧ください。

段取りは以下の通りです。

ニュートン法については、以下の資料を参考にさせていただきました。
http://www.chokkan.org/publication/survey/prml_chapter4_discriminative_slides.pdf
以下、プログラムです。

サンプルデータ作成

library(MASS)
mu<-c(0,0,0,0)
Sigma<-matrix(NA,4,4)
diag(Sigma)<-1
Sigma[lower.tri(Sigma)]<-c(0.4,0.3,0.1,0.2,0.3,0.6)
Sigma[upper.tri(Sigma)]<-c(0.4,0.3,0.1,0.2,0.3,0.6)
data<-mvrnorm(n=1000000,mu=mu,Sigma=Sigma,empirical=F)
for(m in 1:nrow(data)){
  if(pnorm(data[m,4])>0.5){
    data[m,4]=1
  }else{
    data[m,4]=0  
  }
}
lst<-split(as.data.frame(data),1:10)

計算

#lst=リスト形式のデータ,num=従属変数になる列番号
mpLR<-function(lst,num){
  
  #デザイン行列を作成する(対象列番号を除外して切片をつける)
  dzmatlst<-split(rep(NA,length(lst)),1:length(lst),1)
  for(k in 1:length(dzmatlst)){
    dzmatlst[[k]] <- as.matrix( cbind(1,lst[[k]][,-num]) )
  }
  
  #リスト毎に重回帰を実施して初期値決定
  #リスト数×パラメータ数の行列を作成
  wlm<-matrix(NA,ncol(dzmatlst[[1]]),length(lst))
  #リスト毎に重回帰実施し重みを行列に格納
  for(a in 1:length(lst)){
    wlm[,a]<-solve( t(dzmatlst[[a]])%*%dzmatlst[[a]] ) %*% t(dzmatlst[[a]]) %*% lst[[a]][,num]
  }
  #他に特にいい方法が思いつかなかったので、すべての重回帰のパラメータ推定値を平均して初期値にした
  w<-apply(wlm,1,mean)
  
  #以下よりループ処理
  repeat{
    
    yhat <- split(rep(NA,length(lst)),1:length(lst),1)
    dfb  <- matrix(NA,ncol(dzmatlst[[1]]),length(lst))
    hess <- array(NA,dim=c(ncol(dzmatlst[[1]]),ncol(dzmatlst[[1]]),length(lst)))
    
    for(a in 1:length(lst)){
      #リスト毎に予測値定義
      yhat[[a]] <- 1 / ( 1 + exp(-1 * (dzmatlst[[a]] %*% w) ) )

      #リスト毎に勾配
      dfb[,a] <- t(dzmatlst[[a]]) %*% (yhat[[a]]-lst[[a]][,num])

      #リスト毎にヘッセ行列
      R<-yhat[[a]]*(1-yhat[[a]])
      tXR<-matrix(NA,ncol(dzmatlst[[a]]),nrow(dzmatlst[[a]]))
      for(k in 1:nrow(dzmatlst[[a]])){
        tXR[,k]<-dzmatlst[[a]][k,]*R[k] 
      }
      hess[,,a]<-tXR%*%dzmatlst[[a]]
    }
  
    #勾配、ヘッセ行列の和を求める
    dfb<-rowSums(dfb)
    hess<-apply(hess,c(1,2),sum)

    #重みの更新
    w1 <- w-solve(hess)%*%dfb
    
    #収束判定
    if(max(w1-w)>0.0001){
      w<-w1
    }else{
      break
    }
  }
  se_w<-sqrt(diag(solve(hess)))
  tval<-w/se_w
  for(l in 1:length(lst)){
    if(l==1){
      loglik<-sum( log(yhat[[l]])*lst[[l]][,num] + log(1-yhat[[l]])*(1-lst[[l]][,num]) )      
    }else{
      loglik<-loglik + sum( log(yhat[[l]])*lst[[l]][,num] + log(1-yhat[[l]])*(1-lst[[l]][,num]) )
    }
  }
  AIC<--2*(loglik-ncol(dzmatlst[[1]]))
  
  #出力結果
  out<-list(t(w),se_w,t(tval),-2*loglik,AIC)
  names(out)<-c("w","se_w","tval","-2loglik","AIC")
  return(out)
}

試してみる

start<-proc.time()
res<-mpLR(lst,4)
mpLRtime<-proc.time()-start
> res
$w
             [,1]       [,2]      [,3]     [,4]
[1,] 0.0004474997 -0.4053877 0.5541994 1.335781

$se_w
[1] 0.002353378 0.002715095 0.002707285 0.003178775

$tval
          [,1]      [,2]     [,3]     [,4]
[1,] 0.1901521 -149.3088 204.7067 420.2187

$`-2loglik`
[1] 1073944

$AIC
[1] 1073952

> mpLRtime
   ユーザ   システム       経過  
     34.69       0.06      35.12 


start<-proc.time()
LR<-glm(data[,4]~data[,1]+data[,2]+data[,3],
        family=binomial(link="logit"))
summary(LR)
LRtime<-proc.time()-start
> summary(LR)

Call:
glm(formula = data[, 4] ~ data[, 1] + data[, 2] + data[, 3], 
    family = binomial(link = "logit"))

Deviance Residuals: 
    Min       1Q   Median       3Q      Max  
-3.1390  -0.8898   0.1131   0.8901   3.0114  

Coefficients:
              Estimate Std. Error z value Pr(>|z|)    
(Intercept)  0.0004475  0.0023534    0.19    0.849    
data[, 1]   -0.4053878  0.0027151 -149.31   <2e-16 ***
data[, 2]    0.5541995  0.0027072  204.71   <2e-16 ***
data[, 3]    1.3357808  0.0031787  420.23   <2e-16 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 

(Dispersion parameter for binomial family taken to be 1)

    Null deviance: 1386294  on 999999  degrees of freedom
Residual deviance: 1073944  on 999996  degrees of freedom
AIC: 1073952

Number of Fisher Scoring iterations: 4

パラメータ推定値は一致しました。
glm()より処理速度が遅いです。(笑)

あと、一台のマシンですべての処理をしているので、
テーブルを細かく分ければ分けるほど、処理時間は長くなります(笑)