YZOJ P3752 序列求差问题
时间限制:2000MS 内存限制:131072KB
出题人:Night
难度:\(6.0\)
-
题目描述
有一个序列 \(x_1,x_2,\cdots,x_n\) 。
求有多少个从 \(1,2,\cdots,n\) 中取三个元素的排列 \((a,b,c)\) 满足 \(x_a=x_b-x_c\) 。
由于是排列,所以 \((a,b,c)\) 与 \((c,b,a)\) 视为两组解。
-
输入格式
第一行一个整数 \(n\) 表示序列长度。
第二行为 \(n\) 个整数表示序列里的 \(n\) 个数。
-
输出格式
一行一个正整数,表示答案。
-
样例输入
1 2 |
10 1 6 2 9 5 9 2 5 0 5 |
-
样例输出
1 |
26 |
-
数据规模与约定
对于 \(20\%\) 的数据,\(1 \leq n \leq 500\);
对于 \(45\%\) 的数据,\(1 \leq n \leq 5000\);
对于 \(100\%\) 的数据,\(1 \leq n \leq 1000000\),\(0 \leq \left|x_i\right| \leq 100000\) 。
首先这个东西和 \(x_a+x_b=x_c\) 是等价的。
对于 \(n \leq 500\) 的数据,显然可以 \(O(n^3)\) 枚举 \((a,b,c)\) 暴力判断。
对于 \(n \leq 5000\) 的数据,可以记桶 \(cnt_i\) 表示 \(i=x_j\) 的不同 \(j\) 的个数,只要枚举 \((a,b)\) ,答案加上 \(cnt_{x_a+x_b}\) 即可,\(O(n^2)\)。
然后对于 \(100\%\) 的数据,出题人就发现 \(x_a+x_b=x_c\) 这个东西很像多项式乘法,因为可以把 \(cnt_{x_a}\) 和 \(cnt_{x_b}\) 贡献到 \(cnt_{x_c}\) 上。
所以就把 \(cnt\) 作为一个多项式,与它自己相乘,得到的就是答案了。
多项式乘法用 FFT 优化至 \(O(nlogn)\) 。
细节:
1,因为不能取两个相同的,所以 \(ans_{x_i+x_i}\) 要减一。
2,还有要特判一下 \(0\) 的情况,所以答案要减去 \(2 \times m \times (n-1)\) (其中 \(m\) 为 \(x\) 中 \(0\) 的个数)。
3,有负数,所以要整体偏移,\(x_a+diff+x_b+diff=x_c+2 \times diff\) ,注意多项式乘法的答案意义也有所变化。
4,因为答案可能超出 \(int\) ,所以不能使用 FWT/NTT 求解,而且
double 会被卡精度,所以要换成
long double 继续,正确的做法是考虑到答案肯定不超过 \(A^3_{1000000}=999997000002000000\),可以使用 \(998244353\) 和 \(1004535809\) 两个模数分别 FNT 求一遍,然后再 CRT(中国剩余定理) 合并计算结果。
答案为 \(\sum ans_{x_i+2 \times diff}-2 \times m \times (n-1)\) 。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
#include <cstdio> #include <cstdlib> #include <cstring> #include <climits> #include <cmath> #define DIFF 100001 #define _max(_a_,_b_) ((_a_)>(_b_)?(_a_):(_b_)) inline int getnum() { register char c=0; register bool neg=false; while(!(c>='0' && c<='9')) c=getchar(),neg|=(c=='-'); register int a=0; while(c>='0' && c<='9') a=a*10+c-'0',c=getchar(); return (neg?-1:1)*a; } struct _cm { // a+bi long double a,b; _cm(long double na=0,long double nb=0){a=na,b=nb;} _cm operator + (const _cm&o)const { return (_cm){a+o.a,b+o.b}; } _cm operator - (const _cm&o)const { return (_cm){a-o.a,b-o.b}; } _cm operator * (const _cm&o)const { return (_cm){a*o.a-b*o.b,a*o.b+b*o.a}; } }c[805050]; int flip[805050]; inline void FFT(register int len,_cm c[],bool inverse=false) { _cm t; for(register int i=0;i<len;i++) if(i<flip[i]) t=c[flip[i]],c[flip[i]]=c[i],c[i]=t; for(register int step=2;step<=len;step<<=1) { _cm wn=(_cm){std::cos((long double)2*M_PI/step*(inverse ? -1 : 1)), \ std::sin((long double)2*M_PI/step*(inverse ? -1 : 1))}; for(register int k=0;k<len;k+=step) { _cm w=(_cm){1}; for(register int i=k;i < k+(step>>1);i++) { register int j=i+(step>>1); t=c[j]*w; c[j]=c[i]-t; c[i]=c[i]+t; w=w*wn; } } } if(inverse) for(register int i=0;i<len;i++) c[i].a/=len; } int a[1050505]; long long cnt[805050]; int main() { register int N=getnum(),mxa=0,cnt0=0; for(register int i=1;i<=N;i++) { a[i]=getnum()+DIFF; mxa=_max(a[i],mxa),cnt0+=(a[i]==DIFF); cnt[a[i]]++; } mxa<<=1; register int len=1,dig=0; while(len<mxa) len<<=1,dig++; for(register int i=0;i<len;i++) { c[i]=(_cm){(long double)((i<<1)<=mxa ? cnt[i] : 0)}; flip[i]=(flip[i>>1]>>1)|((i&1)<<(dig-1)); } //printf("--------------- A\n"); //for(register int i=0;i<len;i++) // printf("%d: %.2Lf+%.2Lfi\n",i,c[i].a,c[i].b); FFT(len,c); //printf("--------------- B\n"); //for(register int i=0;i<len;i++) // printf("%d: %.2Lf+%.2Lfi\n",i,c[i].a,c[i].b); for(register int i=0;i<len;i++) c[i]=c[i]*c[i]; FFT(len,c,true); //printf("--------------- C\n"); //for(register int i=0;i<len;i++) // printf("%d: %.2Lf+%.2Lfi\n",i,c[i].a,c[i].b); for(register int i=1;i<=mxa;i++) cnt[i]=(long long)(c[i].a+0.49999); long long ans=-(long long)cnt0*(N-1)<<1; for(register int i=1;i<=N;i++) cnt[a[i]<<1]--; for(register int i=1;i<=N;i++) ans+=cnt[a[i]+DIFF]; printf("%lld\n",ans); return 0; } |