模块化算术和 NTT(有限域 DFT)优化

2024-05-19

我想使用 NTT 进行快速平方(参见快速大数平方计算 https://stackoverflow.com/q/18465326/2521214),但即使对于非常大的数字,结果也很慢......超过 12000 位。

所以我的问题是:

  1. 有没有办法优化我的 NTT 转换? 我并不是想通过并行性(线程)来加速它;这只是低层。
  2. 有没有办法加快我的模运算速度?

这是我的(已经优化的)NTT 的 C++ 源代码(它是完整的并且 100% 在 C++ 中工作,不需要任何第三方库,并且也应该是线程安全的。请注意源数组被用作临时数组! ,它也不能将数组转换为其自身)。

//---------------------------------------------------------------------------
class fourier_NTT                                    // Number theoretic transform
    {

public:
    DWORD r,L,p,N;
    DWORD W,iW,rN;
    fourier_NTT(){ r=0; L=0; p=0; W=0; iW=0; rN=0; }

    // main interface
    void  NTT(DWORD *dst,DWORD *src,DWORD n=0);               // DWORD dst[n] = fast  NTT(DWORD src[n])
    void INTT(DWORD *dst,DWORD *src,DWORD n=0);               // DWORD dst[n] = fast INTT(DWORD src[n])

    // Helper functions
    bool init(DWORD n);                                       // init r,L,p,W,iW,rN
    void  NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w);    // DWORD dst[n] = fast  NTT(DWORD src[n])

    // Only for testing
    void  NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w);    // DWORD dst[n] = slow  NTT(DWORD src[n])
    void INTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w);    // DWORD dst[n] = slow INTT(DWORD src[n])

    // DWORD arithmetics
    DWORD shl(DWORD a);
    DWORD shr(DWORD a);

    // Modular arithmetics
    DWORD mod(DWORD a);
    DWORD modadd(DWORD a,DWORD b);
    DWORD modsub(DWORD a,DWORD b);
    DWORD modmul(DWORD a,DWORD b);
    DWORD modpow(DWORD a,DWORD b);
    };

//---------------------------------------------------------------------------
void fourier_NTT:: NTT(DWORD *dst,DWORD *src,DWORD n)
    {
    if (n>0) init(n);
    NTT_fast(dst,src,N,W);
//    NTT_slow(dst,src,N,W);
    }

//---------------------------------------------------------------------------
void fourier_NTT::INTT(DWORD *dst,DWORD *src,DWORD n)
    {
    if (n>0) init(n);
    NTT_fast(dst,src,N,iW);
    for (DWORD i=0;i<N;i++) dst[i]=modmul(dst[i],rN);
       //    INTT_slow(dst,src,N,W);
    }

//---------------------------------------------------------------------------
bool fourier_NTT::init(DWORD n)
    {
    // (max(src[])^2)*n < p else NTT overflow can ocur !!!
    r=2; p=0xC0000001; if ((n<2)||(n>0x10000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x30000000/n; // 32:30 bit best for unsigned 32 bit
//    r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit
//    r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit
//    r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit
     N=n;                // size of vectors [DWORDs]
     W=modpow(r,    L);    // Wn for NTT
    iW=modpow(r,p-1-L);    // Wn for INTT
    rN=modpow(n,p-2  );    // scale for INTT
    return true;
    }

//---------------------------------------------------------------------------
void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w)
    {
    if (n<=1) { if (n==1) dst[0]=src[0]; return; }
    DWORD i,j,a0,a1,n2=n>>1,w2=modmul(w,w);
    // reorder even,odd
    for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j];
    for (    j=1;i<n ;i++,j+=2) dst[i]=src[j];
    // recursion
    NTT_fast(src   ,dst   ,n2,w2);    // even
    NTT_fast(src+n2,dst+n2,n2,w2);    // odd
    // restore results
    for (w2=1,i=0,j=n2;i<n2;i++,j++,w2=modmul(w2,w))
        {
        a0=src[i];
        a1=modmul(src[j],w2);
        dst[i]=modadd(a0,a1);
        dst[j]=modsub(a0,a1);
        }
    }

//---------------------------------------------------------------------------
void fourier_NTT:: NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w)
    {
    DWORD i,j,wj,wi,a,n2=n>>1;
    for (wj=1,j=0;j<n;j++)
        {
        a=0;
        for (wi=1,i=0;i<n;i++)
            {
            a=modadd(a,modmul(wi,src[i]));
            wi=modmul(wi,wj);
            }
        dst[j]=a;
        wj=modmul(wj,w);
        }
    }

