[dp专题] AC自动机与状态压缩dp的结合

最近做到好几道关于AC自动机与状态压缩dp的结合的题,这里总结一下。

题目一般会给出m个字符串,m不超过10,然后求长度为len并且包含特定给出的字符串集合的字符串个数。

以HDU 4758为例:

把题意抽象为:给出两个字符串,且只包含两种字符 ‘R‘、‘D‘,现在求满足下列条件的字符串个数:字符串长度为(m+n),其中包含n个‘D‘,m个‘R‘。

如果不用AC自动机来做,这道题还真没法做了,因为不管怎样都找不到正确的dp状态转移方程。

而如果引入AC自动机,把在AC自动机上的结点当做dp的一个维度的状态,那么问题就可解了。

dp[c][zt][i][j]:c表示当前状态的字符串对应于AC自动机上的结点,zt表示给定字符串取舍情况的压缩状态,i表示‘D‘的个数,j表示‘R‘的个数。

那么dp[c][zt][i][j]表示当前状态字符串的个数。

AC自动机的作用就是增加一个状态维度,使dp过程有足够的信息来转移状态。

#include<cstdio>
#include<cstring>
#include<queue>
using namespace std;
const int mod = 1000000007;
int ch[202][2],End[202],cur,fail[202],last[202];
void get_fail() {
    int now,tmpFail,Next;
    queue<int> q;
    for(int j=0;j<2;j++) {
        if(ch[0][j]) {
            q.push(ch[0][j]);
            fail[ch[0][j]] = 0;
            last[ch[0][j]] = 0;
        }
    }
    while(!q.empty()) {
        now = q.front();q.pop();
        for(int j=0;j<2;j++) {
            if(!ch[now][j]) continue;
            Next = ch[now][j];
            q.push(Next);
            tmpFail = fail[now];
            while(tmpFail&&!ch[tmpFail][j]) tmpFail = fail[tmpFail];
            fail[Next] = ch[tmpFail][j];
            last[Next] = End[fail[Next]] ? fail[Next]:last[fail[Next]];
        }
    }
}
int dp[202][4][102][102];//dp[c][zt][i][j]
int main() {
    int T,m,n;
    char str0[3][104];
    scanf("%d",&T);
    while(T--) {
        cur=1;
        scanf("%d%d",&m,&n);
        n++;m++;
        memset(End,0,sizeof(End));
        memset(ch,0,sizeof(ch));
        memset(last,0,sizeof(last));
        for(int i=1;i<=2;i++) {
            scanf("%s",str0[i]);
            int len = strlen(str0[i]);
            int now = 0;
            for(int j=0;j<len;j++) {
                if(str0[i][j]==‘R‘) str0[i][j]=1;
                else str0[i][j]=0;
                if(ch[now][str0[i][j]]==0) ch[now][str0[i][j]] = cur++;
                now = ch[now][str0[i][j]];
            }
            End[now] = i;
        }
        get_fail();

        memset(dp,0,sizeof(dp));
        dp[0][0][0][0]=1;
        for(int i=0;i<n;i++) //要特别注意这里内外循环顺序,必须把i、j循环放在外面
        for(int j=0;j<m;j++) {
            for(int c=0;c<cur;c++) {
                for(int zt=0;zt<=3;zt++){
                    if(dp[c][zt][i][j])
                    for(int k=0;k<2;k++) {
                        if(k==0&&i==n-1) continue;
                        else if(k==1&&j==m-1) continue;
                        int now=c;
                        while(now&&!ch[now][k]) now = fail[now];
                        now = ch[now][k];

                        int t=0;
                        if(End[now])
                            t = t|(1<<(End[now]-1));
                        int tmp = now;
                        while(last[tmp]) {
                            t = t|(1<<(End[last[tmp]]-1));
                            tmp = last[tmp];
                        }
                        if(k==0) {
                            dp[now][zt|t][i+1][j] += dp[c][zt][i][j];
                            if(dp[now][zt|t][i+1][j]>=mod) dp[now][zt|t][i+1][j]-=mod;
                        }
                        else if(k==1) {
                            dp[now][zt|t][i][j+1] += dp[c][zt][i][j];
                            if(dp[now][zt|t][i][j+1]>=mod) dp[now][zt|t][i][j+1]-=mod;
                        }
                    }
                }
            }
        }
        long long ans=0;
        for(int i=0;i<cur;i++) {
            ans+=dp[i][3][n-1][m-1];
            if(ans>=mod) ans-=mod;
        }
        printf("%I64d\n",ans);
    }
}

注意循环的内外顺序,一般情况下,字符串长度的循环都是放在外层,也就是说,一定要先计算出长度为i的所有字符串状态,才能计算长度为i+1的所有字符串状态。

类似的 HDU 2825 :

