粒子フィルタを実装してみた

カルマンフィルターを試したので、次のステップとして粒子フィルタをR言語で実装してみました。

ソースはChiral's Gistに置いてあります。

粒子フィルタとは?

状態空間モデルで、隠れ状態の遷移と観測モデルのいずれもが非線形な場合に使える状態推定アルゴリズムです。

PRMLにも解説が載ってますが、樋口知之著「予測にいかす統計モデリングの基本」がわかりやすいです。

実装

## particle filter implementation by isobe
 
particle_filter <- function(x0,y,f_noise,f_like,N,M=1) {
  tmax <- nrow(y)
  D <- length(x0) # == ncol(y)
  
  do_noise <- function(x) {
    x1 <- c()
    for (i in 1:N) {
      for (j in 1:M) {
        v <- f_noise(D)
        x1<-rbind(x1,x[i,]+v)
      }
    }
    return(x1)
  }
  
  do_like <- function(x,t) {
    apply(x,1,function(xi) f_like(xi,y[t,]))
  }
 
  resampling <- function(x,w) {
    # input dim(x) = (N*M,D)
    # --> output dim = (D,N)
    wsum <- c()
    for (i in 1:(N*M)) {
      wsum <- c(wsum,sum(w[1:i]))
    }
    total <- sum(w)
    pos <- (1:N) * total/N
    r <- runif(1,0,total) # roulette
    pos <- (pos+r) %% total
    ret <- c()
    for (i in 1:N) {
      j <- which(wsum>=pos[i])[1]
      ret <- cbind(ret,x[j,])
    }
    return(ret)
  }
  
  xx <- list(matrix(x0,D,N))
  
  for (t in 1:tmax) {
    x <- xx[[t]]         # --> dim(x)=(D,N)
    x <- do_noise(t(x)) # --> dim(x)=(N*M,D)
    w <- do_like(x,t) # --> length(w)=N*M
    xx[[t+1]] <- resampling(x,w) # --> dim(x)=(D,N)
  }
  
  return(xx)
}
 
 
##### test program  #####
 
test <- function() {
  x0 <- c(0,0)
  
  y <- rbind(c(4,4),
             c(8,6),
             c(6,-1),
             c(-2,-5),
             c(-8,-9),
             c(-6,0),
             c(-7,3),
             c(-3,6),
             c(0,4))
  
  f_like = function(xt,yt) {
    return(exp(-sum((yt-xt)**2)))
  }
  
  f_noise = function(D) {
    rnorm(D,0,3)
  }
  xx <- particle_filter(x0,y,f_noise=f_noise,f_like=f_like,N=1000)
  
  par(mfrow=c(3,3))
  
  xlim = c(-10,10)
  ylim = c(-10,10)
  for (t in 1:nrow(y)) {
    x <- xx[[t+1]]
    plot(y[1:t,1],y[1:t,2],type="b",col="3",xlab="",ylab="",xlim=xlim,ylim=ylim)
    par(new=T)
    plot(x[1,],x[2,],col="2",xlab=paste("t =",t),ylab="",xlim=xlim,ylim=ylim)
  }
}
 
test()

実行結果

上のコードには2次元平面上でのテストが含まれていて、実行すると以下のようになります。

粒子数N=1000(赤の点)で 時刻10単位分の隠れ状態(緑の点)を推定するという設定です。うまく動いてるようで、良かった。

所感

粒子フィルタはアルゴリズムが簡単な上にうまく推定してくれるのでいいアルゴリズムだと思います。