//---------------------------------------------------------------------------
void fourier_NTT::INTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w)
    {
    DWORD i,j,wi=1,wj=1,a,n2=n>>1;
    for (wj=1,j=0;j<n;j++)
        {
        a=0;
        for (wi=1,i=0;i<n;i++)
            {
            a=modadd(a,modmul(wi,src[i]));
            wi=modmul(wi,wj);
            }
        dst[j]=modmul(a,rN);
        wj=modmul(wj,iW);
        }
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::shl(DWORD a) { return (a<<1)&0xFFFFFFFE; }
DWORD fourier_NTT::shr(DWORD a) { return (a>>1)&0x7FFFFFFF; }

//---------------------------------------------------------------------------
DWORD fourier_NTT::mod(DWORD a)
    {
    DWORD bb;
    for (bb=p;(DWORD(a)>DWORD(bb))&&(!DWORD(bb&0x80000000));bb=shl(bb));
    for (;;)
        {
        if (DWORD(a)>=DWORD(bb)) a-=bb;
        if (bb==p) break;
        bb =shr(bb);
        }
    return a;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modadd(DWORD a,DWORD b)
    {
    DWORD d,cy;
    a=mod(a);
    b=mod(b);
    d=a+b;
    cy=(shr(a)+shr(b)+shr((a&1)+(b&1)))&0x80000000;
    if (cy) d-=p;
    if (DWORD(d)>=DWORD(p)) d-=p;
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modsub(DWORD a,DWORD b)
    {
    DWORD d;
    a=mod(a);
    b=mod(b);
    d=a-b; if (DWORD(a)<DWORD(b)) d+=p;
    if (DWORD(d)>=DWORD(p)) d-=p;
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modmul(DWORD a,DWORD b)
    {    // b bez orezania !
    int i;
    DWORD d;
    a=mod(a);
    for (d=0,i=0;i<32;i++)
        {
        if (DWORD(a&1))    d=modadd(d,b);
        a=shr(a);
        b=modadd(b,b);
        }
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modpow(DWORD a,DWORD b)
    {    // a,b bez orezania !
    int i;
    DWORD d=1;
    for (i=0;i<32;i++)
        {
        d=modmul(d,d);
        if (DWORD(b&0x80000000)) d=modmul(d,a);
        b=shl(b);
        }
    return d;
    }
//---------------------------------------------------------------------------

我的 NTT 类的用法示例:

fourier_NTT ntt;
const DWORD n=32
DWORD x[N]={0,1,2,3,....31},y[N]={32,33,34,35,...63},z[N];

ntt.NTT(z,x,N);    // z[N]=NTT(x[N]), also init constants for N
ntt.NTT(x,y);    // x[N]=NTT(y[N]), no recompute of constants, use last N
// modular convolution y[]=z[].x[]
for (i=0;i<n;i++) y[i]=ntt.modmul(z[i],x[i]);
ntt.INTT(x,y);    // x[N]=INTT(y[N]), no recompute of constants, use last N
// x[]=convolution of original x[].y[]

优化前的一些测量(非 NTT 类):

a = 0.98765588997654321000 | 389*32 bits
looped 1x times
sqr1[ 3.177 ms ] fast sqr
sqr2[ 720.419 ms ] NTT sqr
mul1[ 5.588 ms ] simpe mul
mul2[ 3.172 ms ] karatsuba mul
mul3[ 1053.382 ms ] NTT mul

优化后的一些测量(当前代码、较低的递归参数大小/计数以及更好的模块化算术):

a = 0.98765588997654321000 | 389*32 bits
looped 1x times
sqr1[ 3.214 ms ] fast sqr
sqr2[ 208.298 ms ] NTT sqr
mul1[ 5.564 ms ] simpe mul
mul2[ 3.113 ms ] karatsuba mul
mul3[ 302.740 ms ] NTT mul

检查 NTT mul 和 NTT sqr 时间(我的优化将其速度加快了 3 倍多一点)。它只有 1 倍循环,所以不是很精确(误差 ~ 10%),但即使现在加速也很明显(通常我循环 1000 倍或更多,但我的 NTT 太慢了)。

您可以自由使用我的代码...只需将我的昵称和/或指向此页面的链接保留在某处(代码中的 rem、readme.txt、about 或其他内容)。我希望它有帮助...(我没有在任何地方看到快速 NTT 的 C++ 源代码,所以我不得不自己编写它)。对所有接受的 N 进行了统一根测试,请参阅fourier_NTT::init(DWORD n)功能。

PS:有关 NTT 的更多信息,请参阅从复数 FFT 到有限场 FFT 的转换 https://stackoverflow.com/a/18547575/2521214。该代码基于我在该链接内的帖子。

[edit1:] 代码的进一步更改

通过利用模素数始终为 0xC0000001 并消除不必要的调用,我成功地进一步优化了我的模运算。由此产生的加速效果令人惊叹(超过 40 倍),并且在大约 1500 * 32 位阈值后,NTT 乘法比 karatsuba 更快。顺便说一句,我的 NTT 的速度现在与我在 64 位双精度上优化的 DFFT 相同。

一些测量:

a = 0.98765588997654321000 | 1553*32bits
looped 10x times
mul2[ 28.585 ms ] karatsuba mul
mul3[ 26.311 ms ] NTT mul

模算术的新源代码:

//---------------------------------------------------------------------------
DWORD fourier_NTT::mod(DWORD a)
    {
    if (a>p) a-=p;
    return a;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modadd(DWORD a,DWORD b)
    {
    DWORD d,cy;
    if (a>p) a-=p;
    if (b>p) b-=p;
    d=a+b;
    cy=((a>>1)+(b>>1)+(((a&1)+(b&1))>>1))&0x80000000;
    if (cy ) d-=p;
    if (d>p) d-=p;
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modsub(DWORD a,DWORD b)
    {
    DWORD d;
    if (a>p) a-=p;
    if (b>p) b-=p;
    d=a-b;
    if (a<b) d+=p;
    if (d>p) d-=p;
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modmul(DWORD a,DWORD b)
    {
    DWORD _a,_b,_p;
    _a=a;
    _b=b;
    _p=p;
    asm    {
        mov    eax,_a
        mov    ebx,_b
        mul    ebx        // H(edx),L(eax) = eax * ebx
        mov    ebx,_p
        div    ebx        // eax = H(edx),L(eax) / ebx
        mov    _a,edx    // edx = H(edx),L(eax) % ebx
        }
    return _a;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modpow(DWORD a,DWORD b)
    {    // b bez orezania!
    int i;
    DWORD d=1;
    if (a>p) a-=p;
    for (i=0;i<32;i++)
        {
        d=modmul(d,d);
        if (DWORD(b&0x80000000)) d=modmul(d,a);
        b<<=1;
        }
    return d;
    }

//---------------------------------------------------------------------------

正如你所看到的,函数shl and shr不再使用。我认为 modpow 可以进一步优化,但它不是一个关键函数,因为它只被调用很少次。最关键的函数是 modmul,它似乎处于最佳状态。

进一步的问题:

  • 还有其他选项可以加速 NTT 吗?
  • 我对模算术的优化安全吗? (结果似乎是一样的,但我可能会错过一些东西。)

[edit2] 新的优化

a = 0.99991970486 | 2000*32 bits
looped 10x
sqr1[  13.908 ms ] fast sqr
sqr2[  13.649 ms ] NTT sqr
mul1[  19.726 ms ] simpe mul
mul2[  31.808 ms ] karatsuba mul
mul3[  19.373 ms ] NTT mul

我根据您的所有评论实现了所有可用的内容(感谢您的见解)。

加速:

  • +2.5% 通过删除不必要的安全模块 (Mandalf The Beige)
  • 使用预先计算的 W,iW 能量 +34.9%(神秘)
  • +35% 总计

实际完整源代码:

//---------------------------------------------------------------------------
//--- Number theoretic transforms: 2.03 -------------------------------------
//---------------------------------------------------------------------------
#ifndef _fourier_NTT_h
#define _fourier_NTT_h
//---------------------------------------------------------------------------
//---------------------------------------------------------------------------
class fourier_NTT        // Number theoretic transform
    {
public:
    DWORD r,L,p,N;
    DWORD W,iW,rN;        // W=(r^L) mod p, iW=inverse W, rN = inverse N
    DWORD *WW,*iWW,NN;    // Precomputed (W,iW)^(0,..,NN-1) powers

    // Internals
    fourier_NTT(){ r=0; L=0; p=0; W=0; iW=0; rN=0; WW=NULL; iWW=NULL; NN=0; }
    ~fourier_NTT(){ _free(); }
    void _free();                                            // Free precomputed W,iW powers tables
    void _alloc(DWORD n);                                    // Allocate and precompute W,iW powers tables

    // Main interface
    void  NTT(DWORD *dst,DWORD *src,DWORD n=0);                // DWORD dst[n] = fast  NTT(DWORD src[n])
    void iNTT(DWORD *dst,DWORD *src,DWORD n=0);               // DWORD dst[n] = fast INTT(DWORD src[n])

    // Helper functions
    bool init(DWORD n);                                          // init r,L,p,W,iW,rN
    void  NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w);    // DWORD dst[n] = fast  NTT(DWORD src[n])
    void  NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD *w2,DWORD i2);

    // Only for testing
    void  NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w);    // DWORD dst[n] = slow  NTT(DWORD src[n])
    void iNTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w);    // DWORD dst[n] = slow INTT(DWORD src[n])

    // Modular arithmetics (optimized, but it works only for p >= 0x80000000!!!)
    DWORD mod(DWORD a);
    DWORD modadd(DWORD a,DWORD b);
    DWORD modsub(DWORD a,DWORD b);
    DWORD modmul(DWORD a,DWORD b);
    DWORD modpow(DWORD a,DWORD b);
    };
//---------------------------------------------------------------------------

//---------------------------------------------------------------------------
void fourier_NTT::_free()
    {
    NN=0;
    if ( WW) delete[]  WW;  WW=NULL;
    if (iWW) delete[] iWW; iWW=NULL;
    }

//---------------------------------------------------------------------------
void fourier_NTT::_alloc(DWORD n)
    {
    if (n<=NN) return;
    DWORD *tmp,i,w;
    tmp=new DWORD[n]; if ((NN)&&( WW)) for (i=0;i<NN;i++) tmp[i]= WW[i]; if ( WW) delete[]  WW;  WW=tmp;  WW[0]=1; for (i=NN?NN:1,w= WW[i-1];i<n;i++){ w=modmul(w, W);  WW[i]=w; }
    tmp=new DWORD[n]; if ((NN)&&(iWW)) for (i=0;i<NN;i++) tmp[i]=iWW[i]; if (iWW) delete[] iWW; iWW=tmp; iWW[0]=1; for (i=NN?NN:1,w=iWW[i-1];i<n;i++){ w=modmul(w,iW); iWW[i]=w; }
    NN=n;
    }

//---------------------------------------------------------------------------
void fourier_NTT:: NTT(DWORD *dst,DWORD *src,DWORD n)
    {
    if (n>0) init(n);
    NTT_fast(dst,src,N,WW,1);
//    NTT_fast(dst,src,N,W);
//    NTT_slow(dst,src,N,W);
    }

//---------------------------------------------------------------------------
void fourier_NTT::iNTT(DWORD *dst,DWORD *src,DWORD n)
    {
    if (n>0) init(n);
    NTT_fast(dst,src,N,iWW,1);
//    NTT_fast(dst,src,N,iW);
    for (DWORD i=0;i<N;i++) dst[i]=modmul(dst[i],rN);
//    iNTT_slow(dst,src,N,W);
    }

//---------------------------------------------------------------------------
bool fourier_NTT::init(DWORD n)
    {
    // (max(src[])^2)*n < p else NTT overflow can ocur!!!
    r=2; p=0xC0000001; if ((n<2)||(n>0x10000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x30000000/n; // 32:30 bit best for unsigned 32 bit
//    r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit
//    r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit
//    r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit
     N=n;                // Size of vectors [DWORDs]
     W=modpow(r,    L);  // Wn for NTT
    iW=modpow(r,p-1-L);  // Wn for INTT
    rN=modpow(n,p-2  );  // Scale for INTT
    _alloc(n>>1);        // Precompute W,iW powers
    return true;
    }

//---------------------------------------------------------------------------
void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD w)
    {
    if (n<=1) { if (n==1) dst[0]=src[0]; return; }
    DWORD i,j,a0,a1,n2=n>>1,w2=modmul(w,w);

    // Reorder even,odd
    for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j];
    for (    j=1;i<n ;i++,j+=2) dst[i]=src[j];

    // Recursion
    NTT_fast(src   ,dst   ,n2,w2);    // Even
    NTT_fast(src+n2,dst+n2,n2,w2);    // Odd

    // Restore results
    for (w2=1,i=0,j=n2;i<n2;i++,j++,w2=modmul(w2,w))
        {
        a0=src[i];
        a1=modmul(src[j],w2);
        dst[i]=modadd(a0,a1);
        dst[j]=modsub(a0,a1);
        }
    }

//---------------------------------------------------------------------------
void fourier_NTT:: NTT_fast(DWORD *dst,DWORD *src,DWORD n,DWORD *w2,DWORD i2)
    {
    if (n<=1) { if (n==1) dst[0]=src[0]; return; }
    DWORD i,j,a0,a1,n2=n>>1;

    // Reorder even,odd
    for (i=0,j=0;i<n2;i++,j+=2) dst[i]=src[j];
    for (    j=1;i<n ;i++,j+=2) dst[i]=src[j];

    // Recursion
    i=i2<<1;
    NTT_fast(src   ,dst   ,n2,w2,i);    // Even
    NTT_fast(src+n2,dst+n2,n2,w2,i);    // Odd

    // Restore results
    for (i=0,j=n2;i<n2;i++,j++,w2+=i2)
        {
        a0=src[i];
        a1=modmul(src[j],*w2);
        dst[i]=modadd(a0,a1);
        dst[j]=modsub(a0,a1);
        }
    }

//---------------------------------------------------------------------------
void fourier_NTT:: NTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w)
    {
    DWORD i,j,wj,wi,a;
    for (wj=1,j=0;j<n;j++)
        {
        a=0;
        for (wi=1,i=0;i<n;i++)
            {
            a=modadd(a,modmul(wi,src[i]));
            wi=modmul(wi,wj);
            }
        dst[j]=a;
        wj=modmul(wj,w);
        }
    }

//---------------------------------------------------------------------------
void fourier_NTT::iNTT_slow(DWORD *dst,DWORD *src,DWORD n,DWORD w)
    {
    DWORD i,j,wi=1,wj=1,a;
    for (wj=1,j=0;j<n;j++)
        {
        a=0;
        for (wi=1,i=0;i<n;i++)
            {
            a=modadd(a,modmul(wi,src[i]));
            wi=modmul(wi,wj);
            }
        dst[j]=modmul(a,rN);
        wj=modmul(wj,iW);
        }
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::mod(DWORD a)
    {
    if (a>p) a-=p;
    return a;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modadd(DWORD a,DWORD b)
    {
    DWORD d,cy;
    //if (a>p) a-=p;
    //if (b>p) b-=p;
    d=a+b;
    cy=((a>>1)+(b>>1)+(((a&1)+(b&1))>>1))&0x80000000;
    if (cy ) d-=p;
    if (d>p) d-=p;
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modsub(DWORD a,DWORD b)
    {
    DWORD d;
    //if (a>p) a-=p;
    //if (b>p) b-=p;
    d=a-b;
    if (a<b) d+=p;
    if (d>p) d-=p;
    return d;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modmul(DWORD a,DWORD b)
    {
    DWORD _a,_b,_p;
    _a=a;
    _b=b;
    _p=p;
    asm    {
        mov    eax,_a
        mov    ebx,_b
        mul    ebx        // H(edx),L(eax) = eax * ebx
        mov    ebx,_p
        div    ebx        // eax = H(edx),L(eax) / ebx
        mov    _a,edx    // edx = H(edx),L(eax) % ebx
        }
    return _a;
    }

//---------------------------------------------------------------------------
DWORD fourier_NTT::modpow(DWORD a,DWORD b)
    {    // b is not mod(p)!
    int i;
    DWORD d=1;
    //if (a>p) a-=p;
    for (i=0;i<32;i++)
        {
        d=modmul(d,d);
        if (DWORD(b&0x80000000)) d=modmul(d,a);
        b<<=1;
        }
    return d;
    }
//---------------------------------------------------------------------------
//---------------------------------------------------------------------------
#endif
//---------------------------------------------------------------------------
//---------------------------------------------------------------------------

仍然有可能通过分离来减少堆垃圾的使用NTT_fast到两个函数。一与WW[]另一个与iWW[]这会导致递归调用中少一个参数。但我对它的期望并不高(仅限 32 位指针),而是希望拥有一个函数,以便将来更好地管理代码。许多功能现在处于休眠状态(用于测试),例如慢速变体,mod和旧的快速功能(与w参数而不是*w2,i2).

为了避免大数据集溢出,请将输入数量限制为p/4 bits. Where p是每位数NTT元素,因此对于这个 32 位版本,使用 max(32 bit/4 -> 8 bit)输入值。

[edit3] 简单字符串bigint用于测试的乘法

//---------------------------------------------------------------------------
char* mul_NTT(const char *sx,const char *sy)
    {
    char *s;
    int i,j,k,n;
    // n = min power of 2 <= 2 max length(x,y)
    for (i=0;sx[i];i++); for (n=1;n<i;n<<=1);        i--;
    for (j=0;sx[j];j++); for (n=1;n<j;n<<=1); n<<=1; j--;
    DWORD *x,*y,*xx,*yy,a;
    x=new DWORD[n]; xx=new DWORD[n];
    y=new DWORD[n]; yy=new DWORD[n];

    // Zero padding
    for (k=0;i>=0;i--,k++) x[k]=sx[i]-'0'; for (;k<n;k++) x[k]=0;
    for (k=0;j>=0;j--,k++) y[k]=sy[j]-'0'; for (;k<n;k++) y[k]=0;

    //NTT
    fourier_NTT ntt;
    ntt.NTT(xx,x,n);
    ntt.NTT(yy,y);

    // Convolution
    for (i=0;i<n;i++) xx[i]=ntt.modmul(xx[i],yy[i]);

    //INTT
    ntt.iNTT(yy,xx);

    //suma
    a=0; s=new char[n+1]; for (i=0;i<n;i++) { a+=yy[i]; s[n-i-1]=(a%10)+'0'; a/=10; } s[n]=0;
    delete[] x; delete[] xx;
    delete[] y; delete[] yy;

    return s;
    }
//---------------------------------------------------------------------------

I use AnsiString的,所以我将其移植到char*希望我没有犯一些错误。看起来它工作正常(与AnsiString版本)。

  • sx,sy是十进制整数
  • 返回分配的字符串(char*)=sx*sy

每个 32 位数据字大约只有 4 位,因此不存在溢出风险,但当然速度较慢。在我的bignumlib 我使用二进制表示并使用8 bit每个 32 位 WORD 的块数NTT。超过这个值就会有风险,如果N很大...

玩得开心


首先,非常感谢您发布并免费使用。我真的很感激。

我能够使用一些技巧来消除一些分支,重新安排主循环,并修改程序集,并且能够获得 1.35 倍的加速。

另外,我添加了 64 位的预处理器条件,因为 Visual Studio 不允许在 64 位模式下进行内联汇编(谢谢 Microsoft;随意去搞砸吧)。

当我优化 modsub() 函数时,发生了一些奇怪的事情。我用位黑客重写了它,就像我做 modadd 一样(更快)。但由于某种原因,modsub 的按位版本速度较慢。不知道为什么。可能只是我的电脑。

//
// Mandalf The Beige
// Based on:
// Spektre
// http://stackoverflow.com/questions/18577076/modular-arithmetics-and-ntt-finite-field-dft-optimizations
//
// This code may be freely used however you choose, so long as it is accompanied by this notice.
//




#ifndef H__OPTIMIZED_NUMBER_THEORETIC_TRANSFORM__HDR
#define H__OPTIMIZED_NUMBER_THEORETIC_TRANSFORM__HDR

#include <string.h>

#ifndef uint32
#define uint32 unsigned long int
#endif

#ifndef uint64
#define uint64 unsigned long long int
#endif


class fast_ntt                                   // number theoretic transform
{
    public:
    fast_ntt()
    {
        r = 0; L = 0;
        W = 0; iW = 0; rN = 0;
    }
    // main interface
    void  NTT(uint32 *dst, uint32 *src, uint32 n = 0);             // uint32 dst[n] = fast  NTT(uint32 src[n])
    void INTT(uint32 *dst, uint32 *src, uint32 n = 0);             // uint32 dst[n] = fast INTT(uint32 src[n])
    // helper functions

    private:
    bool init(uint32 n);                                     // init r,L,p,W,iW,rN
    void NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = fast  NTT(uint32 src[n])

    void  NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = fast  NTT(uint32 src[n])
    void NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w);
    // only for testing
    void  NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = slow  NTT(uint32 src[n])
    void INTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = slow INTT(uint32 src[n])
    // uint32 arithmetics


    // modular arithmetics
    inline uint32 modadd(uint32 a, uint32 b);
    inline uint32 modsub(uint32 a, uint32 b);
    inline uint32 modmul(uint32 a, uint32 b);
    inline uint32 modpow(uint32 a, uint32 b);

    uint32 r, L, N;//, p;
    uint32 W, iW, rN;

    const uint32 p = 0xC0000001;
};

//---------------------------------------------------------------------------
void fast_ntt::NTT(uint32 *dst, uint32 *src, uint32 n)
{
    if (n > 0)
    {
        init(n);
    }
    NTT_fast(dst, src, N, W);
    //  NTT_slow(dst,src,N,W);
}

//---------------------------------------------------------------------------
void fast_ntt::INTT(uint32 *dst, uint32 *src, uint32 n)
{
    if (n > 0)
    {
        init(n);
    }
    NTT_fast(dst, src, N, iW);
    for (uint32 i = 0; i<N; i++)
    {
        dst[i] = modmul(dst[i], rN);
    }
    //  INTT_slow(dst,src,N,W);
}

//---------------------------------------------------------------------------
bool fast_ntt::init(uint32 n)
{
    // (max(src[])^2)*n < p else NTT overflow can ocur !!!
    r = 2;
    //p = 0xC0000001;
    if ((n < 2) || (n > 0x10000000))
    {
        r = 0; L = 0; W = 0; // p = 0;
        iW = 0; rN = 0; N = 0;
        return false;
    }
    L = 0x30000000 / n; // 32:30 bit best for unsigned 32 bit
    //  r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit
    //  r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit
    //  r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit
    N = n;               // size of vectors [uint32s]
    W = modpow(r, L); // Wn for NTT
    iW = modpow(r, p - 1 - L); // Wn for INTT
    rN = modpow(n, p - 2); // scale for INTT
    return true;
}

//---------------------------------------------------------------------------

void fast_ntt::NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    if(n > 1)
    {
        if(dst != src)
        {
            NTT_calc(dst, src, n, w);
        }
        else
        {
            uint32* temp = new uint32[n];
            NTT_calc(temp, src, n, w);
            memcpy(dst, temp, n * sizeof(uint32));
            delete [] temp;
        }
    }
    else if(n == 1)
    {
        dst[0] = src[0];
    }
}

void fast_ntt::NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w)
{
    if (n > 1)
    {
        uint32* temp = new uint32[n];
        memcpy(temp, src, n * sizeof(uint32));
        NTT_calc(dst, temp, n, w);
        delete[] temp;
    }
    else if (n == 1)
    {
        dst[0] = src[0];
    }
}



void fast_ntt::NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    if(n > 1)
    {
        uint32 i, j, a0, a1,
        n2 = n >> 1,
        w2 = modmul(w, w);

        // reorder even,odd
        for (i = 0, j = 0; i < n2; i++, j += 2)
        {
            dst[i] = src[j];
        }
        for (j = 1; i < n; i++, j += 2)
        {
            dst[i] = src[j];
        }
        // recursion
        if(n2 > 1)
        {
            NTT_calc(src, dst, n2, w2);  // even
            NTT_calc(src + n2, dst + n2, n2, w2);  // odd
        }
        else if(n2 == 1)
        {
            src[0] = dst[0];
            src[1] = dst[1];
        }

        // restore results

        w2 = 1, i = 0, j = n2;
        a0 = src[i];
        a1 = src[j];
        dst[i] = modadd(a0, a1);
        dst[j] = modsub(a0, a1);
        while (++i < n2)
        {
            w2 = modmul(w2, w);
            j++;
            a0 = src[i];
            a1 = modmul(src[j], w2);
            dst[i] = modadd(a0, a1);
            dst[j] = modsub(a0, a1);
        }
    }
}

//---------------------------------------------------------------------------
void fast_ntt::NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    uint32 i, j, wj, wi, a,
        n2 = n >> 1;
    for (wj = 1, j = 0; j < n; j++)
    {
        a = 0;
        for (wi = 1, i = 0; i < n; i++)
        {
            a = modadd(a, modmul(wi, src[i]));
            wi = modmul(wi, wj);
        }
        dst[j] = a;
        wj = modmul(wj, w);
    }
}

//---------------------------------------------------------------------------
void fast_ntt::INTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    uint32 i, j, wi = 1, wj = 1, a, n2 = n >> 1;

    for (wj = 1, j = 0; j < n; j++)
    {
        a = 0;
        for (wi = 1, i = 0; i < n; i++)
        {
            a = modadd(a, modmul(wi, src[i]));
            wi = modmul(wi, wj);
        }
        dst[j] = modmul(a, rN);
        wj = modmul(wj, iW);
    }
}    


//---------------------------------------------------------------------------
uint32 fast_ntt::modadd(uint32 a, uint32 b)
{
    uint32 d;
    d = a + b;

    if(d < a)
    {
        d -= p;
    }
    if (d >= p)
    {
        d -= p;
    }
    return d;
}

//---------------------------------------------------------------------------
uint32 fast_ntt::modsub(uint32 a, uint32 b)
{
    uint32 d;
    d = a - b;
    if (d > a)
    {
        d += p;
    }
    return d;
}

//---------------------------------------------------------------------------
uint32 fast_ntt::modmul(uint32 a, uint32 b)
{
    uint32 _a = a;
    uint32 _b = b;

    // Original
    uint32 _p = p;
    __asm
    {
        mov eax, _a;
        mul _b;
        div _p;
        mov eax, edx;
    };
}


uint32 fast_ntt::modpow(uint32 a, uint32 b)
{
    //*
    uint64 D, M, A, P;

    P = p; A = a;
    M = 0llu - (b & 1);
    D = (M & A) | ((~M) & 1);

    while ((b >>= 1) != 0)
    {
        A = modmul(A, A);
        //A = (A * A) % P;

        if ((b & 1) == 1)
        {
            //D = (D * A) % P;
            D = modmul(D, A);
        }
    }
    return (uint32)D;
}

新模乘

uint32 fast_ntt::modmul(uint32 a, uint32 b)
{
    uint32 _a = a;
    uint32 _b = b;   

    __asm
    {
    mov eax, a;
    mul b;
    mov ebx, eax;
    mov eax, 2863311530;
    mov ecx, edx;
    mul edx;
    shld edx, eax, 1;
    mov eax, 3221225473;

    mul edx;
    sub ebx, eax;
    mov eax, 3221225473;
    sbb ecx, edx;
    jc addback;

            neg ecx;
            and ecx, eax;
            sub ebx, ecx;

    sub ebx, eax;
    sbb edx, edx;
    and eax, edx;
            addback:
    add eax, ebx;          
    };  
}

[EDIT]Spektre,根据您的反馈,我将 modadd 和 modsub 更改回原来的状态。我还意识到我对递归 NTT 函数做了一些不应该做的更改。

[EDIT2]删除了不需要的 if 语句和按位函数。

[EDIT3]添加了新的 modmul 内联汇编。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

模块化算术和 NTT(有限域 DFT)优化 的相关文章

  • MVC 重定向到没有控制器的视图

    希望应该是一个简单的 我创建了一个通用错误视图 当整个站点的操作方法内发生异常时 我想显示该视图 我创建了一个部分页面 所有导航都位于其中 因此我不需要在此视图上使用控制器 那么如何从控制器内的操作方法重定向到它 像这样的东西 HttpPo
  • SetWindowsHookEx 函数返回 NULL

    我正在研究 DLL 注入 但收到错误如下 挂接进程失败 87 参数不正确 目标进程和dll都是64位的 注入代码为 BOOL HookInjection TCHAR target TCHAR dll name https msdn micr
  • WebClient读取错误页面的内容

    我有一个加载页面内容的应用程序 我使用 WebClient 类 即使服务器返回 404 500 等错误 我也需要检索内容 我需要这样的东西 WebClient wc new WebClient string pageContent try
  • Android NDK C++“wstring”支持

    我有用 C 编写的源代码 lib 现在我想在 Android NDK 项目 NDK 6 中编译并使用相同的源代码 lib 我能够编译大多数 C 文件 除了基于 std wstring 的功能 在 Application mk 中 当我指定时
  • 如何使用汇编获取BIOS时间?

    我正在从头开始实现一个小型操作系统 用于教育目的 现在 我想使用汇编来获取 BIOS 时间 我对此进行了很多搜索 但找不到任何代码示例来执行此操作 如果有人可以提供任何参考或代码示例或与此相关的任何内容 我将非常感激 See 时钟中断 1a
  • rand() 播种与 time() 问题

    我很难弄清楚如何使用 rand 并使用 Xcode 用 time 为其播种 我想生成 0 到 1 之间的随机十进制数 该代码为我提供了元素 1 和 2 看似随机的数字 但元素 0 始终在 0 077 左右 有什么想法吗 我的代码是 incl
  • 如何在 Windows 窗体中运行屏幕保护程序作为其背景?

    如何在 Windows 窗体中运行屏幕保护程序作为其背景 用户还可以在屏幕保护程序运行时与表单控件进行交互 为什么这个 我们有一个案例 需要在用户时运行 Windows Bubbles 屏幕保护程序 可以继续与表单控件交互吗 您可以使用以下
  • F10键没被抓住

    I have a Windows Form and there overriden ProcessCmdKey However this works with all of the F Keys except for F10 I am tr
  • 在通过网络发送之前压缩位图

    我正在尝试通过网络发送位图屏幕截图 因此我需要在发送之前对其进行压缩 有一个库或方法可以做到这一点吗 当您将图像保存到流时 您have选择一种格式 几乎所有位图格式 bmp gif jpg png 都使用一种或多种压缩形式 因此 只需选择适
  • 使用反射获取基类的受保护属性值

    I would like to know if it is possible to access the value of the ConfigurationId property which is located in the base
  • 特定设备的不同字体大小

    我目前正在开发通用应用程序 我需要分别处理移动设备和桌面的文本框字体大小 我找到了一些方法 但都不能解决问题 使用 VisualStateManager 和 StateTrigger 为例
  • 使用 openssl 检查服务器安全协议

    我有一个框架应用程序 它根据使用方式连接到不同的服务器 对于 https 连接 使用 openssl 我的问题是 我需要知道我连接的服务器是否使用 SSL 还是 TLS 以便我可以创建正确的 SSL 上下文 目前 如果我使用错误的上下文尝试
  • 为什么重载方法在 ref 仅符合 CLS 方面有所不同

    公共语言规范对方法重载非常严格 仅允许根据其参数的数量和类型来重载方法 如果是泛型方法 则根据其泛型参数的数量进行重载 根据 csc 为什么此代码符合 CLS 无 CS3006 警告 using System assembly CLSCom
  • 为什么WCF中不允许方法重载?

    假设这是一个ServiceContract ServiceContract public interface MyService OperationContract int Sum int x int y OperationContract
  • 快速查询最新记录的方法?

    我有一张这样的表 USER PLAN START DATE END DATE 1 A 20110101 NULL 1 B 20100101 20101231 2 A 20100101 20100505 在某种程度上 如果END DATE i
  • 在 .NET 中记录 StackOverflowException

    最近 我的 NET 应用程序 asp net 网站 中出现了堆栈溢出异常 我之所以知道该异常是因为它出现在我的 EventLog 中 我知道 StackOverflow 异常无法被捕获或处理 但是有没有办法在它杀死您的应用程序之前记录它 我
  • 展开路径中具有环境变量的文件名

    最好的扩张方式是什么 MyPath filename txt to home user filename txt or MyPath filename txt to c Documents and settings user filenam
  • 你能解释一下这个C++删除问题吗?

    我有以下代码 std string F WideString ws GetMyWideString std string ret StringUtils ConvertWideStringToUTF8 ws ret return ret W
  • 为什么C语言中可以使用多个分号?

    在 C 中我可以执行以下操作 int main printf HELLO WORLD 它有效 这是为什么 我个人的想法 分号是一个 NO OPERATION 来自维基百科 指示符 拥有一大串分号与拥有一个分号并告诉 C 语句已结束具有相同的
  • 如何将 CSV 文件读入 .NET 数据表

    如何将 CSV 文件加载到System Data DataTable 根据CSV文件创建数据表 常规 ADO net 功能是否允许这样做 我一直在使用OleDb提供者 但是 如果您正在读取具有数值的行 但希望将它们视为文本 则会出现问题 但

随机推荐