#include<iostream>
#include<algorithm>
#include<cstring>
#include<cstdio>
#include<queue>
using namespace std;
const int mod=20090717;
int ch[11*11][26],End[11*11],cur,fail[11*11],last[11*11];
char str0[12][12];
void get_fail() {
    int now,tmpFail,Next;
    queue<int> q;
    for(int j=0;j<26;j++) {
        if(ch[0][j]) {
            q.push(ch[0][j]);
            fail[ch[0][j]] = 0;
            last[ch[0][j]] = 0;
        }
    }
    while(!q.empty()) {
        now = q.front();q.pop();
        for(int j=0;j<26;j++) {
            if(!ch[now][j]) continue;
            Next = ch[now][j];
            q.push(Next);
            tmpFail = fail[now];
            while(tmpFail&&!ch[tmpFail][j]) tmpFail = fail[tmpFail];
            fail[Next] = ch[tmpFail][j];
            last[Next] = End[fail[Next]] ? fail[Next]:last[fail[Next]];
        }
    }
}
int dp[27][11*11][1055];
int main()
{
    int sum[1055];
    for(int I=0;I<(1<<10);I++) {
            sum[I]=0;
            int tmp=I;
            while(tmp) {
                if(tmp&1) sum[I]++;
                tmp>>=1;
            }
    }
    int n,m,k;
    while(scanf("%d%d%d",&n,&m,&k)!=EOF&&(n||m||k))
    {
        cur=1;
        int len[13];
        memset(End,0,sizeof(End));
        memset(ch,0,sizeof(ch));
        memset(last,0,sizeof(last));
        for(int i=1;i<=m;i++) {
            scanf("%s",str0[i]);
            len[i] = strlen(str0[i]);
            int now = 0;
            for(int j=0;j<len[i];j++) {
                str0[i][j]-=‘a‘;
                if(ch[now][str0[i][j]]==0) ch[now][str0[i][j]] = cur++;
                now = ch[now][str0[i][j]];
                str0[i][j]+=‘a‘;
            }
            End[now] = i;

        }
        get_fail();
        memset(dp,0,sizeof(dp));
        dp[0][0][0]=1;
        int pre=0,zt=0;
        int ans=0;
        for(int i=0;i<n;i++) {
            for(int j=0;j<cur;j++) {
                for(int zt=0;zt<(1<<m);zt++) {
                    if(dp[i][j][zt]) {
                    for(int c=0;c<26;c++) {
                        int now = j;
                        while(now&&!ch[now][c]) now = fail[now];
                        now = ch[now][c];
                        int t=0;
                        if(End[now])
                            t = t|(1<<(End[now]-1));
                        int tmp = now;
                        while(last[tmp]) {
                            t = t|(1<<(End[last[tmp]]-1));
                            tmp = last[tmp];
                        }
                        dp[i+1][now][zt|t] += dp[i][j][zt];
                        if(dp[i+1][now][zt|t]>=mod) dp[i+1][now][zt|t]-=mod;
                    }
                    }
                }
            }
        }
        for(int I=0;I<(1<<m);I++) {
            if(sum[I]>=k) {
                for(int j=0;j<cur;j++){
                    ans+=dp[n][j][I];
                    if(ans>=mod) ans-=mod;

                }
            }
        }
        printf("%d\n",ans);
    }
}

 

HDU 4057:

