题目链接:https://www.luogu.com.cn/problem/U104688。
前言
这个题并没有什么很难的东西,我就主要讲讲怎么降低常数。
正文
首先众所周知双曲三角函数是可以化成几个基本初等函数的运算的,具体来说:
$$ \sinh(x)=\frac{e^x-e^{-x}}{2} $$
$$ \cosh(x)=\frac{e^x+e^{-x}}{2} $$
$$ \operatorname{sech}(x)=\frac{2}{e^x+e^{-x}} $$
接下来就很容易了,直接按照这个算即可。
常数优化
但是你会发现,有的人写的多项式跑得就是比你快,而且快好几倍!
现在我就介绍一些比较常用的优化常数的方法:
优化开关
O2 比什么都管用,先打开再说。
取模优化
众所周知取模运算慢得出奇,如果能优化取模肯定是会快的。
加法取模
加法取模可以用在结果小于 $2\times\text{MOD}$ 的情况下,具体来说基本就是两个数相加减的时候。
inline void upd(int &x) {
x += x >> 31 & MOD;
}
上面的代码等价于下面的代码:
inline void upd(int &x) {
if (x < 0) x += MOD;
}
但是第一份代码运用了位运算,速度十分可观。
它的原理是对于一个「32 位有符号整数」,负数右移 $31$ 位会变成 $-1$,二进制位下就是全 $1$;而非负数右移 $31$ 位会变成 $0$。
使用的时候就是两个数相加之后减去 $\text{MOD}$,再将结果 upd
一下。
乘法取模
乘法取模要复杂一些,一般不常用。
有兴趣可以去 Min_25 的博客了解一下:地址。
预处理原根
这是个大优化,有的时候能让你的常数减小到原来的 $\dfrac 25$!
一般写 NTT 的时候每次要根据长度重新处理蝴蝶变换的数组,做 NTT 的过程中还要现场算原根的各次幂。这部分要做大量的乘法和取模运算,如果能预处理出来,只做一次,常数就能有极大优化!
另外有的人预处理的时候数组大小是 $O(n\log n)$ 的,其实有一维并不需要,因为长度总是 $2$ 的整数次幂,只要按照最大的长度预处理即可。
其它优化
有时候你需要将数组一段清空或者移到另一个数组中,可以使用 cstring
库里的 memset
和 memcpy
完成。但我感觉优化效果不大,所以就没用。
最后
还有一些从过程上进行的比较复杂的优化,我也不会所以就不讲了,有兴趣可以去论文哥的博客了解一下:地址。
贴一下此题代码,仅供参考:
#include <cstdio>
#include <algorithm>
using std::reverse;
#define MOD 998244353
#define N 262210
typedef long long i64;
typedef unsigned long long u64;
inline void upd(int&x) {
x+=x>>31&MOD;
}
inline int pow(int a,int b) {
int ans(1);
while(b) {
ans=b&1?(i64)ans*a%MOD:ans;
a=(i64)a*a%MOD; b>>=1;
} return ans;
}
int inv[N];
inline void pre(int n) {
inv[1]=1;
for(int i=2;i<=n;++i)
inv[i]=(i64)(MOD-MOD/i)*inv[MOD%i]%MOD;
}
int lmt(1),r[N],w[N],qaq;
inline int getLen(int n) {
return 1<<(32-__builtin_clz(n));
}
inline void init(int n) {
int l(0);
while(lmt<=n) lmt<<=1,++l;
for(int i=1;i<lmt;++i)
r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
int wn(pow(3,(MOD-1)>>l));
w[lmt>>1]=1;
for(int i=(lmt>>1)+1;i<lmt;++i)
w[i]=(i64)w[i-1]*wn%MOD;
for(int i=(lmt>>1)-1;i;--i)
w[i]=w[i<<1];
lmt=l;
}
inline void DFT(int*a,int l) {
static u64 tmp[N];
int u(lmt-__builtin_ctz(l)),t;
for(int i=0;i<l;++i)
tmp[r[i]>>u]=a[i];
for(int i=1;i<l;i<<=1)
for(int j=0,step=i<<1;j<l;j+=step)
for(int k=0;k<i;++k) {
t=tmp[i+j+k]*w[i+k]%MOD;
tmp[i+j+k]=tmp[j+k]+MOD-t;
tmp[j+k]+=t;
}
for(int i=0;i<l;++i)
a[i]=tmp[i]%MOD;
}
inline void IDFT(int*a,int l) {
reverse(a+1,a+l); DFT(a,l);
int bk(MOD-(MOD-1)/l);
for(int i=0;i<l;++i)
a[i]=(i64)a[i]*bk%MOD;
}
void getInv(int*a,int*b,int deg) {
if(deg==1) b[0]=pow(a[0],MOD-2);
else {
static int tmp[N];
getInv(a,b,(deg+1)>>1);
int l(getLen(deg<<1));
for(int i=0;i<l;++i)
tmp[i]=i<deg?a[i]:0;
DFT(tmp,l); DFT(b,l);
for(int i=0;i<l;++i) {
qaq=b[i];
b[i]=2ll-(i64)qaq*tmp[i]%MOD;
upd(b[i]); b[i]=(i64)b[i]*qaq%MOD;
} IDFT(b,l);
for(int i=deg;i<l;++i)
b[i]=0;
}
}
inline void getDer(int*a,int*b,int deg) {
for(int i=0;i+1<deg;++i)
b[i]=(i64)a[i+1]*(i+1)%MOD;
b[deg-1]=0;
}
inline void getInt(int*a,int*b,int deg) {
for(int i=1;i<deg;++i)
b[i]=(i64)a[i-1]*inv[i]%MOD;
b[0]=0;
}
inline void getLn(int*a,int*b,int deg) {
static int tmp[N];
getInv(a,tmp,deg);
getDer(a,b,deg);
int l(getLen(deg<<1));
DFT(tmp,l); DFT(b,l);
for(int i=0;i<l;++i)
tmp[i]=(i64)tmp[i]*b[i]%MOD;
IDFT(tmp,l);
getInt(tmp,b,deg);
for(int i=0;i<l;++i)
tmp[i]=0;
for(int i=deg;i<l;++i)
b[i]=0;
}
void getExp(int*a,int*b,int deg) {
if(deg==1) b[0]=1;
else {
static int tmp[N];
getExp(a,b,(deg+1)>>1);
getLn(b,tmp,deg);
int l(getLen(deg<<1));
for(int i=0;i<l;++i) {
if(i<deg) {
tmp[i]=a[i]-tmp[i];
upd(tmp[i]);
} else tmp[i]=0;
} ++tmp[0];
DFT(tmp,l); DFT(b,l);
for(int i=0;i<l;++i)
b[i]=(i64)b[i]*tmp[i]%MOD;
IDFT(b,l);
for(int i=deg;i<l;++i)
b[i]=tmp[i]=0;
}
}
int n,type,f[N];
int xp[N],ixp[N],sm[N],dc[N];
int Sinh[N],Cosh[N],Sech[N];
int main() {
scanf("%d%d",&n,&type);
pre(n); init(n<<1);
for(int i=0;i<n;++i)
scanf("%d",f+i);
getExp(f,xp,n);
getInv(xp,ixp,n);
if(type&1) {
for(int i=0;i<n;++i) {
dc[i]=xp[i]-ixp[i]; upd(dc[i]);
Sinh[i]=(i64)dc[i]*inv[2]%MOD;
}
for(int i=0;i<n;++i)
printf("%d ",Sinh[i]);
putchar('\n');
}
if(type&2) {
for(int i=0;i<n;++i) {
sm[i]=xp[i]+ixp[i]-MOD; upd(sm[i]);
Cosh[i]=(i64)sm[i]*inv[2]%MOD;
}
for(int i=0;i<n;++i)
printf("%d ",Cosh[i]);
putchar('\n');
}
if(type&4) {
if(type&2) {
getInv(Cosh,Sech,n);
for(int i=0;i<n;++i)
printf("%d ",Sech[i]);
} else {
for(int i=0;i<n;++i) {
sm[i]=xp[i]+ixp[i]-MOD;
upd(sm[i]);
}
getInv(sm,Sech,n);
for(int i=0;i<n;++i) {
Sech[i]=(Sech[i]<<1)-MOD;
upd(Sech[i]);
}
for(int i=0;i<n;++i)
printf("%d ",Sech[i]);
}
}
return 0;
}