2012-02-15 M-Hアルゴリズムによるポアソン分布推定コードのチューニング において
「目的関数を簡素化してしまう方が効果的です。」 ごもっともです。
その他, 何回も同じ引数で関数を呼ぶとか,関数呼び出しのオーバーヘッドも馬鹿にならないこともあり,以下のようにすると約 40% のスピードアップ
set.seed(1631697)
lambda <- 1
length_x <- 100
x <- rpois(length_x, lambda)
sum_x <- sum(x)
lpoi <- function(p) {
(p > 0) * exp(-length_x * p) * p^sum_x
}
gc()
gc()
system.time({
set.seed(2658817)
s0 <- 0.1
LL0 <- lpoi(s0) ## ここはこのままにしておこう
s <- numeric(100000 + 500)
s[1] <- s0
r <- rnorm(length(s))
ru <- runif(length(s)) ## 追加
for (n in 2:length(s)) {
s1 <- s0 + r[n]
LL1 <- (s1 > 0) * exp(-length_x * s1) * s1^sum_x ## lpoi(s1)
if (min(1, LL1 / LL0) > ru[n]) {
s0 <- s1
LL0 <- LL1
}
s[n] <- s0
}
s <- s[-(1:500)]
cat(sprintf("lambda:%.5f, variance:%.5f\n", mean(s), var(s))) ## %1.5f はおかしい
})
このブログは,プログラムの速度を追求するのではなく,「えれがんと」なプログラミングを目指しているので誤解ないように。「えれがんと」が何を意味するかは範囲は広い。
で,「2012-01-10 Rと手作業で覚える最尤法」の「2. 尤度関数、対数尤度関数、一階条件から最尤法を試みる」について
このプログラムでは,導関数の定義を expression で行っておいてから,使うときに eval で評価している。なるほど,このようにするとある意味「えれがんと」だなあ。
でも,直接指定しても「えれがんと」でなくなるとも見えない。結果としてわずかだが速いし。
なお,連立方程式を解くとき,solve で逆用列を求めて定数項と行列掛け算するのではなく,定数項を solve の引数に渡す方がよいというのは,数値演算の定石。
もとのプログラム(整形済み)
> system.time({for (i in 1:10000) {
+ x <- c(11, 13, 23)
+ n <- length(x)
+ f1 <- expression(-sum(x - mu) / s2)
+ f2 <- expression(-n / (2 * s2) + sum((x - mu) ^ 2) / (2 * (s2 ^ 2)))
+ g11 <- expression(n / s2)
+ g12 <- expression(sum(x - mu) / (s2 ^ 2))
+ g21 <- expression(sum(mu - x) / (s2 ^ 2))
+ g22 <- expression(n / (2 * (s2 ^ 2)) - sum((x - mu) ^ 2) / (s2 ^ 3))
+ mu <- 10
+ s2 <- 10
+ for (i in 1:10) {
+ m <- matrix(c(mu, s2), 2, 1)
+ f <- matrix(c(eval(f1), eval(f2)), 2, 1)
+ j <- matrix(c(eval(g11), eval(g21), eval(g12), eval(g22)), 2, 2)
+ m <- m - solve(j) %*% f
+ mu <- m[1]
+ s2 <- m[2]
+ # print(sprintf("[%d] (mu,s2)=(%f,%f)", i, mu, s2))
+ }
+ #print(sprintf("平均%2.3f、分散%2.3f(標準偏差%2.3f)の正規分布",
+ # mu, s2, sqrt(s2)))
+ }})
ユーザ システム 経過
14.484 0.062 14.270
書き換えたプログラム
> system.time({for (i in 1:10000) {
+ x <- c(11, 13, 23)
+ n <- length(x)
+ mu <- 10
+ s2 <- 10
+ for (i in 1:10) {
+ m <- c(mu, s2)
+ f <- c(-sum(x - mu)/s2,
+ -n/(2 * s2) + sum((x - mu)^2)/(2 * (s2^2)))
+ j <- matrix(c(n/s2,
+ sum(mu - x)/(s2^2),
+ sum(x - mu)/(s2^2),
+ n/(2 * (s2^2)) - sum((x - mu)^2)/(s2^3)), 2, 2)
+ m <- m - solve(j, f)
+ mu <- m[1]
+ s2 <- m[2]
+ # cat(sprintf("[%d] (mu, s2) = (%f, %f)\n", i, mu, s2))
+ }
+ #cat(sprintf("平均%.3f、分散%.3f(標準偏差%.3f)の正規分布\n",
+ # mu, s2, sqrt(s2)))
+ }})
ユーザ システム 経過
8.878 0.022 8.879