#include<cstdio>
#include<cstring>
#include<queue>
using namespace std;
int ch[11*102][4],End[11*102],cur,fail[11*102],last[11*102];
int w[11];
char str[102],str0[11][102];
void get_fail()
{
    int now,tmpFail,Next;
    queue<int> q;
    //用bfs生成fail
    //初始化队列
    for(int j=0; j<4; j++)
    {
        if(ch[0][j])
        {
            q.push(ch[0][j]);
            fail[ch[0][j]] = 0;
            last[ch[0][j]] = 0;
        }
    }
    while(!q.empty())
    {
        //从队列中拿出now
        //此时now中的fail、last已经算好了
        //下面计算的是ch[now][j]中的fail、last。
        now = q.front();
        q.pop();
        for(int j=0; j<4; j++)
        {
            if(!ch[now][j]) continue;
            Next = ch[now][j];
            q.push(Next);
            tmpFail = fail[now];
            while(tmpFail&&!ch[tmpFail][j]) tmpFail = fail[tmpFail];
            fail[Next] = ch[tmpFail][j];
            last[Next] = End[fail[Next]] ? fail[Next]:last[fail[Next]];
        }
    }
}
int dp[1029][11*102][2];
bool vis[1029][11*102][2];
int n,l,now,ans;
queue<int> quezt;
queue<int> quenow;
queue<int> quelen;
void bfs (int zt,int now0,int len)
{
    //printf("%d %d %d %d\n",zt,now0,len,dp[zt][now0][len%2]);
    //printf("%d\n",quezt.size());
    if(len==l) ans=max(ans,dp[zt][now0][l%2]);
    if(len==l+1) return;
    for(int i=0; i<4; i++)
    {
        int now=now0,temp=0;
        while(now&&!ch[now][i]) now = fail[now];
        now = ch[now][i];
        int newzt = zt;
        if(End[now])
        {
            if(((1<<(End[now]-1))|newzt)!=newzt) temp+=w[End[now]];
            newzt = (1<<(End[now]-1))|newzt;
        }
        int tmp = now;
        while(last[tmp])
        {
            if(End[last[tmp]])
            {
                if(((1<<(End[last[tmp]]-1))|newzt)!=newzt) temp+=w[End[last[tmp]]];
                newzt = (1<<(End[last[tmp]]-1))|newzt;
            }
            tmp = last[tmp];
        }
        if(newzt!=zt) {
            //printf("%d\n",temp);
            if(!vis[newzt][now][(len+1)%2]) dp[newzt][now][(len+1)%2]=dp[zt][now0][len%2]+temp;
            else dp[newzt][now][(len+1)%2]=max(dp[zt][now0][len%2]+temp,dp[newzt][now][(len+1)%2]);
        }
        else{
            if(!vis[zt][now][(len+1)%2]) dp[zt][now][(len+1)%2]=dp[zt][now0][len%2];
            else dp[zt][now][(len+1)%2]=max(dp[zt][now0][len%2],dp[zt][now][(len+1)%2]);
        }
        //dfs(newzt,now,len+1);
        if(!vis[newzt][now][(len+1)%2]) {
            quezt.push(newzt);
            quenow.push(now);
            quelen.push(len+1);
            vis[newzt][now][(len+1)%2]=true;
        }
    }
    //if(len==l) ans=max(ans,dp[zt][now0][l%2]);
}
int main()
{
    while(scanf("%d%d",&n,&l)!=EOF)
    {
        memset(dp,-1,sizeof(dp));
        memset(ch,0,sizeof(ch));
        memset(End,0,sizeof(End));
        memset(last,0,sizeof(last));
        cur = 1;
        int len;
        for(int i=1; i<=n; i++)
        {
            scanf("%s%d",str0[i],&w[i]);
            //puts(str0[i]);
            len = strlen(str0[i]);
            now = 0;
            for(int j=0; j<len; j++)
            {
                if(str0[i][j]==‘A‘) str0[i][j]=0;
                if(str0[i][j]==‘T‘) str0[i][j]=1;
                if(str0[i][j]==‘G‘) str0[i][j]=2;
                if(str0[i][j]==‘C‘) str0[i][j]=3;
                if(ch[now][str0[i][j]]==0) ch[now][str0[i][j]] = cur++;
                now = ch[now][str0[i][j]];
                if(str0[i][j]==0) str0[i][j]=‘A‘;
                if(str0[i][j]==1) str0[i][j]=‘T‘;
                if(str0[i][j]==2) str0[i][j]=‘G‘;
                if(str0[i][j]==3) str0[i][j]=‘C‘;
            }
            End[now] = i;
        }
        //printf("%d\n",cur);
        get_fail();
        //printf("%d\n",cur);
        dp[0][0][0]=0;
        quezt.push(0);
        quenow.push(0);
        quelen.push(0);
        memset(vis,false,sizeof(vis));
        vis[0][0][0]=true;
        ans=-1;
        int pre=0;
        while(!quezt.empty()) {
            //if(quelen.front()!=pre) {
            //    for(int i=0;i<1029;i++)
            //    for(int j=0;j<11*102;j++) dp[i][j][pre%2]=0;
            //    pre=quelen.front();
            //}
            bfs(quezt.front(),quenow.front(),quelen.front());
            vis[quezt.front()][quenow.front()][quelen.front()%2]=false;
            quezt.pop();quenow.pop();quelen.pop();
        }
        if(ans==-1) puts("No Rabbit after 2012!");
        else printf("%d\n",ans);
    }
}
时间: 2024-08-07 04:31:47

[dp专题] AC自动机与状态压缩dp的结合的相关文章

HDU 3247 Resource Archiver (AC自动机 + BFS + 状态压缩DP)

题目链接:Resource Archiver 解析:n个正常的串,m个病毒串,问包含所有正常串(可重叠)且不包含任何病毒串的字符串的最小长度为多少. AC自动机 + bfs + 状态压缩DP 用最短路预处理出状态的转移.可以优化很多 AC代码: #include <cstdio> #include <iostream> #include <cstring> #include <algorithm> #include <queue> using n

【HDU2825】Wireless Password【AC自动机,状态压缩DP】

