SSE2でマンデルブロ集合

下のコードは、SSE2マンデルブロ集合の計算をするもの。かなり高速に動作すると思う。
このエントリでは、このコードの解説をするにょ。
うるりの物置きのMandel100.zipにあるSSEのコードが元になっている。

; MASM 8.0 を使用
.686
.xmm
.model flat, c

.data
align       16
D4  real8   4.0, 4.0
L1  oword   -1

.code
mandel_sse2 proc x:ptr real8, y:ptr real8, m:dword, count:ptr dword

    mov         eax, x
    movapd      xmm0, [eax]
    mov         eax, y
    movapd      xmm1, [eax]
    mov         ecx, m
    movapd      xmm6, xmm0
    movapd      xmm7, xmm1
    pxor        xmm5, xmm5
    jmp         brotloop

    align       16
brotloop:                   ; xmmx = a b ? ? temp count x y
    movapd      xmm2, xmm0  ; xmm2 = a
    mulpd       xmm0, xmm0  ; xmm0 = aa
    movapd      xmm3, xmm1  ; xmm3 = b
    mulpd       xmm1, xmm2  ; xmm1 = ab, free xmm2
    movapd      xmm2, xmm0  ; xmm2 = aa
    mulpd       xmm3, xmm3  ; xmm3 = bb
    addpd       xmm0, xmm6  ; xmm0 = aa+x
    addpd       xmm2, xmm3  ; xmm2 = aa+bb = rr
    addpd       xmm1, xmm1  ; xmm1 = 2ab
    cmpnlepd    xmm2, D4    ; xmm2 = (rr > 4) ? -1 : 0
    subpd       xmm0, xmm3  ; xmm0 = aa-bb+x = a', free xmm3
    movmskpd    eax, xmm2
    addpd       xmm1, xmm7  ; xmm1 = 2ab+y = b'
    test        eax, eax
    jnz         escaped
reentry:
    dec         ecx
    jnz         brotloop
    jmp         exit

escaped:
    movd        xmm4, ecx
    pshufd      xmm4, xmm4, 0
    pand        xmm4, xmm2  ; 反復回数をマスク
    por         xmm5, xmm4  ; 発散した点へ回数を書き込む
    pxor        xmm2, L1    ; ビット反転ってこれしか方法ないんだろうか
    andpd       xmm0, xmm2
    andpd       xmm1, xmm2
    andpd       xmm6, xmm2
    andpd       xmm7, xmm2  ; 発散した点をクリアし二度とここに来ないようにする
    pxor        xmm4, xmm4
    pcmpeqd     xmm4, xmm5  ; xmm4 = (xmm5 == 0) ? -1 : 0
    pmovmskb    eax, xmm4
    test        eax, eax    ; xmm5に0が
    jnz         reentry     ; 一つでもあったらループに戻る

exit:
    mov         eax, m
    movd        xmm4, eax
    pshufd      xmm4, xmm4, 0
    psubd       xmm4, xmm5
    mov         eax, count
    movdqa      [eax], xmm4
    ret

mandel_sse2 endp

end

ここから、MASM用のソースコードを少しずつ読み進めていく。

; MASM 8.0 を使用
.686
.xmm
.model flat, c

「;」は行コメント(行の途中からでも使える)。
.686と.xmmはアセンブル対象の命令を指示する。
SSEを使うので、.xmmを指定。これでSSE2の命令もアセンブルしてくれる。
.modelではメモリモデルと関数の呼出規約を指定。
呼出規約はC言語に合わせてcdeclとした。これなら面倒がない。

.data
align       16
D4  real8   4.0, 4.0
L1  oword   -1

.data。コード内で使う定数をメモリに置いてもらう。
SSE2のmovapdで読み込むためには、アドレス値が16の倍数である必要があるので、
align 16 としてデータの置き場所をそのように指定する。
D4が定数の名前、real8が型、4.0がメモリに置かれる値。
owordというのは16byteの型で、-1と書くことで128bit全てを1にしている。

.code

ここからがコード。

