倍精度の更に倍の精度でマンデルブロ集合

MULXを使うと、多倍長乗算を書くのがかなり楽になる(もちろん速くもなる)。今回は、64bitの汎用レジスタ2個で128bitの符号付き固定小数点数を表し、その加減算と乗算をXbyakで実装した(これは多倍長なのか?)。

これまでは速さを求めていたけど、速いとどんどん拡大させてしまい、すぐ計算精度の限界に達してしまう。そこで、ずっとやりたかった精度アップだ(64bitのCPUとOS、それにできればMULXが欲しかった)。精度を大幅に上げると計算時間がかかりすぎるので、最低限の桁数にした。精度の桁数が2倍になった代わりに、前回の20倍くらい遅くなった。

Xbyakを使っているといっても、実行時にコードを生成している意味はほとんどなくて、単に高級アセンブラとして使った感じ。いつものマンデルブロ集合用コードより複雑なので、速度はあまり追求せず(ていうかできず)気持ちよく書けた。

Fixed128は128bitの符号付き固定小数点数を表す構造体。符号を別に持つ方法もあったが、16byteに収まったほうが楽だと思い、単に符号付き整数として格納している。乗算では、負数だったら符号反転して非負整数同士の掛け算にしている。速度的にどちらが有利なのかは知らん。

加算は小数点の位置を気にすることなく普通にキャリー付き加算をすればいいから簡単(符号を別に持つとここがややこしくなる)。減算も同様。符号反転は、ビット否定してから1を足せばよい。

どこを描画するかの変数も、doubleではなくFixed128で保持する必要があるので、C++から簡単に加減乗算が使えるようにしている。除算で精度が必要になることはないので、割り算をしたければ逆数をdoubleからキャストして掛ければよい。拡大率はdoubleでいい(ここは浮動小数点の強みが出ている)。

役立つ命令が用意されていて、しかも最近は所要クロック数も減ってきてるのでありがたい。MULXはフラグを変更しない乗算。結果を出力するレジスタも選べるようになっている。SHRDは、まさに固定小数点のためにあるような命令で、これがまた速い(instlatx64で見ただけで自分で測ったわけではないが)。

慣れないことするからバグ取りもけっこう大変だった。rbpを保存してるつもりで使ってたとか、絶対値の大きい値は扱えないのにピクセル数をFixed128にキャストしてたり。しかし低級(アセンブラ)なことを高級(C++)に書けるのは気持ちがいい。

コンパイルするときは、xbyak.hの他にxbyak_mnemonic.h(とxbyak_bin2hex.h?)が必要。ヘッダだけで使えてこんな快適に書けるのはすごい。

#include "xbyak.h"

struct Fixed128 {
    static constexpr int P = 7; // 固定小数点の位置

    using op = void (*)(Fixed128 *, const Fixed128 *, const Fixed128 *);
    static op add, mul; // 速度は不要だがどうせ下で書くので流用する

    uint64_t u[2];

    Fixed128& operator+=(Fixed128& t) {
        return *this = *this + t;
    }
    Fixed128& operator*=(Fixed128& t) {
        return *this = *this * t;
    }
    Fixed128 operator+(Fixed128& t) const {
        Fixed128 ret;
        add(&ret, this, &t);
        return ret;
    }
    Fixed128 operator*(Fixed128& t) const {
        Fixed128 ret;
        mul(&ret, this, &t);
        return ret;
    }

    Fixed128 operator-(Fixed128& t) const {
        return *this + (-t);
    }
    Fixed128 operator-() const {
        Fixed128 t;
        t.u[0] = ~u[0];
        t.u[1] = ~u[1];
        if (++t.u[0] == 0) t.u[1]++;
        return t;
    }

    explicit operator double() const {
        const bool p = (u[1] & 1ULL << 63) == 0;
        const Fixed128 t = p ? *this : -*this;
        return (t.u[0] * std::pow(0.5, 128 - P) + t.u[1] * std::pow(0.5, 64 - P)) * (p ? 1 : -1);
    }
    explicit Fixed128(const double s) {
        double i, d;
        d = std::modf(std::fabs(s) * std::pow(2.0, 64 - P), &i);
        u[0] = (uint64_t)(d * std::pow(2.0, 64));
        u[1] = (uint64_t)i;
        if (s < 0) *this = -*this;
    }
    Fixed128() {}
};

