motivicのチラ裏

非モテの非モテによる非モテのためのブログ

RでWAICを強引に計算させてみた

"R Advent Calendar 2013" 13日目の記事です。

どうも、13日の金曜日に記事を書くことになった幸運の持ち主のmotivicです。

先日WAICについてJapan.RでLTをしたので、まずはこちらをご覧ください。



ということで、WAICスゴイ!早速Rで強引に計算してみましょう。

渡辺先生のmatlabのコードはこちらから辿れます。
http://watanabe-www.math.dis.titech.ac.jp/users/swatanab/dicwaic.html

以下のRのコードはこれの昔のバージョンのものを翻訳したものです。ここでは混合分布のdelicate caseのWAICを計算しています。

研究室の高スペックなコンピュータでも計算するのに37分かかったので、普通のパソコンだと計算に数時間かかるかもしれません。

#True : q=(1-AA)N(0,STD^2) + AA N(BB,STD^2),  BB=(B1,B2,B3)

TRIAL_N <- 50 # Independent Trial Number
NNN <- 200  # Number of training samples # NNN=100, 200, ... 
TTT <- 5000 # Number of testing samples

#delicate case
AA <- 0.5  # True mixture ratio
B1 <- 0.3    # True b1
B2 <- 0.3    # True b2
B3 <- 0.3    # True b3

STD <- 1.0        # standard deviation of each normal distribution
BURNIN <- 20000   # Burn-in Number in MCMC
SIZEM <- 500      # Number of MCMC parameter samples
INTER <- 100      # MCMC sampling interval
MCMC <- BURNIN+SIZEM*INTER     # Total MCMC trial number 
PRIOR <- 0.02    # STD of prior of (b1,b2,b3) is set as 5 : exp(-PRIOR*(b1^2+b2^2+b3^2))
# Note: Prior of aa is the uniform distribution on [0,1]
BETA <- 1.0       # Inverse temperature of MCMC
MCMC_A <- 0.05    # Markov Chain Step for aa. 
MCMC_B <- 0.05    # Markov Chain Step for (b1,b2,b3).

waa <- matrix(0,1,SIZEM)
wb1 <- matrix(0,1,SIZEM)
wb2 <- matrix(0,1,SIZEM)
wb3 <- matrix(0,1,SIZEM)
EEE <- matrix(0,1,SIZEM)
poste <- matrix(0,1,SIZEM)
DIC1 <- matrix(0,1,TRIAL_N)
DIC2 <- matrix(0,1,TRIAL_N)
TE <- matrix(0,1,TRIAL_N)
WAIC <- matrix(0,1,TRIAL_N)
GE <- matrix(0,1,TRIAL_N)
EXCHANGE <- matrix(0,1,TRIAL_N)
CV <- matrix(0,1,TRIAL_N)
CC <- matrix(1,1,NNN)
XX1 <- matrix(0,1,NNN)
XX2 <- matrix(0,1,NNN)
XX3 <- matrix(0,1,NNN)
QQ <- matrix(0,1,NNN)

XT1 <- STD*rnorm(TTT,mean=0,sd=1)
XT2 <- STD*rnorm(TTT,mean=0,sd=1)
XT3 <- STD*rnorm(TTT,mean=0,sd=1)
for (i in 1:TTT){
  if(i<AA*TTT+1){
    XT1[i]<-XT1[i]+B1
    XT2[i]<-XT2[i]+B2
    XT3[i]<-XT3[i]+B3 
  }
}

sss <- 1/(2*STD*STD)
ttt <- 1/(sqrt(2*pi*STD*STD)^3)

pmodel <- function(x,y,z,a,b1,b2,b3){
  return(ttt*((1-a)*exp(-sss*(x*x+y*y+z*z))
              +a*exp(-sss*((x-b1)*(x-b1)+(y-b2)*(y-b2)+(z-b3)*(z-b3)))))
}

HHH <- function(a,b1,b2,b3,XX1,XX2,XX3){
  return(det(PRIOR*(b1^2+b2^2+b3^2)-log((1-a)*exp(-sss*(XX1^2+XX2^2+XX3^2))
    + a*exp(-sss*(XX1^2+XX2^2+XX3^2-2*b1*XX1-2*b2*XX2-2*b3*XX3+(b1^2+b2^2+b3^2)*CC)))%*%t(CC)))
}

