motivicのチラ裏

非モテの非モテによる非モテのためのブログ

RStanでWAICの計算をしてみた

前回の続きです。

前回の記事ではRのfor文でMCMCをしていたので計算が遅かったのですが、RにはRStanというハミルトニアンモンテカルロ法で高速なMCMCをしてくれるライブラリがあるので、今回はこれを使ってWAICを計算してみましょう。

やっていることは前回と全く同じです。ただ、今回はWAICとGE(Generalized Error)のみの計算をして、DICは計算していません。

library(rstan)

TRIAL_N <- 50 # Independent Trial Number
NNN=200 # Number of training samples
TTT <- 5000 # Number of testing samples
AA=0.5  # True mixture ratio
B1=0.3  # True b1
B2=0.3  # True b2
B3=0.3  # True b3
STD=1.0 # standard deviation of each normal distribution
sss=1/(2*STD*STD)
ttt=1/(sqrt(2*pi*STD*STD)^3)

pmodel<-function(x,y,z,a,b1,b2,b3){
  return(ttt*((1-a)*exp(-sss*(x*x+y*y+z*z))
              +a*exp(-sss*((x-b1)*(x-b1)+(y-b2)*(y-b2)+(z-b3)*(z-b3)))))
}

model_normal_mixture <- '
data{
int N;
real W[N];
real X[N];
real Y[N];
real Z[N];
int<lower=1> K;
}

parameters{
real<lower=0,upper=1> a;
real b1;
real b2;
real b3;
}

model{
// Priors
a ~ uniform(0,1);
b1 ~ normal(0,5);
b2 ~ normal(0,5);
b3 ~ normal(0,5);
// normal mixture model
for(i in 1:N){
real ps[K]; // temp for log component densities
ps[1] <- log(1-a) + normal_log(W[i],X[i],1) + normal_log(W[i],Y[i],1) + normal_log(W[i],Z[i],1);
ps[2] <- log(a)   + normal_log(W[i],X[i]-b1,1) + normal_log(W[i],Y[i]-b2,1) + normal_log(W[i],Z[i]-b3,1);
increment_log_prob(log_sum_exp(ps));
}
}
'

XT1 <- STD*rnorm(TTT,mean=0,sd=1)
XT2 <- STD*rnorm(TTT,mean=0,sd=1)
XT3 <- STD*rnorm(TTT,mean=0,sd=1)
for (i in 1:TTT){
  if(i<AA*TTT+1){
    XT1[i]<-XT1[i]+B1
    XT2[i]<-XT2[i]+B2
    XT3[i]<-XT3[i]+B3 
  }
}

WAIC <- vector()
GE <- vector()

for(j in 1:TRIAL_N){
  XX1 <- STD*rnorm(NNN,mean=0,sd=1)
  XX2 <- STD*rnorm(NNN,mean=0,sd=1)
  XX3 <- STD*rnorm(NNN,mean=0,sd=1)
  YY  <- runif(NNN)
  
  for(i in 1:NNN){
    if(YY[i]<AA){
      XX1[i] <- XX1[i]+B1
      XX2[i] <- XX2[i]+B2
      XX3[i] <- XX3[i]+B3
    }
  }
  
  data_training <- list(
    N = NNN,
    W = pmodel(XX1,XX2,XX3,AA,B1,B2,B3),
    X = XX1,
    Y = XX2,
    Z = XX3,
    K = 2
  )
  
  if(j==1){
    stan.model <- stan(model_code=model_normal_mixture, data=data_training, chains=3)
  } else{
    stan.model.redo <- stan(fit=stan.model, data=data_training, chains=3)
  }
    
  if(j==1){
    param <- extract(stan.model)
  } else{
    param <- extract(stan.model.redo)
  }
  
  VV <- 0
  for(i in 1:NNN){
    loglik <- log(pmodel(XX1[i],XX2[i],XX3[i], param$a, param$b1, param$b2, param$b3))
    VV <- VV + var(loglik)
  }
  
  TE <- 0
  for(i in 1:NNN){
    pred <- mean(pmodel(XX1[i],XX2[i],XX3[i], param$a, param$b1, param$b2, param$b3))
    QQ <- pmodel(XX1[i],XX2[i],XX3[i],AA,B1,B2,B3)
    TE <- TE + log(QQ/pred)
  }
  WAIC[j] <- (TE + VV)/NNN
  
  ge <- 0
  for(i in 1:TTT){
    pre <- mean(pmodel(XT1[i], XT2[i], XT3[i], param$a, param$b1, param$b2, param$b3))
    qq <- pmodel(XT1[i], XT2[i], XT3[i],AA,B1,B2,B3)
    ge <- ge+log(qq/pre)+pre/qq-1  
  }
  GE[j] <- ge/TTT
}

library(ggplot2)
library(reshape2)

plottmp <- data.frame(trial=1:50,WAIC=WAIC,GE=GE)
plottmp2 <- melt(plottmp, id.var = "trial")

ggplot(plottmp2, aes(trial, value, group=variable)) + 
  geom_line(aes(colour = variable),size=1) + 
  xlab("trial") + 
  ylab("value")

結果はこんな感じです。
f:id:motivic:20131215232637j:plain

デバッグのため正則なケース(AA=0.5, B1=3, B2=3, B3=3)でも計算したので、ついでに載せておきます。
(やっぱ正則モデルって美しいですねー(ウットリ))
f:id:motivic:20131215225553j:plain

さて、気になる計算時間ですが、前回37分かかったのに対し、今回は7分と5倍位速くなりました。
(DICの計算量は多くないのでDIC分はさほど影響はないはず)

実は前回のMCMCは真の値から始めるというチートをしていたので、実際に計算するには前回のコードだとダメダメなんですね。なので、実際は5倍よりも遥かに高速化に成功していて、実務上使うならRStan一択です!

あと実際にWAICを計算する際には前回も言ったようにTEの中のQQを除いて計算するのでご注意を。