1、数据结构
(1)线段树单点更新
#include <cstdio> #include <cstdlib> #include <algorithm> #include <cstring> #include <cmath> #define MAXN (1<<19) using namespace std; int segTree[MAXN]; void update(int i, int lft, int rht, int index, int deta){ if(lft == rht){ segTree[i] = deta; return ; } int mid = (lft + rht) >> 1; if(index <= mid) update(i<<1, lft, mid, index, deta); else update(i<<1|1, mid+1, rht, index, deta); segTree[i] = max(segTree[i<<1], segTree[i<<1|1]); } int query(int i, int lft, int rht, int qlft, int qrht){ if(qlft <= lft && rht <= qrht) return segTree[i]; int mid = (lft+rht)>>1; int ans = 0; if(qlft <= mid) ans = max(ans, query(i<<1, lft, mid, qlft, qrht)); if(qrht > mid) ans = max(ans, query(i<<1|1, mid+1, rht, qlft, qrht)); return ans; } int main(){ char cmd[5]; int n, m; while(scanf("%d%d", &n, &m) != EOF){ memset(segTree, 0, sizeof(segTree)); int a, b; for(int i = 0; i < n; i++){ scanf("%d", &a); update(1, 1, n, i+1, a); } while(m--){ scanf("%s %d %d", cmd, &a, &b); if(cmd[0] == 'Q') printf("%d\n", query(1, 1, n, a, b)); else update(1, 1, n, a, b); } } return 0; }
(2)线段树区间更新
#include <cstdio> const int N = 1<<18; typedef long long ll; ll sum[N], add[N]; void pushDown(const int& i, const int& lft, const int& rht) { if(add[i]){ add[i<<1] += add[i]; add[i<<1|1] += add[i]; int mid = (lft+rht)>>1; sum[i<<1] += add[i]*(mid-lft+1); sum[i<<1|1] += add[i]*(rht-mid); add[i] = 0; } } void update(int i, int lft, int rht, const int& qlft, const int& qrht,const int& addval) { if(qlft > rht || qrht < lft) return ; if(qlft <= lft && qrht >= rht){ sum[i] += addval*(rht-lft+1); add[i] += addval; } else{ pushDown(i, lft, rht); int mid = (lft + rht) >> 1; update(i<<1, lft, mid, qlft, qrht, addval); update(i<<1|1, mid+1, rht, qlft, qrht, addval); sum[i] = sum[i<<1] + sum[i<<1|1]; } } ll query(int i, int lft, int rht, const int& qlft, const int& qrht) { if(qlft > rht || qrht < lft) return 0; if(qlft <= lft && qrht >= rht) return sum[i]; pushDown(i, lft, rht); int mid = (lft + rht) >> 1; return query(i<<1, lft, mid, qlft, qrht) + query(i<<1|1, mid+1, rht, qlft, qrht); } int main() { int n, q, lft, rht; ll delta; char cmd[5]; scanf("%d%d", &n, &q); for(int i = 1; i <= n; i++) { scanf("%lld", &delta); update(1, 1, n, i, i, delta); } while(q--) { scanf("%s", cmd); if(cmd[0] == 'Q'){ scanf("%d%d", &lft, &rht); printf("%lld\n", query(1, 1, n, lft, rht)); }else{ scanf("%d%d%lld", &lft, &rht, &delta); update(1, 1, n, lft, rht, delta); } } return 0; }
(3)树状数组
#include <cstdio> #include <cstring> #define MAXN 50000 + 10 using namespace std; int Tree[MAXN]; int lowbit(int index){ return index & (-index); } int sum(int index){ int res = 0; while(index > 0){ res += Tree[index]; index -= lowbit(index); } return res; } void update(int index, int value, int n){ while(index <= n){ Tree[index] += value; index += lowbit(index); } } int main() { char cmd[10]; int a, b; int T, n, v; scanf("%d", &T); for(int nCase = 1; nCase <= T; nCase++){ memset(Tree, 0, sizeof(Tree)); printf("Case %d:\n", nCase); scanf("%d", &n); for(int i = 1; i <= n; i++){ scanf("%d", &v); update(i, v, n); } do{ scanf("%s", cmd); if(!strcmp(cmd, "Query")){ scanf("%d%d", &a, &b); printf("%d\n", sum(b) - sum(a-1)); } else if(!strcmp(cmd, "Add")){ scanf("%d%d", &a, &b); update(a, b, n); } else if(!strcmp(cmd, "Sub")){ scanf("%d%d", &a, &b); update(a, -b, n); } }while(strcmp(cmd, "End")); } return 0; }
(4)归并树
#include <cstdio> #include <algorithm> #define MAXN (100000) #define DEEP (20) using namespace std; int sorted[DEEP][MAXN], a[MAXN]; void build(int deep, int lft, int rht){ if(lft == rht){ sorted[deep][lft] = a[lft]; return ; } int mid = (lft + rht) >> 1; build(deep+1, lft, mid); build(deep+1, mid+1, rht); int p = lft, q = mid+1, k = lft; while(p <= mid && q <= rht){ if(sorted[deep+1][p] <= sorted[deep+1][q]) sorted[deep][k++] = sorted[deep+1][p++]; else sorted[deep][k++] = sorted[deep+1][q++]; } while(p <= mid) sorted[deep][k++] = sorted[deep+1][p++]; while(q <= rht) sorted[deep][k++] = sorted[deep+1][q++]; } int query(int deep, int lft, int rht, int qlft, int qrht, int key){ if(qrht < lft || qlft > rht) return 0; if(qlft <= lft && rht <= qrht) return lower_bound(&sorted[deep][lft], &sorted[deep][rht]+1, key) - &sorted[deep][lft]; int mid = (lft + rht) >> 1; return query(deep+1, lft, mid, qlft, qrht, key) + query(deep+1, mid+1, rht, qlft, qrht, key); } int solve(int n, int qlft, int qrht, int k){ int low = 1, high = n+1; while(low+1 < high){ int mid= (low + high) >> 1; int cnt = query(0, 1, n, qlft, qrht, sorted[0][mid]); if(cnt <= k) low = mid; else high = mid; } return sorted[0][low]; } int main(){ int n, m; scanf("%d%d", &n, &m); for(int i = 1; i <= n; i++){ scanf("%d", &a[i]); } build(0, 1, n); while(m--){ int qlft, qrht, k; scanf("%d%d%d", &qlft, &qrht, &k); printf("%d\n", solve(n, qlft, qrht, k-1)); } return 0; }
(5)划分树
#include <cstdio> #include <algorithm> using namespace std; const int MAXN = 100000 + 1; const int DEEP = 18; typedef struct{ int num[MAXN]; int cnt[MAXN]; }PartitionTree; PartitionTree tree[DEEP]; int sorted[MAXN]; void build(int deep, int lft, int rht){ if(lft == rht) return ; int mid = (lft + rht) >> 1; int key = sorted[mid]; int scnt = mid - lft + 1; for(int i = lft; i <= rht; ++i){ if(tree[deep].num[i] < key) --scnt; } int p = lft-1, r = mid; for(int i = lft, cnt = 0; i <= rht; ++i){ int num = tree[deep].num[i]; if(num < key || (num == key && scnt)){ if(num == key) --scnt; ++cnt; tree[deep+1].num[++p] = num; } else tree[deep+1].num[++r] = num; tree[deep].cnt[i] = cnt; } build(deep+1, lft, mid); build(deep+1, mid+1, rht); } int query(int deep, int lft, int rht, int qlft, int qrht, int k){ if(lft == rht) return tree[deep].num[lft]; int mid = (lft + rht) >> 1; int left = 0, sum_in_left = tree[deep].cnt[qrht]; if(lft != qlft){ left = tree[deep].cnt[qlft-1]; sum_in_left -= left; } if(sum_in_left >= k){ int new_qlft = lft + left; int new_qrht = new_qlft + sum_in_left - 1; return query(deep+1, lft, mid, new_qlft, new_qrht, k); } else{ int a = qlft - lft - left; int b = qrht - qlft - sum_in_left; int new_qlft = mid + 1 + a; int new_qrht = new_qlft + b; return query(deep+1, mid+1, rht, new_qlft, new_qrht, k-sum_in_left); } } int main(){ int n, m, qlft, qrht, k; scanf("%d%d", &n, &m); for(int i = 1; i <= n; i++){ scanf("%d", &sorted[i]); tree[0].num[i] = sorted[i]; } sort(sorted+1, sorted+1+n); build(0, 1, n); while(m--){ scanf("%d%d%d", &qlft, &qrht, &k); printf("%d\n", query(0, 1, n, qlft, qrht, k)); } return 0; }
2、图论
(1)并查集
int node[i]; //每个节点 int rank[i]; //树的高度 //初始化n个节点 void Init(int n){ for(int i = 0; i < n; i++){ node[i] = i; rank[i] = 0; } } //查找当前元素所在树的根节点(代表元素) int find(int x){ if(x == node[x]) return x; return node[x] = find(node[x]); //在第一次查找时,将节点直连到根节点 } //合并元素x, y所处的集合 void Unite(int x, int y){ //查找到x,y的根节点 x = find(x); y = find(y); if(x == y) return ; //判断两棵树的高度,然后在决定谁为子树 if(rank[x] < rank[y]){ node[x] = y; }else{ node[y] = x; if(rank[x] == rank[y]) rank[x]++: } } //判断x,y是属于同一个集合 bool same(int x, int y){ return find(x) == find(y); }
(2)最小生成树Kruskal
#include <cstdio> #include <cstdlib> #define MAXN 10000 + 10 using namespace std; int par[MAXN], Rank[MAXN]; typedef struct{ int a, b, price; }Node; Node a[MAXN]; int cmp(const void*a, const void *b){ return ((Node*)a)->price - ((Node*)b)->price; } void Init(int n){ for(int i = 0; i < n; i++){ Rank[i] = 0; par[i] = i; } } int find(int x){ int root = x; while(root != par[root]) root = par[root]; while(x != root){ int t = par[x]; par[x] = root; x = t; } return root; } void unite(int x, int y){ x = find(x); y = find(y); if(Rank[x] < Rank[y]){ par[x] = y; } else{ par[y] = x; if(Rank[x] == Rank[y]) Rank[x]++; } } //n为边的数量,m为村庄的数量 int Kruskal(int n, int m){ int nEdge = 0, res = 0; //将边按照权值从小到大排序 qsort(a, n, sizeof(a[0]), cmp); for(int i = 0; i < n && nEdge != m - 1; i++){ //判断当前这条边的两个端点是否属于同一棵树 if(find(a[i].a) != find(a[i].b)){ unite(a[i].a, a[i].b); res += a[i].price; nEdge++; } } //如果加入边的数量小于m - 1,则表明该无向图不连通,等价于不存在最小生成树 if(nEdge < m-1) res = -1; return res; } int main(){ int n, m, ans; while(scanf("%d%d", &n, &m), n){ Init(m); for(int i = 0; i < n; i++){ scanf("%d%d%d", &a[i].a, &a[i].b, &a[i].price); //将村庄编号变为0~m-1(这个仅仅只是个人习惯,并非必要的) a[i].a--; a[i].b--; } ans = Kruskal(n, m); if(ans == -1) printf("?\n"); else printf("%d\n", ans); } return 0; }
(3)最小生成树prim
【未优化版】
#include <cstdio> #include <vector> #define INF 0xfffffff #define MAXN 100 + 10 using namespace std; struct Vex{ int v, weight; Vex(int tv, int tw):v(tv), weight(tw){} }; vector<Vex> graph[MAXN]; bool inTree[MAXN]; int mindist[MAXN]; void Init(int n){ for(int i = 1; i <= n; i++){ mindist[i] = INF; inTree[i] = false; graph[i].clear(); } } int Prim(int s, int n){ int addNode, tempMin, tempVex ,ret = 0; //将顶点S加入集合Vnew inTree[s] = true; //初始化,各点到集合Vnew的距离, 数组mindist表示各点到集合Vnew的最小距离 for(unsigned int i = 0; i < graph[s].size(); i++) mindist[graph[s][i].v] = graph[s][i].weight; //因为还有n-1个点没有加入集合Vnew,所以还要进行n-1次操作 for(int NodeCount = 1; NodeCount <= n-1; NodeCount++){ tempMin = INF; //在还没有加入集合Vnew的点中查找距离集合Vnew最小的点 for(int i = 1; i <= n; i++){ if(!inTree[i] && mindist[i] < tempMin){ tempMin = mindist[i]; addNode = i; } } //将距离集合Vnew距离最小的点加入集合Vnew inTree[addNode] = true; //将新加入边的权值计入ret ret += tempMin; //更新还没有加入集合Vnew的点 到 集合Vnew的距离 for(unsigned int i = 0; i < graph[addNode].size(); i++){ tempVex = graph[addNode][i].v; if(!inTree[tempVex] && graph[addNode][i].weight < mindist[tempVex]){ mindist[tempVex] = graph[addNode][i].weight; } } } return ret; } int main(){ int n; int v1, v2, weight; while(scanf("%d", &n), n){ Init(n); for(int i = 0; i < n*(n-1)/2; i++){ scanf("%d%d%d", &v1, &v2, &weight); graph[v1].push_back(Vex(v2, weight)); graph[v2].push_back(Vex(v1, weight)); } printf("%d\n", Prim(1, n)); } return 0; }
【堆优化版】
#include <cstdio> #include <vector> #include <queue> #define INF 0xfffffff #define MAXN 100 + 10 using namespace std; struct Vex{ int v, weight; Vex(int tv = 0, int tw = 0):v(tv), weight(tw){} bool operator < (const Vex& t) const{ return this->weight > t.weight; } }; vector<Vex> graph[MAXN]; bool inTree[MAXN]; int mindist[MAXN]; void Init(int n){ for(int i = 1; i <= n; i++){ mindist[i] = INF; inTree[i] = false; graph[i].clear(); } } int Prim(int s, int n){ priority_queue<Vex> Q; Vex temp; //res用来记录最小生成树的权值之和 int res = 0; //将s加入集合Vnew,并更新与点s相连接的各点到集合Vnew的距离 inTree[s] = true; for(unsigned int i = 0; i < graph[s].size(); i++){ int v = graph[s][i].v; if(graph[s][i].weight < mindist[v]){ mindist[v] = graph[s][i].weight; //更新之后,加入堆中 Q.push(Vex(v, mindist[v])); } } while(!Q.empty()){ //取出到集合Vnew距离最小的点 temp = Q.top(); Q.pop(); int addNode = temp.v; if(inTree[addNode]) continue; inTree[addNode] = true; res += mindist[addNode]; //更新到集合Vnew的距离 for(unsigned int i = 0; i < graph[addNode].size(); i++){ int tempVex = graph[addNode][i].v; if(!inTree[tempVex] && mindist[tempVex] > graph[addNode][i].weight){ mindist[tempVex] = graph[addNode][i].weight; Q.push(Vex(tempVex, mindist[tempVex])); } } } return res; } int main(){ int n; int v1, v2, weight; while(scanf("%d", &n), n){ Init(n); for(int i = 0; i < n*(n-1)/2; i++){ scanf("%d%d%d", &v1, &v2, &weight); graph[v1].push_back(Vex(v2, weight)); graph[v2].push_back(Vex(v1, weight)); } printf("%d\n", Prim(1, n)); } return 0; }
(4)最短路bellman-ford
【未优化版】
#include <cstdio> #include <cstring> #define INF 0xfffffff #define MAXN (100 + 10) using namespace std; struct edge{ int from, to; edge(int f = 0, int t = 0) : from(f), to(t){} }; edge es[MAXN*MAXN]; int cost[MAXN]; bool graph[MAXN][MAXN]; int d[MAXN]; //判断图是否联通 void Floyd(int n){ for(int i = 1; i <= n; i++){ for(int k = 1; k <= n; k++){ for(int j = 1; j <= n; j++){ if(!graph[i][j]) graph[i][j] = graph[i][k] && graph[k][j]; } } } } bool bellman_ford(int s, int V, int E){ for(int i = 0; i <= V; i++) d[i] = -INF; d[s] = 100; //重复对每一条边进行松弛操作 for(int k = 0; k < V-1; k++){ for(int i = 0; i < E; i++){ edge e = es[i]; //松弛操作 if(d[e.to] < d[e.from] + cost[e.to] && d[e.from] + cost[e.to] > 0){ d[e.to] = d[e.from] + cost[e.to]; } } } //检查负权环 for(int i = 0; i < E; i++){ edge e = es[i]; if(d[e.to] < d[e.from] + cost[e.to] && graph[e.to][V] && d[e.from] + cost[e.to] > 0) return true; } return d[V] > 0; } int main(){ int n, m, cnt, vex; while(scanf("%d", &n), n != -1){ memset(graph, false, sizeof(graph)); cnt = 0; for(int i = 1; i <= n; i++){ scanf("%d%d", &cost[i], &m); for(int j = 0; j < m; j++){ scanf("%d", &vex); es[cnt++] = edge(i, vex); graph[i][vex] = true; } } Floyd(n); if(!graph[1][n] || !bellman_ford(1, n, cnt)){ printf("hopeless\n"); } else{ printf("winnable\n"); } } return 0; }
【SPFA】
#include <cstdio> #include <cstring> #include <queue> #define MAXN (100 + 10) using namespace std; //d表示s到各点的所经过路径的权值之和 //cost表示各点的权值 //cnt表示进入队列的次数 int d[MAXN], cost[MAXN], cnt[MAXN]; //reach表示两点之间是否联通,即可达 //graph记录两点之间是否有边 bool reach[MAXN][MAXN], graph[MAXN][MAXN]; void Init(){ memset(d, 0, sizeof(d)); memset(cnt, 0, sizeof(cnt)); memset(graph, false, sizeof(graph)); memset(reach, false, sizeof(reach)); } //判断图是否联通 void Floyd(int n){ for(int i = 1; i <= n; i++){ for(int k = 1; k <= n; k++){ for(int j = 1; j <= n; j++){ if(!reach[i][j]) reach[i][j] = reach[i][k] && reach[k][j]; } } } } bool SPFA(int s, int n){ queue<int> Q; d[s] = 100; Q.push(s); while(!Q.empty()){ int now = Q.front(); Q.pop(); cnt[now]++; //如果不存在负权环(PS:在本题中为正权环),即每个点进入队列的次数至多为n-1 //若大于n-1,即表明必然存在负权环 if(cnt[now] >= n) return reach[now][n]; //依次枚举每条边 for(int next = 1; next <= n; next++){ if(graph[now][next] && d[now] + cost[next] > d[next] && d[now] + cost[next] > 0){ Q.push(next); d[next] = d[now] + cost[next]; } } } return d[n] > 0; } int main(){ int n, m, vex; while(scanf("%d", &n), n != -1){ Init(); for(int i = 1; i <= n; i++){ scanf("%d%d", &cost[i], &m); for(int j = 0; j < m; j++){ scanf("%d", &vex); reach[i][vex] = true; graph[i][vex] = true; } } Floyd(n); if(!reach[1][n] || !SPFA(1, n)){ printf("hopeless\n"); } else{ printf("winnable\n"); } } return 0; }
(5)最短路之Dijkstra
【未优化版】
#include <cstdio> #include <vector> #include <algorithm> #define MAXN 200 + 10 #define INF 0xffffff using namespace std; struct Vex{ int v, weight; Vex(int tv, int tw):v(tv), weight(tw){} }; //graph用来记录图的信息 vector<Vex> graph[MAXN]; //判断是否已经找到最短路 bool inTree[MAXN]; //源点s到各顶点最短路的值 int mindist[MAXN]; //初始化 void Init(int n){ for(int i = 0; i < n; i++){ inTree[i] = false; graph[i].clear(); mindist[i] = INF; } } //s表示源点,t表示终点,n表示顶点数目 int Dijkstra(int s, int t, int n){ int tempMin, tempVex, addNode; //初始化s mindist[s] = 0; //将源点s标记为访问过 inTree[s] = true; //题目中可能有重边,我们去除重边 for(unsigned int i = 0; i < graph[s].size(); i++) mindist[graph[s][i].v] = min(mindist[graph[s][i].v], graph[s][i].weight); //从剩下的n-1个点逐个枚举 for(int nNode = 1; nNode <= n-1; nNode++){ tempMin = INF; //寻找所有未访问过点中,有最小距离的点 for(int i = 0; i < n; i++){ if(!inTree[i] && mindist[i] < tempMin){ tempMin = mindist[i]; addNode = i; } } //将该点标记为访问过 inTree[addNode] = true; //将与该点相邻的点进行松弛操作 for(unsigned int i = 0; i < graph[addNode].size(); i++){ tempVex = graph[addNode][i].v; if(!inTree[tempVex] && tempMin + graph[addNode][i].weight < mindist[tempVex]){ mindist[tempVex] = tempMin + graph[addNode][i].weight; } } } return mindist[t]; } int main(){ int n, m; int v1, v2, x, s, t; while(scanf("%d%d", &n, &m) != EOF){ Init(n); for(int i = 0; i < m; i++){ scanf("%d%d%d", &v1, &v2, &x); graph[v1].push_back(Vex(v2, x)); graph[v2].push_back(Vex(v1, x)); } scanf("%d%d", &s, &t); int ans = Dijkstra(s, t, n); if(ans == INF) printf("-1\n"); else printf("%d\n", ans); } return 0; }
【堆优化版】
#include <cstdio> #include <vector> #include <queue> #include <algorithm> #define MAXN 200 + 10 #define INF 0xffffff using namespace std; struct Vex{ int v, weight; bool operator < (const Vex & t) const{ return this->weight > t.weight; } Vex(int tv = 0, int tw = 0):v(tv), weight(tw){} }; vector<Vex> graph[MAXN]; bool inTree[MAXN]; int mindist[MAXN]; void Init(int n){ for(int i = 0; i < n; i++){ inTree[i] = false; graph[i].clear(); mindist[i] = INF; } } int Dijkstra_heap(int s, int t, int n){ priority_queue<Vex> Q; Vex tempVex; int v1, v2, weight; //初始化源点s的信息 mindist[s] = 0; Q.push(Vex(s, 0)); while(!Q.empty()){ //每次从堆中取出最小值 tempVex = Q.top(); Q.pop(); v1= tempVex.v; if(inTree[v1]) continue; //如果没有访问过,则我们将其标记为访问过 inTree[v1] = true; //将与其相邻的点,进行松弛操作 for(unsigned int i = 0; i < graph[v1].size(); i++){ v2 = graph[v1][i].v; weight = graph[v1][i].weight; if(!inTree[v2] && mindist[v1] + weight < mindist[v2]){ mindist[v2] = mindist[v1] + weight; //将满足条件的点重新加入堆中 Q.push(Vex(v2, mindist[v2])); } } } return mindist[t]; } int main() { int n, m; int v1, v2, x, s, t; while(scanf("%d%d", &n, &m) != EOF){ Init(n); for(int i = 0; i < m; i++){ scanf("%d%d%d", &v1, &v2, &x); graph[v1].push_back(Vex(v2, x)); graph[v2].push_back(Vex(v1, x)); } scanf("%d%d", &s, &t); int ans = Dijkstra_heap(s, t, n); if(ans == INF) printf("-1\n"); else printf("%d\n", ans); } return 0; }
3、数学
(1)快速幂
typedef long long ll; //注意这里不一定都是long long 有时 int 也行 ll mod_pow(ll x, ll n, ll mod){ ll res = 1; while( n > 0 ){ if( n & 1 ) res = res * x % mod; //n&1其实在这里和 n%2表达的是一个意思 x = x * x % mod; n >>= 1; //n >>= 1这个和 n/=2表达的是一个意思 } return res; }
(2)矩阵快速幂
#include <limits.h> class Matrix{ public: int **m; //保存矩阵值的指针 int row, col, mod; //分别保存矩阵的行、列以及取模的值 Matrix(int r = 0, int c = 0, int d = INT_MAX); Matrix(const Matrix &value); ~Matrix(); Matrix operator + (const Matrix &rht) const; Matrix operator * (const Matrix &rht) const; Matrix& operator = (const Matrix &rht); Matrix pow(int n) const; }; Matrix::Matrix(int r, int c, int d):row(r), col(c), mod(d){ m = new int*[row]; for(int i = 0; i < row; i++){ m[i] = new int[col]; } for(int i = 0; i < row; i++){ for(int j = 0; j < col; j++){ m[i][j] = 0; } } } Matrix::Matrix(const Matrix &value){ row = value.row; col = value.col; mod = value.mod; m = new int*[row]; for(int i = 0; i < row; i++){ m[i] = new int[col]; } for(int i = 0; i < row; i++){ for(int j = 0; j < col; j++){ m[i][j] = value.m[i][j]; } } } Matrix::~Matrix(){ for(int i = 0; i < row; i++){ delete[] m[i]; } delete[] m; } Matrix Matrix::operator + (const Matrix &rht) const{ Matrix temp(row, col, mod); for(int i = 0; i < row; i++){ for(int j = 0; j < col; j++){ temp.m[i][j] = (m[i][j] + rht.m[i][j])%mod; } } return temp; } Matrix Matrix::operator * (const Matrix &rht) const{ Matrix temp(row, rht.col, mod); for(int i = 0; i < row; i++){ for(int k = 0; k < rht.row; k++){ for(int j = 0; j < rht.col; j++){ temp.m[i][j] = (temp.m[i][j] + m[i][k]*rht.m[k][j])%mod; } } } return temp; } Matrix& Matrix::operator = (const Matrix &rht){ for(int i = 0; i < row; i++){ for(int j = 0; j < col; j++){ m[i][j] = rht.m[i][j]; } } return *this; } //矩阵快速幂 Matrix Matrix::pow(int n) const{ Matrix a(*this), res(row, col, mod); //将矩阵res初始化为单位矩阵 for(int i = 0; i < row; i++){ res.m[i][i] = 1; } //这时候直接使用快速幂的代码 while(n > 0){ if(n & 1) res = res * a; a = a * a; n >>= 1; } return res; }
(3)倍增法
//Matrix 类,同矩阵快速幂的Matrix类 //倍增法求解a^1 + a^2 + ... + a^n Matrix slove(const Matrix &a, int n){ //递归终点 if(n == 1) return a; //temp 递归表示a^1 + a^2 + ... + a^(n/2) Matrix temp = slove(a, n/2); //sum 表示 a^1 + a^2 + ... + a^(n/2) + (a^(n/2))*(a^1 + a^2 + ... + a^(n/2)) Matrix sum = temp + temp*a.pow(n/2); //如果当n为奇数,我们会发现我们的(n/2 + n/2) == n-1 //于是我们需要补上一项: a^n if(n & 1) sum = sum + a.pow(n); return sum; }
时间: 2024-10-12 22:27:04