for(trial in 1:TRIAL_N){
  XX1 <- STD*rnorm(NNN,mean=0,sd=1)
  XX2 <- STD*rnorm(NNN,mean=0,sd=1)
  XX3 <- STD*rnorm(NNN,mean=0,sd=1)
  YY <- runif(NNN)
  
  #training samples generated
  for(i in 1:NNN){
    if(YY[i]<AA){
      XX1[i] <- XX1[i]+B1
      XX2[i] <- XX2[i]+B2
      XX3[i] <- XX3[i]+B3
    }
  }

  #MCMC process begins (Metropolis Method)
  aa0 <- AA
  b10 <- B1
  b20 <- B2
  b30 <- B3
  
  hh0 <- HHH(aa0,b10,b20,b30,XX1,XX2,XX3)
  k <- 0
  exchange <- 0
  
  for(t in 1:MCMC){
    aa <- aa0+MCMC_A*rnorm(1,mean=0,sd=1)
    bb1 <- b10+MCMC_B*rnorm(1,mean=0,sd=1)
    bb2 <- b20+MCMC_B*rnorm(1,mean=0,sd=1)
    bb3 <- b30+MCMC_B*rnorm(1,mean=0,sd=1)
    while(aa<0){
      aa <- aa+1
    }
    while(aa>1){
      aa <- aa-1
    }
    hh1 <- HHH(aa,bb1,bb2,bb3,XX1,XX2,XX3)
    DD <- hh1-hh0
    
    if(exp(-BETA*DD)>runif(1)){
     aa0 <- aa
     b10 <- bb1
     b20 <- bb2
     b30 <- bb3
     hh0 <- hh1
     if(t>BURNIN){
       exchange <- exchange+1
     }
    }
    if(t%%INTER==0 && t>BURNIN){
      k <- k+1
      waa[k] <- aa0
      wb1[k] <- b10
      wb2[k] <- b20
      wb3[k] <- b30
      EEE[k] <- hh0
    }
  }
  
  EXCHANGE[trial] <- exchange/(MCMC-BURNIN)
  
  #Posterior weight normalization 
  mine <- 0
  for(k in 1:SIZEM){
    if(k==1|mine>EEE[k]){
      mine <- EEE[k]
    }
  }
  
  zzz <- 0
  for(k in 1:SIZEM){
    zzz <- zzz+exp(-(1-BETA)*(EEE[k]-mine))
  }
  
  for(k in 1:SIZEM){
    poste[k] <- exp(-(1-BETA)*(EEE[k]-mine))/zzz 
    #weight of MCMC samples
  }
  
  #Calculation of Training Error
  te <- 0
  for(i in 1:NNN){
    pre <- 0
    for(k in 1:SIZEM){
      pre <- pre+poste[k]*pmodel(XX1[i],XX2[i],XX3[i],waa[k],wb1[k],wb2[k],wb3[k])
    }
    QQ[i] <- pmodel(XX1[i],XX2[i],XX3[i],AA,B1,B2,B3)
    te <- te+log(QQ[i]/pre)  
  }
  
  TE[trial] <- te/NNN
 
  #Calculation of WAIC and DIC
  avaa <- 0
  avb1 <- 0
  avb2 <- 0
  avb3 <- 0
  for(k in 1:SIZEM){
    #DIC1 average parameter
    avaa <- avaa+poste[k]*waa[k] 
    avb1 <- avb1+poste[k]*wb1[k]
    avb2 <- avb2+poste[k]*wb2[k]
    avb3 <- avb3+poste[k]*wb3[k]
  }
  
  VV <- 0
  #WAIC  Functional Variance 
  eff_num <- 0
  #DIC1 effective number of parameters
  CrVa <- 0
  #Importance Sampling Cross Validation
  for(i in 1:NNN){
    pow1 <- 0
    pow2 <- 0
    cvcv <- 0
    for(k in 1:SIZEM){
      tmpmodel <- pmodel(XX1[i],XX2[i],XX3[i],waa[k],wb1[k],wb2[k],wb3[k])
      tmp <- log(tmpmodel)
      pow1 <- pow1+poste[k]*tmp
      pow2 <- pow2+poste[k]*tmp*tmp
      cvcv <- cvcv+poste[k]/tmpmodel    
    }
    VV <- VV+pow2-pow1*pow1
    eff_num <- eff_num - 2.0*( pow1 - log(pmodel(XX1[i],XX2[i],XX3[i],avaa,avb1,avb2,avb3)) )
    CrVa <- CrVa+log(cvcv)+log(QQ[i]) 
  }
  
  #VV is the effective number of parameters in WAIC. This is not equal to the real log canonical threshold.
  #eff_num is the effective number of parameters defined in DIC1.
  DIC1[trial] <- TE[trial]+eff_num/NNN   
  #Training error + effective number of parameters / training sample number
  WAIC[trial] <- TE[trial]+VV/NNN        
  #Training error + functional variance / training sample number
  CV[trial] <- CrVa/NNN                  
  #Importance Sampling Cross Validation
  
  #DIC2
  pow1 <- 0
  pow2 <- 0
  for(k in 1:SIZEM){
    tmp <- 0
    for(i in 1:NNN){
      tmp <- tmp+log(pmodel(XX1[i],XX2[i],XX3[i],waa[k],wb1[k],wb2[k],wb3[k]))    
    }
    pow1 <- pow1+poste[k]*tmp
    pow2 <- pow2+poste[k]*tmp*tmp
  }
  
  DIC2[trial] <- TE[trial]+2*(pow2-pow1*pow1)/NNN 
  #Training error + effective number / training sample number
  
  #Calculation of Generalization Error
  ge <- 0
  for(i in 1:TTT){
    pre <- 0
    for(k in 1:SIZEM){
      pre <- pre+poste[k]*pmodel(XT1[i],XT2[i],XT3[i],waa[k],wb1[k],wb2[k],wb3[k])
    }
    qq <- pmodel(XT1[i],XT2[i],XT3[i],AA,B1,B2,B3)
    ge <- ge+log(qq/pre)+pre/qq-1  
    #Kullback Leibler = int qlog(q/p) = int q(log(q/p)+p/q-1)
  }  
  GE[trial] <- ge/TTT
}

