Restricted Boltzmann Machineを実装してみた
R言語で実装してみました。特にパッケージは使ってないのでコピペすれば動くと思います。
このプログラムは隠れノード数やCD-kのkや更新ステップのηや収束条件のパラメータ等を変えて試すことが出来ます。
また学習後のRBMを使ってreconstruct(後述)したり隠れノードのサンプリングや確率計算をしたり出来ます。
ソース
### Restricted Boltzmann Machine implementation by isobe sigmoid <- function(x) 1/(1+exp(-x)) rbm <- function(obs,n_hidden,eta=0.05, epsilon=0.05,maxiter=100, CD_k=1,reconstruct_trial=10, verbose=0) { L <- nrow(obs) N <- ncol(obs) M <- n_hidden # initial values assinment # cf) Chapter 8 in # http://www.cs.toronto.edu/~hinton/absps/guideTR.pdf pn <- apply(obs,2,function(x) min(0.9,sum(x)/L)) bn <- log(pn/(1-pn)) bm <- rep(0,M) W <- matrix(rnorm(N*M,0,0.01),N,M) pv_h <- function(i,h) { sigmoid(sum(W[i,]*h)+bn[i]) } ph_v <- function(i,v) { sigmoid(sum(W[,i]*v)+bm[i]) } gs_step <- function(x,n,p_func) { r<-c() for (i in 1:n) { r<-c(r,rbinom(1,1,p_func(i,x))) } return(r) } gs_v <- function(h) gs_step(h,N,pv_h) gs_h <- function(v) gs_step(v,M,ph_v) cd_k <- function(v) { v1 <- v for (i in 1:CD_k) { h1 <- gs_h(v1) v1 <- gs_v(h1) } # R has immutable value and lexical scope, # so we can overwrite locally. for (i in 1:N) for (j in 1:M) { W[i,j] <- ph_v(j,v)*v[i]-ph_v(j,v1)*v1[i] } bn <- v-v1 for (j in 1:M) { bm[j] <- ph_v(j,v)-ph_v(j,v1) } return(list(W=W,bn=bn,bm=bm)) } theta_step <- function() { W <- matrix(0,N,M) bn <- rep(0,N) bm <- rep(0,M) for (i in 1:L) { if (verbose>=3) cat(paste("theta for obs ",i,"\n")) d <- cd_k(obs[i,]) W <- W+d$W bn <- bn+d$bn bm <- bm+d$bm } return(list(W=W,bn=bn,bm=bm)) } reconstruct <- function(v) gs_v(gs_h(v)) recon_error <- function() { r <- 0 for (t in 1:reconstruct_trial) for (i in 1:L) { v <- obs[i,] v1 <- reconstruct(v) r <- r+sum(abs(v-v1)) } return(r/(N*L*reconstruct_trial)) } err <- 1 count <- 0 cat("init OK. \n") while (err>epsilon && count<maxiter) { if (verbose>=2) cat(paste("step =",count,"\n")) d <- theta_step() backup <- list(W=W,bn=bn,bm=bm,err=err) W <- W + eta*d$W bn <- bn + eta*d$bn bm <- bm + eta*d$bm count <- count+1 err <- recon_error() if (backup$err<err) { W <- backup$W bn <- backup$bn bm <- backup$bm err <- backup$err } else if (verbose) { if (verbose>=1) print(paste("step",count,": err=",err)) } } hidden_prob <- function(v) { apply(rbind(1:M),1,function(i) ph_v(i,v)) } learn_info=paste("step",count,": err=",err) obj <- list(W=W,bn=bn,bm=bm, learn_info=learn_info, hidden_prob=hidden_prob, hidden_sample=gs_h, reconstruct=reconstruct) class(obj) <- 'rbm' return(obj) } print.rbm <- function(rbm) { cat("edge weights:\n") print(rbm$W) cat("\nbias for observable nodes:\n") print(rbm$bn) cat("\nbias for hidden nodes:\n") print(rbm$bm) cat(paste("\n",rbm$learn_info,"\n",sep='')) } rbm_hidden_prob <- function(obj,obs) obj$hidden_prob(obs) rbm_hidden_sample <- function(obj,obs) obj$hidden_sample(obs) rbm_reconstruct <- function(obj,obs) obj$reconstruct(obs) ### test program test <- function() { obs <- rbind(c(1,0,1), c(1,1,0), c(1,0,1), c(0,1,1)) net <- rbm(obs,2,verbose=1,maxiter=3000) print(net) x <- c(1,1,0) trial <- 5 cat("original") print(x) for (t in 1:trial) { cat("reconstructed") print(rbm_reconstruct(net,x)) } } test()
テスト
上のソースには観測ノード数が3つ、隠れノード数が3つでXORすると1になるパリティチェック的なものを構成できるかのテストが書いてあります。実行すると以下のような感じに。
init OK. [1] "step 1 : err= 0.408333333333333" [1] "step 2 : err= 0.35" [1] "step 46 : err= 0.35" [1] "step 53 : err= 0.35" [1] "step 104 : err= 0.341666666666667" [1] "step 105 : err= 0.325" [1] "step 158 : err= 0.325" [1] "step 321 : err= 0.325" [1] "step 363 : err= 0.283333333333333" [1] "step 444 : err= 0.283333333333333" [1] "step 1181 : err= 0.266666666666667" edge weights: [,1] [,2] [1,] 0.07643391 0.08229904 [2,] 0.02417799 0.02915906 [3,] -0.05989687 -0.07663989 bias for observable nodes: [1] 1.2486123 0.0500000 0.9486123 bias for hidden nodes: [1] -9.421639e-04 -5.015886e-05 step 3000 : err= 0.266666666666667 original[1] 1 1 0 reconstructed[1] 0 0 1 reconstructed[1] 1 1 1 reconstructed[1] 1 0 1 reconstructed[1] 1 0 1 reconstructed[1] 1 1 1
3ビットのパターンでパリティが0になる4つを与えて、学習させた結果で110というビットを食わせて観測ノード→隠れノード→観測ノードとしたときの確率でサンプリング(reconstructionというそうです)を何度かやってみるというものです。
RBMはそもそも1層で何かを学習させるものではないと思うのでまぁあんまりうまくいってないですね。他の関数として8bitの2進数表現で4以上の数値を与えるようにするとパリティ関数よりは少し精度があがりました。
というわけで次はこれを使ってDeep Belief Networkにチャレンジしてみます。ワクワク!