RでWAICを強引に計算させてみた
"R Advent Calendar 2013" 13日目の記事です。
どうも、13日の金曜日に記事を書くことになった幸運の持ち主のmotivicです。
先日WAICについてJapan.RでLTをしたので、まずはこちらをご覧ください。
ということで、WAICスゴイ!早速Rで強引に計算してみましょう。
渡辺先生のmatlabのコードはこちらから辿れます。
http://watanabe-www.math.dis.titech.ac.jp/users/swatanab/dicwaic.html
以下のRのコードはこれの昔のバージョンのものを翻訳したものです。ここでは混合分布のdelicate caseのWAICを計算しています。
研究室の高スペックなコンピュータでも計算するのに37分かかったので、普通のパソコンだと計算に数時間かかるかもしれません。
#True : q=(1-AA)N(0,STD^2) + AA N(BB,STD^2), BB=(B1,B2,B3) TRIAL_N <- 50 # Independent Trial Number NNN <- 200 # Number of training samples # NNN=100, 200, ... TTT <- 5000 # Number of testing samples #delicate case AA <- 0.5 # True mixture ratio B1 <- 0.3 # True b1 B2 <- 0.3 # True b2 B3 <- 0.3 # True b3 STD <- 1.0 # standard deviation of each normal distribution BURNIN <- 20000 # Burn-in Number in MCMC SIZEM <- 500 # Number of MCMC parameter samples INTER <- 100 # MCMC sampling interval MCMC <- BURNIN+SIZEM*INTER # Total MCMC trial number PRIOR <- 0.02 # STD of prior of (b1,b2,b3) is set as 5 : exp(-PRIOR*(b1^2+b2^2+b3^2)) # Note: Prior of aa is the uniform distribution on [0,1] BETA <- 1.0 # Inverse temperature of MCMC MCMC_A <- 0.05 # Markov Chain Step for aa. MCMC_B <- 0.05 # Markov Chain Step for (b1,b2,b3). waa <- matrix(0,1,SIZEM) wb1 <- matrix(0,1,SIZEM) wb2 <- matrix(0,1,SIZEM) wb3 <- matrix(0,1,SIZEM) EEE <- matrix(0,1,SIZEM) poste <- matrix(0,1,SIZEM) DIC1 <- matrix(0,1,TRIAL_N) DIC2 <- matrix(0,1,TRIAL_N) TE <- matrix(0,1,TRIAL_N) WAIC <- matrix(0,1,TRIAL_N) GE <- matrix(0,1,TRIAL_N) EXCHANGE <- matrix(0,1,TRIAL_N) CV <- matrix(0,1,TRIAL_N) CC <- matrix(1,1,NNN) XX1 <- matrix(0,1,NNN) XX2 <- matrix(0,1,NNN) XX3 <- matrix(0,1,NNN) QQ <- matrix(0,1,NNN) XT1 <- STD*rnorm(TTT,mean=0,sd=1) XT2 <- STD*rnorm(TTT,mean=0,sd=1) XT3 <- STD*rnorm(TTT,mean=0,sd=1) for (i in 1:TTT){ if(i<AA*TTT+1){ XT1[i]<-XT1[i]+B1 XT2[i]<-XT2[i]+B2 XT3[i]<-XT3[i]+B3 } } sss <- 1/(2*STD*STD) ttt <- 1/(sqrt(2*pi*STD*STD)^3) pmodel <- function(x,y,z,a,b1,b2,b3){ return(ttt*((1-a)*exp(-sss*(x*x+y*y+z*z)) +a*exp(-sss*((x-b1)*(x-b1)+(y-b2)*(y-b2)+(z-b3)*(z-b3))))) } HHH <- function(a,b1,b2,b3,XX1,XX2,XX3){ return(det(PRIOR*(b1^2+b2^2+b3^2)-log((1-a)*exp(-sss*(XX1^2+XX2^2+XX3^2)) + a*exp(-sss*(XX1^2+XX2^2+XX3^2-2*b1*XX1-2*b2*XX2-2*b3*XX3+(b1^2+b2^2+b3^2)*CC)))%*%t(CC))) } for(trial in 1:TRIAL_N){ XX1 <- STD*rnorm(NNN,mean=0,sd=1) XX2 <- STD*rnorm(NNN,mean=0,sd=1) XX3 <- STD*rnorm(NNN,mean=0,sd=1) YY <- runif(NNN) #training samples generated for(i in 1:NNN){ if(YY[i]<AA){ XX1[i] <- XX1[i]+B1 XX2[i] <- XX2[i]+B2 XX3[i] <- XX3[i]+B3 } } #MCMC process begins (Metropolis Method) aa0 <- AA b10 <- B1 b20 <- B2 b30 <- B3 hh0 <- HHH(aa0,b10,b20,b30,XX1,XX2,XX3) k <- 0 exchange <- 0 for(t in 1:MCMC){ aa <- aa0+MCMC_A*rnorm(1,mean=0,sd=1) bb1 <- b10+MCMC_B*rnorm(1,mean=0,sd=1) bb2 <- b20+MCMC_B*rnorm(1,mean=0,sd=1) bb3 <- b30+MCMC_B*rnorm(1,mean=0,sd=1) while(aa<0){ aa <- aa+1 } while(aa>1){ aa <- aa-1 } hh1 <- HHH(aa,bb1,bb2,bb3,XX1,XX2,XX3) DD <- hh1-hh0 if(exp(-BETA*DD)>runif(1)){ aa0 <- aa b10 <- bb1 b20 <- bb2 b30 <- bb3 hh0 <- hh1 if(t>BURNIN){ exchange <- exchange+1 } } if(t%%INTER==0 && t>BURNIN){ k <- k+1 waa[k] <- aa0 wb1[k] <- b10 wb2[k] <- b20 wb3[k] <- b30 EEE[k] <- hh0 } } EXCHANGE[trial] <- exchange/(MCMC-BURNIN) #Posterior weight normalization mine <- 0 for(k in 1:SIZEM){ if(k==1|mine>EEE[k]){ mine <- EEE[k] } } zzz <- 0 for(k in 1:SIZEM){ zzz <- zzz+exp(-(1-BETA)*(EEE[k]-mine)) } for(k in 1:SIZEM){ poste[k] <- exp(-(1-BETA)*(EEE[k]-mine))/zzz #weight of MCMC samples } #Calculation of Training Error te <- 0 for(i in 1:NNN){ pre <- 0 for(k in 1:SIZEM){ pre <- pre+poste[k]*pmodel(XX1[i],XX2[i],XX3[i],waa[k],wb1[k],wb2[k],wb3[k]) } QQ[i] <- pmodel(XX1[i],XX2[i],XX3[i],AA,B1,B2,B3) te <- te+log(QQ[i]/pre) } TE[trial] <- te/NNN #Calculation of WAIC and DIC avaa <- 0 avb1 <- 0 avb2 <- 0 avb3 <- 0 for(k in 1:SIZEM){ #DIC1 average parameter avaa <- avaa+poste[k]*waa[k] avb1 <- avb1+poste[k]*wb1[k] avb2 <- avb2+poste[k]*wb2[k] avb3 <- avb3+poste[k]*wb3[k] } VV <- 0 #WAIC Functional Variance eff_num <- 0 #DIC1 effective number of parameters CrVa <- 0 #Importance Sampling Cross Validation for(i in 1:NNN){ pow1 <- 0 pow2 <- 0 cvcv <- 0 for(k in 1:SIZEM){ tmpmodel <- pmodel(XX1[i],XX2[i],XX3[i],waa[k],wb1[k],wb2[k],wb3[k]) tmp <- log(tmpmodel) pow1 <- pow1+poste[k]*tmp pow2 <- pow2+poste[k]*tmp*tmp cvcv <- cvcv+poste[k]/tmpmodel } VV <- VV+pow2-pow1*pow1 eff_num <- eff_num - 2.0*( pow1 - log(pmodel(XX1[i],XX2[i],XX3[i],avaa,avb1,avb2,avb3)) ) CrVa <- CrVa+log(cvcv)+log(QQ[i]) } #VV is the effective number of parameters in WAIC. This is not equal to the real log canonical threshold. #eff_num is the effective number of parameters defined in DIC1. DIC1[trial] <- TE[trial]+eff_num/NNN #Training error + effective number of parameters / training sample number WAIC[trial] <- TE[trial]+VV/NNN #Training error + functional variance / training sample number CV[trial] <- CrVa/NNN #Importance Sampling Cross Validation #DIC2 pow1 <- 0 pow2 <- 0 for(k in 1:SIZEM){ tmp <- 0 for(i in 1:NNN){ tmp <- tmp+log(pmodel(XX1[i],XX2[i],XX3[i],waa[k],wb1[k],wb2[k],wb3[k])) } pow1 <- pow1+poste[k]*tmp pow2 <- pow2+poste[k]*tmp*tmp } DIC2[trial] <- TE[trial]+2*(pow2-pow1*pow1)/NNN #Training error + effective number / training sample number #Calculation of Generalization Error ge <- 0 for(i in 1:TTT){ pre <- 0 for(k in 1:SIZEM){ pre <- pre+poste[k]*pmodel(XT1[i],XT2[i],XT3[i],waa[k],wb1[k],wb2[k],wb3[k]) } qq <- pmodel(XT1[i],XT2[i],XT3[i],AA,B1,B2,B3) ge <- ge+log(qq/pre)+pre/qq-1 #Kullback Leibler = int qlog(q/p) = int q(log(q/p)+p/q-1) } GE[trial] <- ge/TTT } library(ggplot2) library(reshape2) plottmp <- data.frame(trial=1:50,WAIC=t(WAIC),GE=t(GE),DIC1=t(DIC1),DIC2=t(DIC2)) plottmp2 <- melt(plottmp, id.var = "trial") ggplot(plottmp2, aes(trial, value, group=variable)) + geom_line(aes(colour = variable),size=1) + xlab("trial") + ylab("value") mean(GE) mean(WAIC) mean(DIC1) mean(DIC2)
結果はこんな感じです。
>mean(GE) [1] 0.01264311 >mean(WAIC) [1] 0.01147506 >mean(DIC1) [1] 0.008621296 >mean(DIC2) [1] 0.02011052
このプログラムは汎化誤差との比較をするためにWAICのTとして学習誤差で計算しているので注意が必要ですね。学習誤差は真の分布を知らないと計算できないので。
実際にWAICを計算する際にはTとして学習損失を使います。コード的にはteの計算式からlog(QQ[i]])を除くだけです。これは真の分布のみに依存してモデルには依存しないので、除いてもモデル選択には影響しません。
そもそも上記のコードでWAICを計算させるには遅すぎて使い物になりません(追記:と思っていたのですが、1分以内に計算が終わるからありっちゃありかも)。RでMCMCをさせるならRStanでやるのが一番良さそう。ということで、RStanでWAICの計算をと思い、上記の混合分布のWAICの計算を試していたのですが間に合いませんでした… RStanでWAICの計算はまたいずれ。
WAICの計算で事前分布使っていてくぁwせdrfgtyふじこl;という方は以下のリンクからどうぞ
http://watanabe-www.math.dis.titech.ac.jp/users/swatanab/prior.html
WAIC/WBICの簡単な解説はこちらから
http://watanabe-www.math.dis.titech.ac.jp/users/swatanab/waicwbic_j.html
http://watanabe-www.math.dis.titech.ac.jp/users/swatanab/waicwbic.html
RStanのコードまでいけなくて申し訳なかったですが、以上、"R Advent Calendar 2013" 13日目の記事でした。
(12/15追記:RStanによるWAICの計算もしてみました http://motivic.hateblo.jp/entry/2013/12/15/232856)