library(ggplot2)
library(reshape2)

plottmp <- data.frame(trial=1:50,WAIC=t(WAIC),GE=t(GE),DIC1=t(DIC1),DIC2=t(DIC2))
plottmp2 <- melt(plottmp, id.var = "trial")

ggplot(plottmp2, aes(trial, value, group=variable)) + 
  geom_line(aes(colour = variable),size=1) + 
  xlab("trial") + 
  ylab("value")

mean(GE)
mean(WAIC)
mean(DIC1)
mean(DIC2)

結果はこんな感じです。

f:id:motivic:20131213230132j:plain

 >mean(GE)
[1] 0.01264311
 >mean(WAIC)
[1] 0.01147506
 >mean(DIC1)
[1] 0.008621296
 >mean(DIC2)
[1] 0.02011052


このプログラムは汎化誤差との比較をするためにWAICのTとして学習誤差で計算しているので注意が必要ですね。学習誤差は真の分布を知らないと計算できないので。

実際にWAICを計算する際にはTとして学習損失を使います。コード的にはteの計算式からlog(QQ[i]])を除くだけです。これは真の分布のみに依存してモデルには依存しないので、除いてもモデル選択には影響しません。

そもそも上記のコードでWAICを計算させるには遅すぎて使い物になりません(追記:と思っていたのですが、1分以内に計算が終わるからありっちゃありかも)。RでMCMCをさせるならRStanでやるのが一番良さそう。ということで、RStanでWAICの計算をと思い、上記の混合分布のWAICの計算を試していたのですが間に合いませんでした… RStanでWAICの計算はまたいずれ。


WAICの計算で事前分布使っていてくぁwせdrfgtyふじこl;という方は以下のリンクからどうぞ
http://watanabe-www.math.dis.titech.ac.jp/users/swatanab/prior.html

WAIC/WBICの簡単な解説はこちらから
http://watanabe-www.math.dis.titech.ac.jp/users/swatanab/waicwbic_j.html
http://watanabe-www.math.dis.titech.ac.jp/users/swatanab/waicwbic.html


RStanのコードまでいけなくて申し訳なかったですが、以上、"R Advent Calendar 2013" 13日目の記事でした。

(12/15追記:RStanによるWAICの計算もしてみました http://motivic.hateblo.jp/entry/2013/12/15/232856