发布于 ,更新于 ,文章内容可能已经过时

同余方程组

同余方程组形式如下:

中国剩余定理

原理

本算法用于所有 两两互质的情况。

实际上类似构造,先把 累积起来

先考虑如何构造出一个同余方程的特解,设 𝑀𝑖mod𝑛𝑖 意义下的逆元,想到 𝑀𝑖𝑀1𝑖 =1(mod𝑛𝑖),于是有 𝑥 =𝑎𝑖𝑀𝑖𝑀1𝑖 满足方程 𝑥 𝑎𝑖(mod𝑛𝑖)
这实际上就是上面定义 𝑀𝑖 为质数的原因:需要保证有逆元存在。

至于解的合并,把 𝑥 相加即可。即 𝑘𝑖=1𝑎𝑖𝑀𝑖𝑀1𝑖。通解即为 𝑛𝑀 +𝑘𝑖=1𝑎𝑖𝑀𝑖𝑀1𝑖,𝑛 。为什么能这样合并?因为对于 𝑖,𝑗,𝑖 𝑗 来说,必然存在 𝑛𝑗 𝑀𝑖,因为 𝑛𝑖 两两互质,𝑀𝑖 的因子包含 𝑛𝑗,因此不会对已经推出来的解造成任何影响。

通解方面:容易想到,加减 𝑛𝑀 也不会对解的可行性造成任何影响。

实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
llint ex_gcd(llint u, llint v, llint& x, llint& y) {
if (!v) {
x = 1, y = 0;
return u;
}

llint g = ex_gcd(v, u % v, x, y);
llint temp = x;
x = y;
y = temp - u / v * y;
return g;
}

llint crt() {
llint tmp = 1, ans = 0;
for (int i = 1; i <= n; ++ i) tmp *= b[i];
for (int i = 1; i <= n; ++ i) {
llint m = tmp / b[i], x, y;
ex_gcd(m, a[i], x, y);
ans = (ans + n[i] * m * x % tmp) % tmp;
}
return (ans % tmp + tmp) % tmp;
}
复制

扩展中国剩余定理/同余方程合并

原理

𝑛𝑖 不互质的时候,上面做法就不再适用(不能保证逆元存在),此时我们考虑逐一合并同余方程组。
方法是使用 ex_gcd

假设方程组只有两个方程

{𝑥𝑎1(mod𝑛1)𝑥𝑎2(mod𝑛2)

可以写作

{𝑥=𝑎1+𝑘1𝑛1𝑥=𝑎2+𝑘2𝑛2𝑎1 +𝑘1𝑛1 =𝑎2 +𝑘2𝑛2
移项,得 𝑘1𝑛1 𝑘2𝑛2 =𝑎2 𝑎1 ,接下来按照扩展欧几里得的思路推导。
𝑔 =gcd(𝑛1,𝑛2),若使用扩展欧几里得算法,可以求出 𝑛1𝑘1 +𝑛2(𝑘2) =gcd(𝑛1,𝑛2) 的一组特解 𝑘1,𝑘2,即 𝑛1𝑘1+𝑛2(𝑘2)=𝑔(𝑘1𝑎2𝑎1𝑔)𝑛1+(𝑘2𝑎2𝑎1𝑔)𝑛2=𝑎2𝑎1 抽出刚才表示 𝑥 的方程组的一条,得到 𝑥0 =𝑎1 +𝑘1𝑛1𝑘1 的通解为 𝑘1 +𝑛2𝑔𝑝,𝑝 ,则 𝑥=𝑎1+(𝑘1+𝑛2𝑔𝑝)𝑛1=𝑛1𝑘1+𝑎1+𝑛1𝑛2𝑔𝑝=𝑛1𝑘1+𝑎1+lcm(𝑛1,𝑛2)𝑝 显然我们可以把这个等式写成同余式的形式(没必要)

𝑥0𝑥(modlcm(𝑛1,𝑛2))𝑥𝑛1𝑘1+𝑎1(modlcm(𝑛1,𝑛2)) lcm(𝑛1,𝑛2),𝑛1𝑘1 +𝑎1 都是已知的,因此方程可解,合并完成。

实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#define N 100005

i128 ex_gcd(i128 a, i128 b, i128& x, i128& y) {
if (b) {
i128 tmp;
i128 g = ex_gcd(b, a % b, x, y);
tmp = x;
x = y;
y = tmp - a / b * y;
return g;
} else {
x = 1, y = 0;
return a;
}
}

i64 n;
i128 a[N], m[N]; // a === x (mod m)

i128 ex_crt() {
i128 k1, k2, a1, m1, a2, m2, c, gc, g;
a1 = a[1], m1 = m[1];
for (i64 i = 2; i <= n; ++i) {
a2 = a[i], m2 = m[i], c = ((a2 - a1) % m2 + m2) % m2;
g = ex_gcd(m1, m2, k1, k2);
gc = m2 / g;
k1 = k1 % gc * ((c / g) % gc) % gc;
a1 = k1 * m1 + a1;
m1 *= gc;
a1 = (a1 % m1 + m1) % m1;
}
return a1;
}

int main() {
read(n);
for (i64 i = 1; i <= n; ++i) {
read(m[i], a[i]);
}

write(ex_crt(), '\n');
return 0;
}
复制