x86上128位二进制乘法最快速算法征解
可作为学习MMX、SSE2等系列汇编的案例教材使用C/C++代码和汇编在指令集限于SSE2以下的情况下
(x86, MMX, SSE, SSE2代码)
对128X128 二进制整数运算写出最快速的代码
函数形式
voidmul128(DWORD * result, DWORD * left, DWORD * right)
运算1000万次的累加时间在各机器上最少的为胜利 一年前,我曾发过千分求汇编优化:UInt96x96To192(...),
吸引了不少高手参与,用到的汇编指令包括ALU、SSE2等(我本人的SSE2汇编就是从这个讨论入门的)
上周与楼主聊天时提到了这个。楼主想来个更大的:void UInt128x128To256( UINT32 result, const UINT32 left, const UINT32 right );(我想像的函数形式,数组下标从小到大依次对应于变量的由低到高位)
楼主所给的函数原型不知是否与上述相同?
对出入参数组的大小是否限定?即result是否限定于8个DWORD。。。 是
但最好写做指针格式
也不考虑内存分配啊等
只考虑4 * 4 =8算法 为了统一计时
函数计算
FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF * FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF
= FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE00000000000000000000000000000001
作为测试样例, 但函数应该能处理任何输入
:) 看了你们的96位算法的讨论
发现两点问题
1、这么重要的问题, 代码里竟然没验证功能部分
2、没验证或者考虑MMX寄存器作为存储
PS:
可利用寄存器总共是下面几个
EAX, ECX, EDX, ECX
ESI, EDI
MM0-MM7
XMM0-XMM7
测试程序模板
感觉这次擂台比上次的 UInt96x96To192() 更有意义,<br>为了大家方便交流和测试,我做了一个模板:仅一个.c文件再加一个dsp文件即可。
<br>
<br>其中为了方便大家自测算法的正确性,我已完成了一个标准c的版本,
<br>该版本的结果与我用HugeCalc的一致,并经过反复推敲论证过的,应该是无bug且高效的。
<br>
代码如下:
<pre>
/************************************************************************/
/* UInt128x128To256.c */
/************************************************************************/
#include < wtypes.h >
#include < stdio.h >
#if 0
# define _RAND_TEST /* 随机测试 */
#endif
#ifndef _DEBUG /* 测试次数 */
#define _TEST_TIMES 10000000UL
#else
#define _TEST_TIMES 1UL /* 先DEBUG算法的正确性 */
#endif
typedef void ( *lpfn_UInt128x128To256 )( UINT32 * const result,
const UINT32 * const left,
const UINT32 * const right );
static BOOL s_bSupport_MMX = FALSE;
static BOOL s_bSupport_SSE = FALSE;
static BOOL s_bSupport_SSE2 = FALSE;
static UINT64 s_u64Frequency = 1;
static UINT64 s_u64Start, s_u64End;
static BYTE buffer[(4+4+8)*4+15];
static UINT32 * left = NULL;
static UINT32 * right = NULL;
static UINT32 * result = NULL;
void initParam( void )
{
/* 16字节对齐,以便于 SSE/SSE2 优化 */
left = (UINT32 *)( ((UINT32)( buffer+15 )) & -16 );
right = left + 4;
result = right + 4;
/* 待测函数中不得假定三个指针的相对偏移量情况 */
#ifndef _RAND_TEST
left = 0xFFFFFFFF;
left = 0xFFFFFFFF;
left = 0xFFFFFFFF;
left = 0xFFFFFFFF;
right = 0xFFFFFFFF;
right = 0xFFFFFFFF;
right = 0xFFFFFFFF;
right = 0xFFFFFFFF;
#else
/* 随机数确保可充满32bits */
# define RAND_VAL() (( (UINT32)(rand()) << 17 ) | rand())
srand(GetTickCount());
left = RAND_VAL();
left = RAND_VAL();
left = RAND_VAL();
left = RAND_VAL();
right = RAND_VAL();
right = RAND_VAL();
right = RAND_VAL();
right = RAND_VAL();
#endif
/* 测试 CPU 支持的指令集 */
__asm
{
mov eax, 1;
cpuid;
mov ecx, 800000h; /* 23 bit */
and ecx, edx;
neg ecx;
sbb ecx, ecx;
neg ecx;
mov dword ptr[ s_bSupport_MMX ], ecx;
mov ecx, 2000000h;/* 25 bit */
and ecx, edx;
neg ecx;
sbb ecx, ecx;
neg ecx;
mov dword ptr[ s_bSupport_SSE ], ecx;
mov ecx, 4000000h;/* 26 bit */
and ecx, edx;
neg ecx;
sbb ecx, ecx;
neg ecx;
mov dword ptr[ s_bSupport_SSE2 ], ecx;
}
QueryPerformanceFrequency((LARGE_INTEGER *)&s_u64Frequency );
}
/************************************************************************/
/* UInt128x128To256 ANSI C 版,经过了严格测试 */
/************************************************************************/
void UInt128x128To256_ANSI_C32( UINT32 * const result,
const UINT32 * const left,
const UINT32 * const right )
{
typedef union tag_UINT64
{
UINT64 u64Val;
UINT32 u32LH;
} UInt64;
UInt64 u64_0x0, u64_0x1, u64_0x2, u64_0x3;
UInt64 u64_1x0, u64_1x1, u64_1x2, u64_1x3;
UInt64 u64_2x0, u64_2x1, u64_2x2, u64_2x3;
UInt64 u64_3x0, u64_3x1, u64_3x2, u64_3x3;
u64_0x0.u64Val = UInt32x32To64( left, right );
u64_0x1.u64Val = UInt32x32To64( left, right );
u64_0x2.u64Val = UInt32x32To64( left, right );
u64_0x3.u64Val = UInt32x32To64( left, right );
u64_1x0.u64Val = UInt32x32To64( left, right );
u64_1x1.u64Val = UInt32x32To64( left, right );
u64_1x2.u64Val = UInt32x32To64( left, right );
u64_1x3.u64Val = UInt32x32To64( left, right );
u64_2x0.u64Val = UInt32x32To64( left, right );
u64_2x1.u64Val = UInt32x32To64( left, right );
u64_2x2.u64Val = UInt32x32To64( left, right );
u64_2x3.u64Val = UInt32x32To64( left, right );
u64_3x0.u64Val = UInt32x32To64( left, right );
u64_3x1.u64Val = UInt32x32To64( left, right );
u64_3x2.u64Val = UInt32x32To64( left, right );
u64_3x3.u64Val = UInt32x32To64( left, right );
/* FF FE00 01 --
FF FE00 01 --
FF FE00 01 --
FF FE00 01 --
FF FE00 01 --
FF FE00 01 --
FF FE00 01 --
FF FE00 01 --
FF FE00 01 --
FF FE00 01 --
FF FE00 01 --
FF FE00 01 --
FF FE00 01 --
FF FE00 01 --
FF FE00 01 --
FF FE00 01 --表示FFFF FFFE0000 0001*/
u64_0x1.u64Val += u64_0x0.u32LH;
u64_0x1.u64Val += u64_1x0.u64Val;
u64_2x0.u32LH += ( u64_0x1.u64Val < u64_1x0.u64Val );
u64_0x2.u64Val += u64_0x1.u32LH;
u64_0x2.u64Val += u64_1x1.u64Val;
u64_3x0.u32LH += ( u64_0x2.u64Val < u64_1x1.u64Val );
u64_0x2.u64Val += u64_2x0.u64Val;
u64_2x1.u32LH += ( u64_0x2.u64Val < u64_2x0.u64Val );
u64_0x3.u64Val += u64_0x2.u32LH;
u64_0x3.u64Val += u64_1x2.u64Val;
u64_3x1.u32LH += ( u64_0x3.u64Val < u64_1x2.u64Val );
u64_0x3.u64Val += u64_2x1.u64Val;
u64_2x2.u32LH += ( u64_0x3.u64Val < u64_2x1.u64Val );
u64_0x3.u64Val += u64_3x0.u64Val;
u64_2x3.u64Val += ( u64_0x3.u64Val < u64_3x0.u64Val );
u64_1x3.u64Val += u64_0x3.u32LH;
u64_1x3.u64Val += u64_2x2.u64Val;
u64_3x2.u32LH += ( u64_1x3.u64Val < u64_2x2.u64Val );
u64_1x3.u64Val += u64_3x1.u64Val;
u64_3x3.u64Val += ( u64_1x3.u64Val < u64_3x1.u64Val );
u64_2x3.u64Val += u64_1x3.u32LH;
u64_2x3.u64Val += u64_3x2.u64Val;
u64_3x3.u32LH += ( u64_2x3.u64Val < u64_3x2.u64Val );
u64_3x3.u64Val += u64_2x3.u32LH;
result = u64_0x0.u32LH;
result = u64_0x1.u32LH;
result = u64_0x2.u32LH;
result = u64_0x3.u32LH;
result = u64_1x3.u32LH;
result = u64_2x3.u32LH;
result = u64_3x3.u32LH;
result = u64_3x3.u32LH;
}
/************************************************************************/
/* 测试函数代码粘贴区开始 begin{ */
/*----------------------------------------------------------------------*/
/* 待测函数命名规范(推荐):UInt128x128To256_{1}_{2}(...) */
/* 其中{1}代表用到的最高级指令集; */
/* {2}代表发帖楼层,如果写错了可以自行修订或请管理员帮助修改 */
/* 以后大家只需发自己待测函数的代码即可,不必再贴全测试代码 */
/************************************************************************/
_declspec(naked)
void UInt128x128To256_MMX_xxxF( UINT32 * const result,
const UINT32 * const left,
const UINT32 * const right )
{
__asm
{
/* do something ... */
ret;
}
}
_declspec(naked)
void UInt128x128To256_SSE_xxxF( UINT32 * const result,
const UINT32 * const left,
const UINT32 * const right )
{
__asm
{
/* do something ... */
ret;
}
}
_declspec(naked)
void UInt128x128To256_SSE2_11F( UINT32 * const result,
const UINT32 * const left,
const UINT32 * const right )
{
__asm
{
/* 具体代码在 11#,需登陆才可见 */
ret;
}
}
/************************************************************************/
/* }end 测试函数代码粘贴区结束 */
/************************************************************************/
void testFun( const lpfn_UInt128x128To256 pFun,
const LPCTSTR lpszFunName,
const UINT32 u32TestTimes )
{
UINT32 i;
printf( "\nTest function: %s(..) %u times... \n",
lpszFunName, u32TestTimes );
QueryPerformanceCounter((LARGE_INTEGER *)&s_u64Start );
for ( i = 0; i < u32TestTimes; ++i )
{
(*pFun)( result, left, right );
}
QueryPerformanceCounter((LARGE_INTEGER *)&s_u64End );
i = (UINT32)(( s_u64End - s_u64Start ) * 1000000UL / s_u64Frequency );
printf( "Elapsed time: %d.%03u ms\n", i / 1000, i % 1000 );
printf( "%08X %08X %08X %08X * %08X %08X %08X %08X\n"
"= %08X %08X %08X %08X %08X %08X %08X %08X\n",
left, left, left, left,
right, right, right, right,
result, result, result, result,
result, result, result, result );
}
int main(int argc, char* argv[])
{
initParam();
/* 标准结果 */
testFun( UInt128x128To256_ANSI_C32,
"UInt128x128To256_ANSI_C32", /*1*/ _TEST_TIMES );
/* MMX 版本测试 */
if ( s_bSupport_MMX )
{
testFun( UInt128x128To256_MMX_xxxF,
"UInt128x128To256_MMX_xxxF", _TEST_TIMES );
/* test other functions:
testFun( ... );
*/
}
/* SSE 版本测试 */
if ( s_bSupport_SSE )
{
testFun( UInt128x128To256_SSE_xxxF,
"UInt128x128To256_SSE_xxxF", _TEST_TIMES );
/* test other functions:
testFun( ... );
*/
}
/* SSE2 版本测试 */
if ( s_bSupport_SSE2 )
{
testFun( UInt128x128To256_SSE2_11F,
"UInt128x128To256_SSE2_11F", _TEST_TIMES );
/* test other functions:
testFun( ... );
*/
}
printf( "\n" );
system( "pause" );
return 0;
}</pre>统一测试模版,形成测试标准,方便交流。
<br>完整的测试模板包如下:(包含一个.c文件,及dsp文件)
<br>
<br>请有兴趣的朋友在本地机上编译运行;如发现任何问题,请及时反馈,以便修正,谢谢! 一共四个字
l0, l0, l2, l3
r0, r1, r2, r3
esi = l
edi = r
ebx = result 下面是乘法三角形
l3 r3 l2 r2 l1 r1 l0 r0
l2 r3 l1 r2 l0 r1
l3 r2 l2 r1 l1 r0
l1 r3 l0 r2
l3 r1l2r0
l0 r3
l3 r0
左面是最高位
每个数字宽度为32位
从左面开始 相邻符号表示两个数字乘积
比如l3 r3表示 l3 * r3 可以判断
从最右上角乘并向左推进是最节约的方法 现在的问题是
存在最多4个四字加
这么做?
假设mm0 + mm1 + mm2 + mm3
xor mm4,mm4
xor mm5,mm5
xor mm6. mm6
xor mm7, mm7
movd mm4,mm0
pshrq mm0, 32
movd mm5, mm1
pshrq mm1, 32
addq mm4, mm5
movd mm6, mm2
psrq mm2, 32
addq mm4, mm6
movd mm7, mm3
psrq mm3, 32
addq mm4, mm7
movd , mm4
pshrq mm4, 32
addq mm0, mm1
addq mm2, mm3
addq mm0, mm2
addq mm0, mm4
现在结果在mm0: