成人av在线资源一区,亚洲av日韩av一区,欧美丰满熟妇乱XXXXX图片,狠狠做五月深爱婷婷伊人,桔子av一区二区三区,四虎国产精品永久在线网址,国产尤物精品人妻在线,中文字幕av一区二区三区欲色
    您正在使用IE低版瀏覽器,為了您的雷峰網(wǎng)賬號安全和更好的產(chǎn)品體驗,強烈建議使用更快更安全的瀏覽器
    此為臨時鏈接,僅用于文章預覽,將在時失效
    人工智能開發(fā)者 正文
    發(fā)私信給skura
    發(fā)送

    0

    基于JAX的大規(guī)模并行MCMC:CPU25秒就可以處理10億樣本

    本文作者: skura 2020-01-14 11:43
    導語:在概率編程中,JAX 有很多優(yōu)勢

    JAX 的表現(xiàn)出乎所有人的意料,在極端情況下,最大性能可提高 20 倍。由于 JAX 的 JIT 編譯開銷,Numpy 在少樣本、少量鏈的情況下會勝出。我報告了 tensorflow probability (TFP) 的結(jié)果,但請記住,這種比較是不公平的,因為它實現(xiàn)的隨機游走 metroplis 比我們的包含更多的功能。

    重現(xiàn)結(jié)果所需的代碼可以在這里找到。使代碼運行得更快的技巧值得學習。

    矢量化 MCMC

    Colin Carroll 最近發(fā)布了一篇有趣的博文,使用 Numpy 和隨機游走 metropolis 算法 (RWMH) 的矢量化版本來生成大量的樣本,同時運行多個鏈以便對算法的收斂性進行后驗檢驗。這通常是通過在多線程機器上每個線程運行一個鏈來實現(xiàn)的,在 Python 中使用 joblib 或自定義后端。這么做很麻煩,但它能完成任務。

    Colin 的 文章讓我感到非常興奮,因為我可以在幾乎不增加成本的情況下,同時對成千上萬的鏈進行取樣。他在文章中詳細介紹了幾個這一方法的應用,但我有一種直覺,它可以完成更多的事情。

    大約在同一時間,我偶然發(fā)現(xiàn)了 JAX。JAX 在概率編程語言環(huán)境中似乎很有趣,原因如下:

    • 在大多數(shù)情況下,它完全可以替代 Numpy;

    • Autodiff 很簡單;

    • 它的正向微分模式使得計算高階導數(shù)變得容易;

    • JAX 使用 XLA 執(zhí)行 JIT 編譯,即使在 CPU 上也可以加速代碼的運行;

    • 使用 GPU 和 TPU 非常簡單;

    • 這是一個偏好問題,但它更傾向于函數(shù)式編程。

    在開始使用 JAX 實現(xiàn)一個框架之前,我想做一些基準測試,以了解我要注冊的是什么。這里我將進行比較:

    • Numpy

    • Jax

    • Tensorflow Probability (TFP)

    • XLA 編譯的 Tensorflow Probability

    關于基準測試

    在給出結(jié)果之前,首先需要聲明的是:

    1. 報告的時間是在我的筆記本電腦上運行 10 次的平均值,除了終端打開外,沒有任何其它操作。除了編譯后的 JAX 運行外,所有運行的時間都是使用 hyperfine 命令行工具測量的。

    2. 我的代碼可能不是最優(yōu)的,對于 TFP 來說尤其如此。

    3. 實驗是在 CPU 上進行的。JAX 和 TFP 可以運行在 GPU/TPU 上,所以可以期待額外的加速。

    4. 對于 Numpy 和 JAX 來說,采樣器是一個生成器,樣本不保存在內(nèi)存中但對 TFP 來說并非如此,因此在大型實驗期間,計算機會耗盡內(nèi)存。如果 TFP 沒有在堆棧上預先分配內(nèi)存,不斷地分配內(nèi)存也會影響性能。

    5. 在概率編程中重要的度量是每秒有效采樣的數(shù)量,而不是每秒采樣數(shù)量,前者后者更像是你使用的算法。這個基準測試仍然可以很好地反映不同框架的原始性能。

    設置和結(jié)果

    我在對一個含有 4 個分量的任意高斯混合樣本進行采樣。使用 Numpy:

    import numpy as np
    from scipy.stats import norm
    from scipy.special import logsumexp

    def mixture_logpdf(x):
        loc = np.array([[-2, 0, 3.2, 2.5]]).T
        scale = np.array([[1.2, 1, 5, 2.8]]).T
        weights = np.array([[0.2, 0.3, 0.1, 0.4]]).T

        log_probs = norm(loc, scale).logpdf(x)

        return -logsumexp(np.log(weights) - log_probs, axis=0)

    Numpy

    Colin Carroll 的 MiniMC 是我見過的最簡單、最易讀的大都市隨機游走  Metropolis 和 Hamiltonian Monte Carlo 的實現(xiàn)。我的 Numpy 實現(xiàn)是他的一個迭代:

    import numpy as np

    def rw_metropolis_sampler(logpdf, initial_position):
        position = initial_position
        log_prob = logpdf(initial_position)
        yield position

        while True:
            move_proposals = np.random.normal(0, 0.1, size=initial_position.shape)
            proposal = position + move_proposals
            proposal_log_prob = logpdf(proposal)

            log_uniform = np.log(np.random.rand(initial_position.shape[0], initial_position.shape[1]))
            do_accept = log_uniform < proposal_log_prob - log_prob

            position = np.where(do_accept, proposal, position)
            log_prob = np.where(do_accept, proposal_log_prob, log_prob)
            yield position

    JAX

    JAX 的實現(xiàn)與 Numpy 非常相似:

    from functools import partial

    import jax
    import jax.numpy as np

    @partial(jax.jit, static_argnums=(0, 1))
    def rw_metropolis_kernel(rng_key, logpdf, position, log_prob):
        move_proposals = jax.random.normal(rng_key, shape=position.shape) * 0.1
        proposal = position + move_proposals
        proposal_log_prob = logpdf(proposal)

        log_uniform = np.log(jax.random.uniform(rng_key, shape=position.shape))
        do_accept = log_uniform < proposal_log_prob - log_prob

        position = np.where(do_accept, proposal, position)
        log_prob = np.where(do_accept, proposal_log_prob, log_prob)
        return position, log_prob


    def rw_metropolis_sampler(rng_key, logpdf, initial_position):
        position = initial_position
        log_prob = logpdf(initial_position)
        yield position

        while True:
            position, log_prob = rw_metropolis_kernel(rng_key, logpdf, position, log_prob)
            yield position

    如果你熟悉 Numpy,那么你應該非常熟悉它的語法。JAX 和它有一些不同之處:

    •  jax.numpy 充當 numpy 的替代。對于只涉及數(shù)組操作的函數(shù),用 import jax.numpy as np 替換 import numpy as np,這會給你帶來性能上的提升。

    • JAX 處理隨機數(shù)生成的方式與其他 Python 包不同,這是有原因的 (請閱讀這篇文章:https://github.com/google/jax/blob/master/design_notes/prng.md ) 。每個發(fā)行版都以一個 PRNG 鍵作為輸入。

    • 因為 JAX 不能編譯生成器,我從采樣器中提取內(nèi)核。因此,我們提取并 JIT 完成所有繁重工作的函數(shù):rw_metropolis_kernel。

    • 我們需要對 JAX 的編譯器提供一點幫助,即指出當函數(shù)多次運行時哪些參數(shù)不會改變:@partial(jax.jit, argnums=(0, 1))。如果將函數(shù)作為參數(shù)傳遞,這是必需的,并且可以啟用進一步的編譯時優(yōu)化。

    Tensorflow Probability

    對于 TFP,我們使用庫中實現(xiàn)的隨機游走 Metropolis 算法:

    from functools import partial

    import numpy as np
    import tensorflow as tf
    import tensorflow_probability as tfp
    tfd = tfp.distributions

    def run_raw_metropolis(n_dims, n_samples, n_chains, target):
        samples, _ = tfp.mcmc.sample_chain(
            num_results=n_samples,
            current_state=np.zeros((n_dims, n_chains), dtype=np.float32),
            kernel=tfp.mcmc.RandomWalkMetropolis(target.log_prob, seed=42),
            num_burnin_steps=0,
            parallel_iterations=8,
        )
        return samples

    run_mcm = partial(run_tfp_mcmc, n_dims, n_samples, n_chains, target)

    ## Without XLA
    run_mcm()

    ## With XLA compilation
    tf.xla.experimental.compile(run_mcm)

    結(jié)果

    我們有兩個自由維度:樣本的數(shù)量和鏈的數(shù)量,第一個依賴于原始的數(shù)字處理能力,第二個也依賴于向量化的實現(xiàn)方式。因此,我決定在兩個維度上對算法進行基準測試。

    我考慮以下情況:

    1. Numpy 實現(xiàn);

    2. JAX 實現(xiàn);

    3. 減去編譯時間的 JAX 實現(xiàn)。這只是一個假設的情況,目的是顯示編譯帶來的改進。

    4. Tensorflow Probability;

    5. 實驗 XLA 編譯的 Tensorflow Probability。

    用 1000 條鏈繪制越來越多的樣本

    我們固定鏈的數(shù)量,并改變樣本的數(shù)量。

    基于JAX的大規(guī)模并行MCMC:CPU25秒就可以處理10億樣本

    你將注意到 TFP 實現(xiàn)的缺失點。由于 TFP 算法存儲所有的樣本,所以它會耗盡內(nèi)存。這在 XLA 編譯的版本中沒有發(fā)生,可能是因為它使用了內(nèi)存效率更高的數(shù)據(jù)結(jié)構。

    對于少于 1000 個樣本,普通的 TFP 和 Numpy 實現(xiàn)比它們的編譯副本要快。這是由于編譯開銷造成的:當你減去 JAX 的編譯時間 (從而獲得綠色曲線) 時,它會大大加快速度。只有當樣本的數(shù)量變得很大,并且總抽樣時間取決于抽取樣本的時間時,你才開始從編譯中獲益。

    沒有什么神奇的:JIT 編譯意味著一個明顯的、但不變的計算開銷。

    我建議在大多數(shù)情況下使用 JAX。只有當相同的代碼執(zhí)行超過 10 次時,在 0.3 秒而不是 3 秒內(nèi)進行采樣的差異才會產(chǎn)生影響。然而,編譯是只會發(fā)生一次。在這種情況下,計算開銷將在你達到 10 次迭代之前得到回報。實際上,JAX 贏了。

    用越來越多的鏈繪制 1000 個樣本

    在這里,我們固定樣本的數(shù)量,改變鏈的數(shù)量。

    基于JAX的大規(guī)模并行MCMC:CPU25秒就可以處理10億樣本

    JAX 仍然明顯地贏了:只要鏈的數(shù)量達到 10,000,它就比 Numpy 更快。你將注意到 JAX 曲線上有一個凸起,這完全是由于編譯造成的 (綠色曲線沒有這個凸起)。我不知道為什么,如果有答案請告訴我!

    這就是令人興奮的亮點:

    JAX 可以在 25 秒內(nèi)在 CPU 上生成 10 億個樣本,比 Numpy 快 20 倍!

    結(jié)論

    對于允許我們用純 python 編寫代碼的項目,JAX 的性能是令人難以置信的。Numpy 仍然是一個不錯的選擇,特別是對于那些 JAX 的大部分執(zhí)行時間都花在編譯上的項目來說尤其如此。

    但是,Numpy 不適合概率編程語言。如 Hamiltonian Monte Carlo 這樣的高效抽樣算 Uber 優(yōu)步的團隊開始和 JAX 在 Numpyro 上合作。

    不要過多地解讀 Tensorflow Probability 的拙劣表現(xiàn)。當從分布中采樣時,重要的不是原始速度,而是每秒有效采樣的數(shù)量。TFP 的實現(xiàn)包括更多的附加功能,我希望它在每秒有效采樣樣本數(shù)方面更具競爭力。

    最后,請注意,用鏈的數(shù)量乘以樣本的數(shù)量要比用樣本的數(shù)量乘以樣本的數(shù)量容易得多。我們還不知道如何處理這些鏈,但我有一種直覺,一旦我們這樣做了,概率編程將會有另一個突破。

    via:https://rlouf.github.io/post/jax-random-walk-metropolis/

    雷鋒網(wǎng)雷鋒網(wǎng)雷鋒網(wǎng)

    雷峰網(wǎng)版權文章,未經(jīng)授權禁止轉(zhuǎn)載。詳情見轉(zhuǎn)載須知

    基于JAX的大規(guī)模并行MCMC:CPU25秒就可以處理10億樣本

    分享:
    相關文章
    當月熱門文章
    最新文章
    請?zhí)顚懮暾埲速Y料
    姓名
    電話
    郵箱
    微信號
    作品鏈接
    個人簡介
    為了您的賬戶安全,請驗證郵箱
    您的郵箱還未驗證,完成可獲20積分喲!
    請驗證您的郵箱
    立即驗證
    完善賬號信息
    您的賬號已經(jīng)綁定,現(xiàn)在您可以設置密碼以方便用郵箱登錄
    立即設置 以后再說