组合数

定义:从 n 个不同元素中取出 m(m ≤ n)个元素的所有组合的个数,叫做 n 个不同元素中取出 m(m≤n)个元素的组合数。用符号 C(n,m) 表示

组合数公式:C(n,m) = n! / (m! * (n-m)!)
性质:C(n,m) = C(n,m-n)
递推公式:C(n,m) = C(n-1,m-1) + C(n-1,m)

排列数

定义:从 n 个不同的元素中任取 m(m ≤ n)个元素的所有排列的个数,叫做从 n 个不同的元素中取出 m(m≤n)个元素的排列数。用符号 A(n,m) 表示
排列与元素的顺序有关,组合与顺序无关

排列数公式:A(n,m) = n * (n-1) * ... * (n-m+1) = n! / (n-m)!

简化取模

计算过程中可能出现多次取模的情况,我们可以将计算操作用函数代替

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
ll quick_pow(ll a, ll b) {
ll ans = 1, base = a;
while (b) {
if (b & 1) ans = ans * base % mod;
base = base * base % mod;
b >>= 1;
}
return ans % mod;
}

inline ll mul(ll a, ll b) {
a %= mod, b %= mod;
return (a * b) % mod;
}

inline ll pls(ll a, ll b) {
a %= mod, b %= mod;
return (a + b) % mod;
}

inline ll mine(ll a, ll b) {
a %= mod, b %= mod;
return (a + mod - b) % mod;
}

inline ll dive(ll a, ll b) {
return mul(a, quick_pow(b, mod - 2)); //费马小定理求逆元
}

模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include <bits/stdc++.h>
// // #pragma GCC optimize(2)
// // #pragma G++ optimize(2)
// #define online_judge
#define endl "\n"
#define fi first
#define se second
#define pb push_back
#define all(x) x.begin(), x.end()
#define rep(i, x, y) for (auto i = (x); i != (y + 1); ++i)
#define dep(i, x, y) for (auto i = (x); i != (y - 1); --i)
#define debug(x) cout << "debug: " << x << endl;
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
const ll maxn = 200000, mod = 1e9 + 7;

ll inv[maxn + 5], fac[maxn + 5], invfac[maxn + 5]; //处理num的逆元,num的阶乘与阶乘逆元

void init(ll n) { //预处理阶乘与逆元
inv[1] = 1;
for (int i = 2; i <= n; i++)
inv[i] = inv[mod % i] * (mod - mod / i) % mod; //线性求逆元
fac[0] = 1, invfac[0] = 1;
for (int i = 1; i <= n; i++) {
fac[i] = fac[i - 1] * i % mod;
invfac[i] = invfac[i - 1] * inv[i] % mod;
}
}

ll C(ll n, ll m) { //组合数
if (n < m)
return 0;
return fac[n] * invfac[n - m] % mod * invfac[m] % mod;
}

ll lucas(ll n, ll m) { //n或m大于mod且mod为素数的时候
if (!m) return 1;
return C(n % mod, m % mod) * lucas(n / mod, m / mod) % mod;
}

ll A(ll n, ll m) {
return fac[n] * invfac[n - m] % mod;
}

void solve() {
init(maxn);
ll n, m;
cin >> n >> m;
if (n >= mod)
cout << lucas(n, m) << endl;
else
cout << C(n, m) << endl;
}

signed main() {
ios::sync_with_stdio(false), cin.tie(0);
#ifndef online_judge
freopen("IO\\in.txt", "r", stdin);
freopen("IO\\out.txt", "w", stdout);
clock_t start, end;
start = clock();
#endif
solve();
#ifndef online_judge
end = clock();
cout << endl
<< "Runtime: " << (double)(end - start) / CLOCKS_PER_SEC << "s\n";
#endif
return 0;
}