class Mandel128 : public Xbyak::CodeGenerator {
    const Xbyak::Reg64& r(int i, int h) const {
        const Xbyak::Reg64 *const reg[] = { /*&rax, &rcx, &rdx, &rbx, &rsp,*/ &rbp, &rsi, &rdi,/* &r8, &r9, &r10,*/ &r11, &r12, &r13, &r14, &r15 };
        return *reg[i * 2 + h];
    }
    void push_8() {
        push(rbp);
        push(rbx);
        push(rsi);
        push(rdi);
        push(r12);
        push(r13);
        push(r14);
        push(r15);
    }
    void pop_8() {
        pop(r15);
        pop(r14);
        pop(r13);
        pop(r12);
        pop(rdi);
        pop(rsi);
        pop(rbx);
        pop(rbp);
    }
    void f_mov(int i, int j) {
        mov(r(i, 0), r(j, 0));
        mov(r(i, 1), r(j, 1));
    }
    void f_mov(int i, Xbyak::RegExp e) {
        mov(r(i, 0), ptr[e]);
        mov(r(i, 1), ptr[e + 8]);
    }
    void f_mov(Xbyak::RegExp e, int i) {
        mov(ptr[e], r(i, 0));
        mov(ptr[e + 8], r(i, 1));
    }
    void f_add(int i, int j) {
        add(r(i, 0), r(j, 0));
        adc(r(i, 1), r(j, 1));
    }
    void f_add(int i, Xbyak::RegExp e) {
        add(r(i, 0), ptr[e]);
        adc(r(i, 1), ptr[e + 8]);
    }
    void f_sub(int i, int j) {
        sub(r(i, 0), r(j, 0));
        sbb(r(i, 1), r(j, 1));
    }
    void f_neg(int i) {
        not(r(i, 0));
        not(r(i, 1));
        add(r(i, 0), 1);
        adc(r(i, 1), 0);
    }
    void f_mul(int i, int j) {
        xor(r10, r10);

        if (i == j) {
            mov(r8, r(i, 1));
            add(r8, r8);
            jnc("@f");
            f_neg(i);
            L("@@");

            mov(rdx, r(i, 0)); // free r(i, 0)
            mulx(r(i, 0), r(i, 0), r(i, 0));
            mulx(r9, r8, r(i, 1));
            add(r8, r8);
            adc(r9, r9);
            adc(r10, r10);
            add(r(i, 0), r8);

            mov(rdx, r(i, 1)); // free r(i, 1)
            mulx(r8, r(i, 1), r(i, 1));
            adc(r(i, 1), r9);
            adc(r8, r10);

            shrd(r(i, 0), r(i, 1), 64 - Fixed128::P);
            shrd(r(i, 1),      r8, 64 - Fixed128::P);

        } else {
            mov(r8, r(i, 1));
            add(r8, r8);
            jnc("@f");
            f_neg(i);
            xor(r10, 1);
            L("@@");

            mov(r8, r(j, 1));
            add(r8, r8);
            jnc("@f");
            f_neg(j);
            xor(r10, 1);
            L("@@");

            mov(rdx, r(j, 0)); // free r(j, 0)
            mulx(r(j, 0), r(j, 0), r(i, 0));
            mulx(r9, r8, r(i, 1));
            add(r8, r(j, 0));

            mov(rdx, r(j, 1)); // free r(j, 1)
            mulx(r(j, 1), r(j, 0), r(i, 1)); // free r(i, 1)
            adc(r9, r(j, 0));
            adc(r(j, 1), 0);

            mulx(r(i, 1), r(i, 0), r(i, 0)); // free r(i, 0)
            add(r(i, 0), r8);
            adc(r(i, 1), r9);
            adc(r(j, 1), 0);

            shrd(r(i, 0), r(i, 1), 64 - Fixed128::P);
            shrd(r(i, 1), r(j, 1), 64 - Fixed128::P);

            test(r10, r10);
            jz("@f");
            f_neg(i);
            L("@@");
        }
    }

public:
    Mandel128(int m) {
        push_8();

        f_mov(0, rcx);
        f_mov(1, rcx + 16);
        mov(ebx, m);
        mov(rax, 4);
        shl(rax, 64 - Fixed128::P);

        L("loop");   // a b ? ?
        f_mov(2, 0); // a b a ?
        f_mul(0, 0); // aa b a ?
        f_mov(3, 1); // aa b a b
        f_mul(3, 3); // aa b a bb
        f_mul(1, 2); // aa ab ? bb
        f_mov(2, 0); // aa ab aa bb
        f_add(2, 3); // aa ab aa+bb bb

        cmp(r(2, 1), rax);
        jge("exit");

        f_sub(0, 3);        // aa-bb ab ? ?
        f_add(1, 1);        // aa-bb 2ab ? ?
        f_add(0, rcx);      // aa-bb+x 2ab ? ?
        f_add(1, rcx + 16); // aa-bb+x 2ab+y ? ?

        dec(ebx);
        jnz("loop");

        L("exit");
        mov(eax, m);
        sub(eax, ebx); // 戻り値(反復回数)

        pop_8();
        ret();
    }
    void init_Fixed128() {
        // 引数はrcx, rdx, r8に入っている
        Fixed128::add = (Fixed128::op)getCurr();
        push_8();

        f_mov(0, rdx);
        f_add(0, r8);
        f_mov(rcx, 0);

        pop_8();
        ret();

        Fixed128::mul = (Fixed128::op)getCurr();
        push_8();

        f_mov(0, rdx);
        f_mov(1, r8);
        f_mul(0, 1); // 最初f_addって書いてました すません
        f_mov(rcx, 0);

        pop_8();
        ret();
    }
};

Fixed128::op Fixed128::add, Fixed128::mul;

// // こんな感じで初期化する
// ma = new Mandel128(m);
// f = (int (*)(Fixed128 *))ma->getCode();
// ma->init_Fixed128();