FMAでマンデルブロ集合

x64のAVXでマンデルブロ集合 - merom686's blog
↑前回の。

今回は、FMAを使った。あまり本質的ではないが、反復回数もAVX2を使うようにした。

xmmA EQU xmm10
xmmB EQU xmm11
xmmC EQU xmm12
xmmD EQU xmm13
xmmE EQU xmm14
xmmF EQU xmm15
ymmA EQU ymm10
ymmB EQU ymm11
ymmC EQU ymm12
ymmD EQU ymm13
ymmE EQU ymm14
ymmF EQU ymm15

.data
align         16
D4  real8     4.0, 4.0, 4.0, 4.0
S8  dword     0, 2, 4, 6, 0, 0, 0, 0

.code
mandel_avx proc

    sub          rsp, 168         ; この時点でrsp%16==8なので8byte余分に確保する
    vmovapd      [rsp    ], xmm6
    vmovapd      [rsp+ 16], xmm7
    vmovapd      [rsp+ 32], xmm8
    vmovapd      [rsp+ 48], xmm9
    vmovapd      [rsp+ 64], xmmA
    vmovapd      [rsp+ 80], xmmB
    vmovapd      [rsp+ 96], xmmC
    vmovapd      [rsp+112], xmmD
    vmovapd      [rsp+128], xmmE  ; xmm6〜xmm15をスタックに保存
    vmovapd      [rsp+144], xmmF  ; 汎用レジスタは揮発性のものしか使っていないので保存不要

    vzeroall                      ; raxを保存するxmm5とxmmDはゼロにする必要がある

    vmovapd      ymm0, [rcx   ]   ; aをxで初期化
    vmovapd      ymm8, [rcx+32]
    vmovapd      ymm1, [rcx+64]   ; bをyで初期化
    vmovapd      ymm9, [rcx+96]
    mov          eax, [rcx+160]   ; m 使うのはraxだがゼロ拡張される

    vmovapd      ymm6, ymm0       ; x
    vmovapd      ymmE, ymm8
    vmovapd      ymm7, ymm1       ; y
    vmovapd      ymmF, ymm9

    vmovupd      ymm4, D4         ; 4.0 注:ymmCは別用途
    jmp          lp
                                  ; -:今使ってはいけない領域, ?:今後使われないデータ
    align        16               ; ymm = a b ? ? - - x y
lp:
    vmulpd       ymm2, ymm0, ymm0 ; ymm = a b aa ? - - x y
    vmulpd       ymmA, ymm8, ymm8
    vaddpd       ymm3, ymm0, ymm0 ; ymm = a b aa 2a - - x y
    vaddpd       ymmB, ymm8, ymm8
    vfmadd213pd  ymm0, ymm0, ymm6 ; ymm = aa+x b aa 2a - - x y
    vfmadd213pd  ymm8, ymm8, ymmE
    vfmadd231pd  ymm2, ymm1, ymm1 ; ymm = aa+x b aa+bb 2a - - x y
    vfmadd231pd  ymmA, ymm9, ymm9
    vfnmadd231pd ymm0, ymm1, ymm1 ; ymm = aa-bb+x b aa+bb 2a - - x y
    vfnmadd231pd ymm8, ymm9, ymm9
    vfmadd213pd  ymm1, ymm3, ymm7 ; ymm = aa-bb+x 2ab+y aa+bb ? - - x y
    vfmadd213pd  ymm9, ymmB, ymmF
    vcmpnlepd    ymm2, ymm2, ymm4 ; ymm = aa-bb+x 2ab+y aa+bb>4 ? 4 - x y
    vcmpnlepd    ymmA, ymmA, ymm4
    vptest       ymm2, ymm2       ; 1個も4を超えなかったらZFが立つ
    jnz          dive             ; ZFが立っていなかったらジャンプ
cont:
    vptest       ymmA, ymmA
    jnz          dive2
cont2:
    dec          rax
    jnz          lp               ; m回ループしたら終わり
    jmp          exit

dive:
    vandnpd      ymm0, ymm2, ymm0 ; 発散した点をゼロクリアし二度とここへ来ないようにする
    vandnpd      ymm1, ymm2, ymm1 ; 4を超えた場所でビットが立っていることを利用しandnでそこだけ消す
    vandnpd      ymm6, ymm2, ymm6
    vandnpd      ymm7, ymm2, ymm7

    vmovq        xmmC, rax        ; ymm = - - mask - 4/rax count - -
    vpbroadcastq ymmC, xmmC       ; ymmC = rax rax rax rax
    vpand        ymmC, ymmC, ymm2 ; 反復回数をマスク、4を超えたところだけ残す
    vpxor        ymm5, ymm5, ymmC ; 発散した点へ回数を書き込む

    vpxor        ymm2, ymm2, ymm2 ; 比較用のゼロ

    vpcmpeqq     ymmC, ymmD, ymm2 ; ymmC = ymmD == 0 ? -1 : 0
    vptest       ymmC, ymmC
    jnz          cont             ; ymmDに0が一つでもあったらループに戻る
    vpcmpeqq     ymmC, ymm5, ymm2
    vptest       ymmC, ymmC
    jnz          cont2            ; ymmD側が完了していればcont2へ戻ってもよい
    jmp          exit             ; 全ての点の計算が完了してたら終わり

dive2:
    vandnpd      ymm8, ymmA, ymm8
    vandnpd      ymm9, ymmA, ymm9
    vandnpd      ymmE, ymmA, ymmE
    vandnpd      ymmF, ymmA, ymmF

    vmovq        xmmC, rax
    vpbroadcastq ymmC, xmmC
    vpand        ymmC, ymmC, ymmA
    vpxor        ymmD, ymmD, ymmC

    vpxor        ymmA, ymmA, ymmA
    
    vpcmpeqq     ymmC, ymm5, ymmA
    vptest       ymmC, ymmC
    jnz          cont2
    vpcmpeqq     ymmC, ymmD, ymmA
    vptest       ymmC, ymmC
    jnz          cont2            ; 終わりならそのままexit:へ

exit:
    vmovdqu      ymmC, ymmword ptr S8
    vpermd       ymm5, ymmC, ymm5 ; 反復回数の下位32bitを集める
    vpermd       ymmD, ymmC, ymmD
    vinserti128  ymm0, ymm5, xmmD, 1

    vpbroadcastd ymmC, dword ptr [rcx+160]
    vpsubq       ymm0, ymmC, ymm0 ; ymm0 = 各点の反復回数
    vmovdqu      ymmword ptr [rcx+128], ymm0

    vmovapd      xmm6, [rsp    ]
    vmovapd      xmm7, [rsp+ 16]
    vmovapd      xmm8, [rsp+ 32]
    vmovapd      xmm9, [rsp+ 48]
    vmovapd      xmmA, [rsp+ 64]
    vmovapd      xmmB, [rsp+ 80]
    vmovapd      xmmC, [rsp+ 96]
    vmovapd      xmmD, [rsp+112]
    vmovapd      xmmE, [rsp+128]
    vmovapd      xmmF, [rsp+144]
    add          rsp, 168
    vzeroupper                    ; SSEを使えるように依存関係を断ち切っておく
    ret

mandel_avx endp

end