round to even

ちょっと古いものなんだけど浮動小数点数の加算についてのround to evenをテストしてみたりなんかしたコードが埋もれてたのを発見。せっかくなので置いておくことにした。
よくある四捨五入っていうのは0.5丁度のときに問答無用で切り上げるけれども、round to even の場合、これをeven方向に丸める。1.5 だったら 2に。2.5も 2に。3.5は4に、とかいう感じ。これだけ見てると「なんか変じゃね?」と思うかもしれないけれども、あくまでこれは ?.5ちょうどのときのお話で、それ以外は普通に誤差が小さいほうに丸められる。
肝の部分は、四捨五入のままだと、 ?.5のときは必ず +0.5になってしまうわけで、たとえば?.5が連発するようなデータの和なんかの場合、?.5が来るたびにどんどん +0.5ずつ増えていってしまう。科学技術計算なんかではとてもまずい。これを統計的に +-0.0にしようとするのがこの方法。だそうな。普段我々が使っているIEEE形式の仕様もこうなっている。
2進の世界では .10000..0 となっていた場合に前のビットが 1 だったら切り上げ。

回路で作る場合、ビット位置は自由自在だし、条件もわかりやすい。たとえば以下のような小数点以下5桁の値を整数部で丸める場合、

 abcd.efghi

整数部の最下位ビット d と、小数部の最上位ビット e、あと小数部の残り全ビットのand f&g&h&i(rount to even条件を判定するため)、を出すだけ、と結構簡単。さて、ソフトでやったらどうなる?というのと、さらに加算だけならlong longを用意しなくても、加算する前に round出せるなあ、とかやってみて、結果大して回路と変わらない上に、かえってあちこちで変数が増えて、こんなことなら素直に64bit演算すればよかったとかになったのが下のコードというわけ。なんか恥ずかしい。

#include <stdio.h>
#include <stdlib.h>

typedef enum { false, true } bool;

typedef union {
    unsigned int intval;
    float floatval;
    struct  {
	unsigned int num :23;
	unsigned int exp : 8;
	unsigned int sign: 1;
    } structval;
} FYPE;

int round_to(int num, int btm, int diff, bool enable)
{
    const int 
	upper =  num & 1,
	lower = (btm >> (diff-1)) & 1,
	rest  = btm & ((1 << (diff-1)) -1);

    if (enable && !rest && lower) {
	// round to even
	return  upper? 1: 0;
    } else {
	return  lower? 1: 0;
    }
}

FYPE add_float (FYPE a, FYPE b, bool enable)
{
    int num_a = a.structval.num+0x00800000,
        num_b = b.structval.num+0x00800000,
        num_c, head, bottom;
    int exp_diff = a.structval.exp - b.structval.exp;
    FYPE c;

    printf("[debug]:input: num_a:%08x  num_b:%08x\n", num_a, num_b);
    printf("[debug]:input: exp_a:%-8d  exp_b:%-8d\n", a.structval.exp, b.structval.exp);
    if (exp_diff == 0) {
	c.structval.exp = a.structval.exp;
    } else if (exp_diff > 0) {
	exp_diff = abs(exp_diff);
	if (exp_diff > 23) num_b = 0;
	else  {
	    bottom = num_b;
	    num_b   >>= exp_diff;
	}
	c.structval.exp = a.structval.exp;
    } else {
	exp_diff = abs(exp_diff);
	if (exp_diff > 23) num_a = 0;
	else {
	    bottom = num_a;
	    num_a   >>= exp_diff;
	}
	c.structval.exp = b.structval.exp;
    }
    if (a.structval.sign == 1) num_a = -(num_a);
    if (b.structval.sign == 1) num_b = -(num_b);
    printf("[debug]:unite: num_a:%08x  num_b:%08x\n", num_a, num_b);

    num_c = num_a + num_b;
    printf("[debug]:added     : %08x\n", num_c);

    if (num_c < 0) {
	num_c = -num_c;
	c.structval.sign= 1;
    }	else  {
	c.structval.sign= 0;
    }

    if (num_c >= 0x1000000) {
	++exp_diff;
	bottom &= ~(1 << exp_diff);
	bottom |= ((num_c&1) << exp_diff);
	num_c >>= 1;
	++c.structval.exp;
    } else if (num_c) {
	while (!(num_c & 0x800000)) {
	    c.structval.exp--;
	    num_c <<= 1;
	}
    } else { c.structval.exp = 0; } 

    num_c += round_to(num_c, bottom, exp_diff, enable);
    printf("[debug]:rounded   : %08x\n", num_c);
    c.structval.num = num_c & 0x7fffff;

    return c;
}

main(int argc, char** argv)
{
    FYPE a,b;
    FYPE c,d;
    bool enable = true;

    if (argc >= 3) {
	a.floatval = atof(argv[1]);
	b.floatval = atof(argv[2]);
	if (argc > 3) enable = false;

	c = add_float(a,b, enable);
	printf("----------\n");
	printf("kuro:%08x + %08x = %08x (%.12f)\n", a.intval, b.intval, 
		c.intval, c.floatval);
	d.floatval = a.floatval + b.floatval;
	printf("fpu :%08x + %08x = %08x (%.12f)\n", a.intval, b.intval, 
		d.intval, d.floatval);
	printf("%s\n", d.intval == c.intval? "OK": "*********NG********");

    } else { 
	fprintf(stderr, "you must specify two FP values\n");
    }
    return 0;
}

こんな感じに最後にFPU演算結果と照合している。

quartz% ./float_add 1.45 5.8 
[debug]:input: num_a:00b9999a  num_b:00b9999a
[debug]:input: exp_a:127       exp_b:129     
[debug]:unite: num_a:002e6666  num_b:00b9999a
[debug]:added     : 00e80000
[debug]:rounded   : 00e80000
----------
kuro:3fb9999a + 40b9999a = 40e80000 (7.250000000000)
fpu :3fb9999a + 40b9999a = 40e80000 (7.250000000000)
OK

quartz% ./float_add 1.45 5.8  disable
[debug]:input: num_a:00b9999a  num_b:00b9999a
[debug]:input: exp_a:127       exp_b:129     
[debug]:unite: num_a:002e6666  num_b:00b9999a
[debug]:added     : 00e80000
[debug]:rounded   : 00e80001
----------
kuro:3fb9999a + 40b9999a = 40e80001 (7.250000476837)
fpu :3fb9999a + 40b9999a = 40e80000 (7.250000000000)
*********NG********

というようにdisableにするとちゃんと 0.5 (というか 0b0.1000) のときにNGになる。