RStanでWAICの計算をしてみた
前回の続きです。
前回の記事ではRのfor文でMCMCをしていたので計算が遅かったのですが、RにはRStanというハミルトニアンモンテカルロ法で高速なMCMCをしてくれるライブラリがあるので、今回はこれを使ってWAICを計算してみましょう。
やっていることは前回と全く同じです。ただ、今回はWAICとGE(Generalized Error)のみの計算をして、DICは計算していません。
library(rstan) TRIAL_N <- 50 # Independent Trial Number NNN=200 # Number of training samples TTT <- 5000 # Number of testing samples AA=0.5 # True mixture ratio B1=0.3 # True b1 B2=0.3 # True b2 B3=0.3 # True b3 STD=1.0 # standard deviation of each normal distribution sss=1/(2*STD*STD) ttt=1/(sqrt(2*pi*STD*STD)^3) pmodel<-function(x,y,z,a,b1,b2,b3){ return(ttt*((1-a)*exp(-sss*(x*x+y*y+z*z)) +a*exp(-sss*((x-b1)*(x-b1)+(y-b2)*(y-b2)+(z-b3)*(z-b3))))) } model_normal_mixture <- ' data{ int N; real W[N]; real X[N]; real Y[N]; real Z[N]; int<lower=1> K; } parameters{ real<lower=0,upper=1> a; real b1; real b2; real b3; } model{ // Priors a ~ uniform(0,1); b1 ~ normal(0,5); b2 ~ normal(0,5); b3 ~ normal(0,5); // normal mixture model for(i in 1:N){ real ps[K]; // temp for log component densities ps[1] <- log(1-a) + normal_log(W[i],X[i],1) + normal_log(W[i],Y[i],1) + normal_log(W[i],Z[i],1); ps[2] <- log(a) + normal_log(W[i],X[i]-b1,1) + normal_log(W[i],Y[i]-b2,1) + normal_log(W[i],Z[i]-b3,1); increment_log_prob(log_sum_exp(ps)); } } ' XT1 <- STD*rnorm(TTT,mean=0,sd=1) XT2 <- STD*rnorm(TTT,mean=0,sd=1) XT3 <- STD*rnorm(TTT,mean=0,sd=1) for (i in 1:TTT){ if(i<AA*TTT+1){ XT1[i]<-XT1[i]+B1 XT2[i]<-XT2[i]+B2 XT3[i]<-XT3[i]+B3 } } WAIC <- vector() GE <- vector() for(j in 1:TRIAL_N){ XX1 <- STD*rnorm(NNN,mean=0,sd=1) XX2 <- STD*rnorm(NNN,mean=0,sd=1) XX3 <- STD*rnorm(NNN,mean=0,sd=1) YY <- runif(NNN) for(i in 1:NNN){ if(YY[i]<AA){ XX1[i] <- XX1[i]+B1 XX2[i] <- XX2[i]+B2 XX3[i] <- XX3[i]+B3 } } data_training <- list( N = NNN, W = pmodel(XX1,XX2,XX3,AA,B1,B2,B3), X = XX1, Y = XX2, Z = XX3, K = 2 ) if(j==1){ stan.model <- stan(model_code=model_normal_mixture, data=data_training, chains=3) } else{ stan.model.redo <- stan(fit=stan.model, data=data_training, chains=3) } if(j==1){ param <- extract(stan.model) } else{ param <- extract(stan.model.redo) } VV <- 0 for(i in 1:NNN){ loglik <- log(pmodel(XX1[i],XX2[i],XX3[i], param$a, param$b1, param$b2, param$b3)) VV <- VV + var(loglik) } TE <- 0 for(i in 1:NNN){ pred <- mean(pmodel(XX1[i],XX2[i],XX3[i], param$a, param$b1, param$b2, param$b3)) QQ <- pmodel(XX1[i],XX2[i],XX3[i],AA,B1,B2,B3) TE <- TE + log(QQ/pred) } WAIC[j] <- (TE + VV)/NNN ge <- 0 for(i in 1:TTT){ pre <- mean(pmodel(XT1[i], XT2[i], XT3[i], param$a, param$b1, param$b2, param$b3)) qq <- pmodel(XT1[i], XT2[i], XT3[i],AA,B1,B2,B3) ge <- ge+log(qq/pre)+pre/qq-1 } GE[j] <- ge/TTT } library(ggplot2) library(reshape2) plottmp <- data.frame(trial=1:50,WAIC=WAIC,GE=GE) plottmp2 <- melt(plottmp, id.var = "trial") ggplot(plottmp2, aes(trial, value, group=variable)) + geom_line(aes(colour = variable),size=1) + xlab("trial") + ylab("value")
結果はこんな感じです。
デバッグのため正則なケース(AA=0.5, B1=3, B2=3, B3=3)でも計算したので、ついでに載せておきます。
(やっぱ正則モデルって美しいですねー(ウットリ))
さて、気になる計算時間ですが、前回37分かかったのに対し、今回は7分と5倍位速くなりました。
(DICの計算量は多くないのでDIC分はさほど影響はないはず)
実は前回のMCMCは真の値から始めるというチートをしていたので、実際に計算するには前回のコードだとダメダメなんですね。なので、実際は5倍よりも遥かに高速化に成功していて、実務上使うならRStan一択です!
あと実際にWAICを計算する際には前回も言ったようにTEの中のQQを除いて計算するのでご注意を。