裏 RjpWiki

文字通り,RjpWiki の裏を行きます
R プログラム コンピュータ・サイエンス 統計学

Julia に翻訳--020

2021年03月04日 | ブログラミング

#==========
Julia の修行をするときに,いろいろなプログラムを書き換えるのは有効な方法だ。
以下のプログラムを Julia に翻訳してみる。

シンプレックス法によるパラメータ推定
http://aoki2.si.gunma-u.ac.jp/R/simplex.html

ファイル名: simplex.jl  関数名: simplex

翻訳するときに書いたメモ

loops は for ループ中のローカル変数なので,ループが回りきったかどうかを知るために for ループの外で参照しようとしてもできない(loops は存在しないというエラーメッセージが出る)。
その他にも,元の R プログラムでは求めるパラメータもローカル変数なので for ループの外で利用することができないなどのため,パラメータが求まったときの処理を for ループ内で行うように変更した。

==========#

using Plots

function simplex(fun, start, x, y; MAXIT = 10000, EPSILON = 1e-7,
                 LO = 0.8, HI = 1.2, plotflag = false)
    # one line function definition
    residual(x, y, p) = sum((y .- fun(x, p)) .^ 2)

    ip = length(start)
    ip1 = ip + 1
    ip2 = ip + 2
    ip3 = ip + 3
    pa = reshape(repeat(start, ip3), ip, :)
    for i = 1:ip
        pa[i, i] = start[i] * rand(1)[1] * (HI - LO) + LO
    end
    res = vcat([residual(x, y, pa[:, i]) for i = 1:ip1], 0, 0)
    converge = false
    for loops = 1:MAXIT
        res0 = res[1:ip1]
        mx = argmax(res0)
        mi = argmin(res0)
        s = sum(pa[:, 1:ip1], dims = 2)
        if res[mx] < EPSILON || res[mi] < EPSILON ||
           (res[mx] - res[mi]) / res[mi] < EPSILON
            converge = true
            parameters = pa[:, mi]
            residuals = residual(x, y, parameters)
            if plotflag
                pyplot()
                plt = scatter(x, y, tick_direction = :out, grid = false,
                    markercolor = :blue, label = "")
                x = range(minimum(x), maximum(x), length = 1000)
                plot!(x, fun(x, parameters), label = "")
            end
            display(plt)
            return Dict(:converge => converge, :parameters => parameters, :residuals => residuals)
        end
        i = ip2
        pa[:, ip2] = (2 * s - ip2 * pa[:, mx]) / ip
        res[ip2] = residual(x, y, pa[:, ip2])
        if res[ip2] < res[mi]
            pa[:, ip3] = (3 * s - (2 * ip1 + 1) * pa[:, mx]) / ip
            res[ip3] = residual(x, y, pa[:, ip3])
            if res[ip3] <= res[ip2]
                i = ip3
            end
        elseif res[ip2] > res[mx]
            pa[:, ip3] = s / ip1
            res[ip3] = residual(x, y, pa[:, ip3])
            if res[ip3] >= res[mx]
                for i = 1:ip1
                    if i != mi
                        pa[:, i] = (pa[:, i] + pa[:, mi]) * 0.5
                        res[i] = residual(x, y, pa[:, i])
                    end
                end
                i = 0
            else
                i = ip3
            end
        end
        if i > 0
            pa[:, mx] = pa[:, i]
            res[mx] = res[i]
        end
    end
    println("not converged!")
end

x = 1:10; # x 値
y = [3, 8, 15, 35, 57, 80, 92, 95, 99, 100]; # y 値

# あてはめるモデル関数
fun(x, p) = p[1] ./ (1 .+ p[2] .* exp.(-p[3] .* x))

simplex(fun, [80, 70, 0.5], x, y, plotflag = true)

コメント   この記事についてブログを書く
« Julia に翻訳--019 | トップ | Julia に翻訳--021 »
最新の画像もっと見る

コメントを投稿

ブログ作成者から承認されるまでコメントは反映されません。

ブログラミング」カテゴリの最新記事