题意 题目给出m(m<=10)个单词,每个单词的长度不超过10且仅由小写字母组成,给出一个正整数n(n<=25)和正整数k,问有多少方法可以组成长度为n的文本且最少包含k个给出的单词. 分析 和上一个AC自动机很相似,上一篇博客是不包含任何一个单词长度为n的方案数,这个题是包含至少k个单词的方案数,而且n,m,k都非常的小. 按照前面的经验很容易想到,我们还是得先建一个AC自动机,然后把它的单词结点标记出来.与前面不同的是我们在状态转移的时候需要考虑到当前走过的结点已经包含多少单词了.所以我们

HDU 3247 Resource Archiver AC自动机 + bfs + 状态压缩dp

题意:给定你n个文本串 ,m个模式串,怎么构造最短的新的文本串使得这个新的文本串包含n个文本串的所有信息且文本串的长度最短且不包含模式串. 解题思路:这里看题解撸的,首先我们算出两两文本串的距离(end数组标记文本和模式串的值不同,利用这个进行bfs算出两两之间的最短距离,注意到这里模式串的end是不能走到的.这里也不需要松弛操作),然后因为n只有10这么大,所以我们可以状态压缩  ,dp[i][j] 表示 压缩后状态为 i(二进制压缩,每i位表示第i个是否在)且 以j结尾的文本串的最小花费.这

dp乱写2:状态压缩dp(状压dp)炮兵阵地

https://www.luogu.org/problem/show?pid=2704 题意: 炮兵在地图上的摆放位子只能在平地('P') 炮兵可以攻击上下左右各两格的格子: 而高原('H')上炮兵能够攻击到但是不能摆放 求最多能摆放的炮兵的数量 就是这个意思. 难度提高,弱省省选 一开始是想写dfs(迷之八皇后)的, 但是看到数据量100就想dp了: 因为题目n的范围给的很少n<=10,想到状压 非常明显是一个状态压缩的dp(状压dp) 其实可以当做状压的入门题目来做. 由于本行的状态是由前若

HDU 4511 (AC自动机+状态压缩DP)

题目链接:  http://acm.hdu.edu.cn/showproblem.php?pid=4511 题目大意:从1走到N,中间可以选择性经过某些点,比如1->N,或1->2->N,但是某些段路径(注意不是某些条)是被禁止的.问从1->N的最短距离. 解题思路: AC自动机部分: 如果只是禁掉某些边,最短路算法加提前标记被禁的边即可. 但是本题是禁掉指定的路段,所以得边走边禁,需要一个在线算法. 所以使用AC自动机来压缩路段,如禁掉的路段是1->2->3,那么in

POJ 3691 (AC自动机+状态压缩DP)

题目链接:  http://poj.org/problem?id=3691 题目大意:给定N的致病DNA片段以及一个最终DNA片段.问最终DNA片段最少修改多少个字符,使得不包含任一致病DNA. 解题思路: 首先说一下AC自动机在本题中的作用. ①字典树部分:负责判断当前0~i个字符组成的串是否包含致病DNA,这部分靠字典树上的cnt标记完成. ②匹配部分:主要依赖于匹配和失配转移关系的计算,这部分非常重要,用来构建不同字符间状态压缩的转移关系(代替反人类的位运算). 这也是必须使用AC自动机而

HDU 2825 Wireless Password (AC自动机 + 状态压缩DP)

题目链接:Wireless Password 解析:给 m 个单词构成的集合,统计所有长度为 n 的串中,包含至少 k 个单词的方案数. AC自动机 + 状态压缩DP. DP[i][j][k]:长度为i的字符串匹配到状态j且包含k个magic word的可能字符串个数. AC代码: #include <algorithm> #include <iostream> #include <cstdio> #include <queue> #include <

hdu 4057 AC自动机+状态压缩dp

http://acm.hdu.edu.cn/showproblem.php?pid=4057 Problem Description Dr. X is a biologist, who likes rabbits very much and can do everything for them. 2012 is coming, and Dr. X wants to take some rabbits to Noah's Ark, or there are no rabbits any more.

HDU 3341 Lost&#39;s revenge AC自动机+ 状态压缩DP

题意:这个题目和HDU2457有点类似,都是AC自动机上的状态dp,题意就是给你只含有'A','T','C','G',四个字符的子串和文本串,问你文本串如何排列才可以使得文本串中包含有更多的模式串 解题思路:我们知道了 有 num[0] 个 'A', num[1] 个 ‘T’, num[2] 个 ‘C’,num[3] 个‘G’, 我们的可以知道暴力的思路就是把所有的文本串都枚举出来然后一一匹配.我们膜拜了一下春哥以后,就可以有以下思路:  把一个串的信息压缩一下,把具有同样个数字符的串看成是同一