mandel_sse2 proc x:ptr real8, y:ptr real8, m:dword, count:ptr dword
//C++から呼び出すときはこんな感じ
extern "C" void mandel_sse2(double *x, double *y, long m, long *c);
int hoge()
{
    double x[2], y[2];
    long m, count[2];
    …
    mandel_sse2(x, y, m, count);
    …

mandel_sse2という名前の関数が始まる。引数は4つ。x, y, count の3つはポインタ。

    mov         eax, x
    movapd      xmm0, [eax]
    mov         eax, y
    movapd      xmm1, [eax]
    mov         ecx, m
    movapd      xmm6, xmm0
    movapd      xmm7, xmm1
    pxor        xmm5, xmm5

まず、ポインタxをeaxへコピーし、そのアドレスからdouble値2つをxmm0へロードする。
yも同様。mはecxへ読み込んでカウンタとして使う。
xmm0をxmm6へコピー、xmm1をxmm7へコピー、xmm5の値を0にする。

//C++版を書くとすればこんな感じ
long mandel_cpp(double x, double y, long m)
{
    long k, count;
    double a = x, b = y, t;
    for (k = 0; k < m; k++){
        if (a * a + b * b > 4.0) break;
        t = a * a - b * b + x;
        b = 2.0 * a * b + y;
        a = t;
    }
    count = k;
    return count;
}

ここからの処理内容をC++で書くと上のようになる。
MASM用コードのレジスタと、この変数たちとの対応を下の表で示す。

レジスタ C++版の変数
ecx(ダウンカウント) k
xmm0 a
xmm1 b
xmm5 count
xmm6 x
xmm7 y
    jmp         brotloop

    align       16

さて、次にループの開始を16byte境界にアラインさせておく。
CPUが少しでも命令を読みやすいように。
ここからが、速度的にクリティカルな部分だ。

brotloop:                   ; xmmx = a b ? ? temp count x y
    movapd      xmm2, xmm0  ; xmm2 = a
    mulpd       xmm0, xmm0  ; xmm0 = aa
    movapd      xmm3, xmm1  ; xmm3 = b
    mulpd       xmm1, xmm2  ; xmm1 = ab, free xmm2
    movapd      xmm2, xmm0  ; xmm2 = aa
    mulpd       xmm3, xmm3  ; xmm3 = bb
    addpd       xmm0, xmm6  ; xmm0 = aa+x
    addpd       xmm2, xmm3  ; xmm2 = aa+bb = rr
    addpd       xmm1, xmm1  ; xmm1 = 2ab

brotloop: はラベル。ループの開始地点。
複素数のa+biを2乗してx+yiを足す操作を繰り返すところ。
ここは全てSSE2の命令だが、加算・乗算・コピーだけなので意味は簡単。
ただし、依存関係のある演算同士はできるだけ離して置くようにする。
例えば、mulpd xmm0, xmm0 でaの2乗を計算した後、
xmm0(ていうかaの2乗)を使う演算はできるだけ離して配置したい。
そうすればレイテンシの長い命令を並列実行してくれる。
CPU内には命令を溜めておける場所もあるけど、最適化の効果は十分ある。

    cmpnlepd    xmm2, D4    ; xmm2 = (rr > 4) ? -1 : 0
    subpd       xmm0, xmm3  ; xmm0 = aa-bb+x = a', free xmm3
    movmskpd    eax, xmm2
    addpd       xmm1, xmm7  ; xmm1 = 2ab+y = b'
    test        eax, eax
    jnz         escaped

命令が入り交じってわかりにくいが、複素数の絶対値が4を超えたかの判定にのみ触れる。
cmpnlepdで、xmm2の値が4以下でない場合に全ビットを1にする。
「超える」ではなく「以下でない」というのは回りくどい表現だが、
レジスタの状態が通常の数でなかった場合の扱いが違う(ここでは関係ない)。
「D4」は、先に宣言しておいた定数。ここに4.0が2つ入っていて、これと比較する。
これを条件分岐に使いたいため、movmskpdで情報を汎用レジスタ(eax)に移す。
現在、SSE2でdouble値2つを同時に計算しているが、そのうち1つでも4.0を超えたら分岐する。
test命令でゼロフラグが立たない、つまり誰かが4.0を超えたときにjnzでジャンプする。
実はここで、cmpnlepdの代わりにcmplepdを使い、psubd xmm5, xmm2を入れれば、
escaped:への面倒な分岐をせずに同等の計算ができる。
だが、その1命令を削って高速化するのが今回の思想。
17命令のループだが、psubdは軽い命令なので、1%とかのレベルでしか変わらない。

reentry:
    dec         ecx
    jnz         brotloop
    jmp         exit

reentry: は後で出てくる。処理完了時以外に、ここへ戻ってくるための場所。
decとjnzは、ダウンカウントループの定型文。
SSE命令はフラグを変更しないので、必要ならdecの後にaddpdとかを入れてもいい。
指定の回数(m回)に達したらループを抜け、jmpで終了処理へ飛ぶ。

escaped:
    movd        xmm4, ecx
    pshufd      xmm4, xmm4, 0
    pand        xmm4, xmm2  ; 反復回数をマスク
    por         xmm5, xmm4  ; 発散した点へ回数を書き込む

少なくとも1つが4.0を超えた場合にここへ来る。
まず、現在のカウンタの値をxmm4へコピー。
pshufd xmm4, xmm4, 0 は、xmmレジスタの最下位32bitを全体へコピーする。
ここで、さっきのxmm2(4を超えた場合all 1)をマスクとして使う。
それを、反復回数を記録しておくxmm5へporで書き込めば、要素毎に回数を管理できる。

    pxor        xmm2, L1    ; ビット反転ってこれしか方法ないんだろうか
    andpd       xmm0, xmm2
    andpd       xmm1, xmm2
    andpd       xmm6, xmm2
    andpd       xmm7, xmm2  ; 発散した点をクリアし二度とここに来ないようにする
    pxor        xmm4, xmm4
    pcmpeqd     xmm4, xmm5  ; xmm4 = (xmm5 == 0) ? -1 : 0
    pmovmskb    eax, xmm4
    test        eax, eax    ; xmm5に0が
    jnz         reentry     ; 一つでもあったらループに戻る

今度は、4を超えてないケースを残したいので、xmm2をxmm2の否定に置き換える。
SIMD命令にはnot命令がないので、仕方なく1とのxorをとることで反転させる。
ここはループの外なので、それほど速度を気にしなくていい。
4を超えたケースについては、ここで消えてもらう(0にする、もう4は超えない)。
次に、反復回数を記録しているxmm5を0と等しいか比較して、
0(回数未記録)が1つでもあったらループへ再突入する。
ecxが非0の状態でしかここへ飛んでこないことを使っている。
もし全て(と言っても2つだが)の点について計算が終わっていたら、jnzをスルーして終了処理へ。

exit:
    mov         eax, m
    movd        xmm4, eax
    pshufd      xmm4, xmm4, 0
    psubd       xmm4, xmm5
    mov         eax, count
    movdqa      [eax], xmm4
    ret

計算は終わった。しかし、xmm5に入っている回数は加工が必要(ダウンカウントなので)。
xmm4をmの値で埋め、そこからxmm5を減算することで仕上がる。
それをcountのアドレスへ書き込んで終了。
ちょっと手抜きなので、回数はcount[0]とcount[2]に記録される。
これはdouble2個の処理だが、float4個の処理へも簡単に変更できる。
addpdをaddpsにするなどの機械的な変更だけで、ほぼ通る。
手元のPentiumMはSSE2スループットが悪く、FPUと同等なのだが、
それでもレジスタの増加や表現力が高まったことによりFPU比で1.5倍高速化した。
精度が低くていいなら、4並列のSSE版で更に2倍くらい速くなる。
コードの見た目、若干最適化が甘いようにも見えるが、
PentiumMではパワー不足でこれ以上は詰め込めない感触だ。