motivicのチラ裏

非モテの非モテによる非モテのためのブログ

Rで3層線形ニューラルネットワークのWBICの計算をしてみた

WAICを計算したついでにWBICも計算しちゃおうってことでしてみた。

元ネタは渡辺先生のHPのこちら。
http://watanabe-www.math.dis.titech.ac.jp/users/swatanab/wbic2012.html

ここでやろうとしていることは、3層線形ニューラルネットワーク(縮小ランク回帰)で2層目の中間素子が3個である真のモデルから発生させたデータから、中間素子が1~6個である候補モデルからWBICで正しくモデル選択ができるかを見ています。

また、実対数閾値λの推定値も算出しています。

# Constants in Reduced Rank Regression Y=B*A*X+noise
MM <- 6     # Input Dimension
HMAX <- 6   # Hidden Dimension
NN <- 6     # Output Dimension
H0 <- 3     # True Hidden Dimension
SD1 <- 3    # Standard Deviation of Input Distribution 
SD2 <- 0.1  # Standard Deviation of Output Noise
SD3 <- 10.0 # Standard Deviation of Prior Distribution
SD6 <- 0.2  # Standard Deviation of True Parameter Making

# Constants in Metropolis Method
KK <- 2000         # Number of Parameters from Posterior Distribution
n <- 500          # Number of training samples
BURNIN <- 50000   # Burn-in in Metroplois method 
INTER <- 100      # sampling interval in Metroplois method
MONTEC <- BURNIN+KK*INTER # Total Samling in Metroplis method
BETA <- 1/log(n)  # inverse temperature 
SMALLVAL <- 0.5   # Constant for calculation of RLCT
SD4 <- 0.0012     # Standard Deviation of Monte Carlo Step for A
SD5 <- 0.0012     # Standard Deviation of Monte Carlo Step for B
# SD4 and SD5 are determined so that 0.1< acceptance probability <0.9. 

#Ture parameter is determined
A0 <- matrix(0,H0,MM)
B0 <- matrix(0,NN,H0)
X  <- matrix(0,MM,n)
Y  <- matrix(0,NN,n)
C  <- matrix(0,NN,n)

for(i in 1:H0){
  A0[i,] <- SD6*rnorm(MM)
}
for(i in 1:NN){
  B0[i,] <- SD6*rnorm(H0)
}

#Input and Output are determined
for(i in 1:MM){
  X[i,] <- SD1*rnorm(n)  # random inputs 
}
for(i in 1:NN){
  C[i,] <- SD2*rnorm(n)
}
Y <- B0%*%A0%*%X + C # random outputs

# Functions for Likelihood and prior
loglik <- function(A,B){
  return((1/(2*SD2*SD2)*(sum(diag(((Y-B%*%A%*%X)%*%t(Y-B%*%A%*%X))))))+NN*n*log(SD2))
}

prior <- function(A,B){
  return(1/(2*SD3*SD3)*(sum(diag((A%*%t(A))))+sum(diag((B%*%t(B))))))
}

LLL=matrix(0,1,KK)   #Log Likelihood for parameters

Learner_Rank <- vector()
WBIC <- vector()
RLCT_th <- vector()
RLCT_est <- vector()
Accept_Prob <- vector()

#Model Selection Start
for(HH in 1:HMAX){
  #Initial parameters
  A=matrix(0,HH,MM)
  B=matrix(0,NN,HH)

  # Metropolis preparation
  ENERGY1=loglik(A,B)
  HAMILTON1=BETA*ENERGY1+prior(A,B)
  rec=0
  acceptance=0
  maxlll=0

  #Metropolis BEGIN
  for(t in 1:MONTEC){
    # Metropolis Step
    AA <- matrix(0,HH,MM)
    BB <- matrix(0,NN,HH)
    AAE <- matrix(0,HH,MM)
    BBE <- matrix(0,NN,HH)
    
    for(i in 1:HH){
      AAE[i,] <- SD4*rnorm(MM)
    }
    for(i in 1:NN){
      BBE[i,] <- SD5*rnorm(HH)
    }
    
    AA=A+AAE
    BB=B+BBE
    ENERGY2=loglik(AA,BB)
    HAMILTON2=BETA*ENERGY2+prior(AA,BB)
    DELTA=HAMILTON2-HAMILTON1

    #Accept or Reject
    if(exp(-DELTA)>runif(1)){
      A=AA
      B=BB
      HAMILTON1=HAMILTON2
      ENERGY1=ENERGY2
      if(t>BURNIN){
        acceptance=acceptance+1
      }
    }
    # Record Likelihood
    if(t%%INTER==0 && t>BURNIN){
      rec=rec+1
      LLL[rec]=ENERGY1
      if(t==1 | ENERGY1>maxlll){
        maxlll=ENERGY1
      }
    }
  }#Metropolis END

  # WBIC
  sum1 = mean(LLL)
  
  #Estimate Real Canonical Log Threshold(実対数閾値) BEGIN
  sum2 = mean(LLL*exp((-SMALLVAL*BETA*(LLL-maxlll))))
  sum3 = mean(exp(-SMALLVAL*BETA*(LLL-maxlll)))
  sum2 = sum2/sum3
  lambda2=(sum1-sum2)/((1-(1/(1+SMALLVAL)))*log(n))
  #Estimate RLCT End

  #Theoretical Real Canonical Log Threshold of reduced rank regression BEGIN
  lambda1=(2*(HH+H0)*(MM+NN)-(MM-NN)*(MM-NN)-(HH+H0)*(HH+H0))/8
  if(((MM+NN+HH+H0)%%2)==1){
    lambda1=(2*(HH+H0)*(MM+NN)-(MM-NN)*(MM-NN)-(HH+H0)*(HH+H0)+1)/8
  }
  if(HH<H0){
    lambda1=HH*(MM+NN-HH)/2  
  }

  acceptr=acceptance/(MONTEC-BURNIN)
  Learner_Rank[HH] <- HH
  WBIC[HH] <- sum1
  RLCT_th[HH] <- lambda1
  RLCT_est[HH] <- lambda2
}

result <- data.frame(LR=Learner_Rank,WBIC=WBIC,RLCT_th=RLCT_th,RLCT_est=RLCT_est)
result

結果はこちら

  LR      WBIC RLCT_th  RLCT_est
1  1 13346.171     5.5  5.286158
2  2 -2247.739    10.0 10.170242
3  3 -5389.231    13.5 13.913074
4  4 -5383.064    15.0 14.700155
5  5 -5375.218    16.0 16.354419
6  6 -5371.550    17.0 17.293711

ってことで、中間素子が3個の時がWBICが最も小さく、正しくモデル選択できました!
また、実対数閾値の推定値は理論値に近い数字となっています。

ちなみに計算時間は6分30秒でした。