hdu 4878 ZCC loves words(AC自动机+dp+矩阵快速幂+中国剩余定理)
题意:给出若干个模式串,总长度不超过40,对于某一个字符串,它有一个价值,对于这个价值的计算方法是这样的,设初始价值为V=1,假如这个串能匹配第k个模式串,则V=V*prime[k]*(i+len[k]),其中prime[k]表示第k个素数,i表示匹配的结束位置,len[k]表示第k个模式串的长度(注意,一个字符串可以多次匹配同意个模式串)。问字符集为‘A‘-‘Z‘的字符,组成的所有的长为L的字符串,能得到的总价值和是多少?
解法:跟以前做过的很多AC自动机的题有点类似,很容易想到一个node*L的dp,dp[i][v]表示长为i,匹配到AC自动机的V节点能得到的价值和(详见代码中的DEBUG函数)。但是L太大,没法搞。节点总数只有40,那么就可以用矩阵来加速dp了,但是很可惜,建立矩阵的时候,发现建的矩阵居然是跟i有关,这样是不能直接用矩阵快速幂做的。但是,题目给出的提示是,mod可以拆成三个较小的质数。那么我们可以分别用三个较小的质数作为mod进行运算,因为第i个矩阵,它是跟第i+mod个矩阵一样的,所以我们可以把L个矩阵分成L/mod段,每一段的矩阵乘起来都是一样的,设其为A(可以暴力乘起来,因为mod很小),那么我们要的所有的L个矩阵的乘起来得到的矩阵,就是A^(L/mod),再乘上剩下来多余的L%mod个了,这样就可以计算出在每个较小的模系下的答案。最后用中国剩余定理计算总的答案。
代码:
#include<stdio.h> #include<string.h> #include<algorithm> #include<vector> #include<queue> #define ll __int64 using namespace std ; const int N = 44 ; const int mod = 5047621 ; int pri[12345] , p_num , vis[12345] ; void get_prime () { p_num = 0 ; for ( int i = 2 ; i < 12345 ; i ++ ) { if ( !vis[i] ) pri[++p_num] = i ; for ( int j = 1 ; j <= p_num ; j ++ ) { if ( i * pri[j] >= 12345 ) break ; vis[i*pri[j]] = 1 ; if ( i % pri[j] == 0 ) break ; } } } struct Point { int len , p ; Point () {} Point ( int a , int b ):len(a),p(b) {} } ; struct RECT { int elem[N][N] ; void print ( int n ) { for ( int i = 0 ; i < n ; i ++ , puts ( "" ) ) for ( int j = 0 ; j < n ; j ++ ) printf ( "%d " , elem[i][j] ) ; } } p[222] , E ; struct AC_auto { int dp[111][44] ; int c[N][26] , fail[N] , tot ; vector<Point> vec[N] ; queue<int> Q ; void init () { tot = 0 ; new_node () ; } int new_node () { vec[tot].clear () ; fail[tot] = 0 ; memset ( c[tot] , 0 , sizeof ( c[tot] ) ) ; return tot ++ ; } void insert ( char *s , int i ) { int now = 0 , len = strlen ( s ) ; for ( ; *s ; s ++ ) { int k = *s - 'A' ; if ( !c[now][k] ) c[now][k] = new_node () ; now = c[now][k] ; } vec[now].push_back ( Point ( len , pri[i] ) ) ; } void get_fail () { int u = 0 , v ; for ( int i = 0 ; i < 26 ; i ++ ) { if ( c[u][i] ) Q.push ( c[u][i] ) ; } while ( !Q.empty () ) { u = Q.front () ; Q.pop () ; for ( int i = 0 ; i < 26 ; i ++ ) { if ( c[u][i] ) { v = c[u][i] ; fail[v] = c[fail[u]][i] ; Q.push ( v ) ; } else c[u][i] = c[fail[u]][i] ; } } } void BUILD_RECT ( int l , int mod ) { memset ( p[l].elem , 0 , sizeof ( p[l].elem ) ) ; for ( int i = 0 ; i < tot ; i ++ ) { for ( int j = 0 ; j < 26 ; j ++ ) { int u = c[i][j] ; int v = u , ret = 1 ; while ( v ) { for ( int k = 0 ; k < vec[v].size () ; k ++ ) { Point u = vec[v][k] ; ret *= (l+u.len)*u.p ; ret %= mod ; } v = fail[v] ; } p[l].elem[u][i] += ret ; if ( p[l].elem[u][i] >= mod ) p[l].elem[u][i] -= mod ; } } } void RECT_MUIL ( RECT x , RECT y , RECT &z , int mod ) { memset ( z.elem , 0 , sizeof ( z.elem ) ) ; for ( int i = 0 ; i < tot ; i ++ ) { for ( int j = 0 ; j < tot ; j ++ ) for ( int k = 0 ; k < tot ; k ++ ) { z.elem[i][j] += x.elem[i][k] * y.elem[k][j] % mod ; if ( z.elem[i][j] >= mod ) z.elem[i][j] -= mod ; } } } void GAO ( RECT& ret , ll n , int mod ) { // printf ( "n = %I64d\n" , n ) ; RECT f = ret ; ret = E ; while ( n ) { if ( n & 1 ) RECT_MUIL ( ret , f , ret , mod ) ; RECT_MUIL ( f , f , f , mod ) ; n >>= 1 ; } } int SOLVE ( int mod , ll l ) { RECT ans = E , temp = E ; // printf ( "mod = %d\n" , mod ) ; for ( int i = mod ; i >= 1 ; i -- ) { BUILD_RECT ( i , mod ) ; RECT_MUIL ( temp , p[i] , temp , mod ) ; // if (i == 1) ans.print ( tot ) ; } // puts( "fuck ") ; GAO ( temp , l/mod , mod ) ; // ans.print ( tot ) ; for ( int i = l % mod ; i >= 1 ; i -- ) { BUILD_RECT ( i , mod ) ; RECT_MUIL ( ans , p[i] , ans , mod ) ; } RECT_MUIL ( ans , temp , ans , mod ) ; // ans.print ( tot ) ; int ret = 0 ; for ( int i = 0 ; i < tot ; i ++ ) { ret += ans.elem[i][0] ; if ( ret >= mod ) ret -= mod ; } return ret ; } void DEBUG ( ll l ) { memset ( dp , 0 , sizeof ( dp ) ) ; dp[0][0] = 1 ; for ( int i = 0 ; i < l ; i ++ ) { for ( int j = 0 ; j < tot ; j ++ ) { for ( int k = 0 ; k < 26 ; k ++ ) { int u = c[j][k] ; int v = u ; int ret = 1 ; while ( v ) { for ( int g = 0 ; g < vec[v].size () ; g ++ ) { Point f = vec[v][g] ; ret *= (i+1+f.len) * f.p ; } v = fail[v] ; } dp[i+1][u] += dp[i][j] * ret % mod ; if ( dp[i+1][u] >= mod ) dp[i+1][u] -= mod ; } } } int ans = 0 ; for ( int i = 0 ; i < tot ; i ++ ) { ans += dp[l][i] ; if ( ans >= mod ) ans -= mod ; } puts ( "fuck" ) ; printf ( "%d\n" , ans ) ; } } ac ; void extend_gcd ( ll a , ll b , int &x , int &y ) { if ( !b ) x = 1 , y = 0 ; else extend_gcd ( b , a % b , y , x ) , y -= x * ( a / b ) ; } char s[1111] ; int main () { for ( int i = 0 ; i < N ; i ++ ) for ( int j = 0 ; j < N ; j ++ ) E.elem[i][j] = i == j ; get_prime () ; int n ; ll l ; int ca = 0 ; while ( scanf ( "%d%I64d" , &n , &l ) != EOF ) { ac.init () ; for ( int i = 1 ; i <= n ; i ++ ) { scanf ( "%s" , s ) ; ac.insert ( s , i ) ; } ac.get_fail () ; // ac.DEBUG ( l ) ; int m1 , mm1 , m2 , mm2 , m3 , mm3 , fuck ;//mm为m的乘法逆元 m1 = 173 * 179 , m2 = 163 * 179 , m3 = 163 * 173 ; extend_gcd ( m1 , 163 , mm1 , fuck ) ; extend_gcd ( m2 , 173 , mm2 , fuck ) ; extend_gcd ( m3 , 179 , mm3 , fuck ) ; int a1 = ac.SOLVE ( 163 , l ) ; // printf( "a1 = %d\n" , a1 ) ; int a2 = ac.SOLVE ( 173 , l ) ; // printf ( "a2 = %d\n" , a2 ) ; int a3 = ac.SOLVE ( 179 , l ) ; // printf ( "a3 = %d\n" , a3 ) ; int ans = ( a1 * m1 * mm1 + a2 * m2 * mm2 + a3 * m3 * mm3 ) % 5047621 ; printf ( "Case #%d: %d\n" , ++ ca , ans ) ; } return 0 ; } /* 2 3 AB BB 2 2 A B */
hdu 4878 ZCC loves words(AC自动机+dp+矩阵快速幂+中国剩余定理)
时间: 2024-12-27 16:52:19