思路
看到多个子串并且不能包含的情况,想到了AC自动机
但是题目多了一个不能大于给出的n的限制条件,联想数位dp的过程,设f[i][j][0/1]表示在第i位,AC自动机的第j个节点,数位有/无限制的方案数
dp方程就是对应的转移到子节点即可,不向有标记的节点转移
注意如果跳fail能够跳到限制节点,就也不能转移,因为fail树上的父节点是子节点的子串,如果父节点是单词节点,子节点一定包含单词
另外题目中的数不能出现前导零,所以从根节点向子节点转移时不能转移到根的0号子节点
代码
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <queue>
#define int long long
const int MOD = 1000000007;
using namespace std;
int trie[1501][10],Nodecnt,fail[1501],mark[1501],root,f[1201][1501][2],lens,lent;
char s[1501],t[1501];
void insert(char *s,int len){
int o=root;
for(int i=1;i<=len;i++){
if(!trie[o][s[i]-'0'])
trie[o][s[i]-'0']=++Nodecnt;
o=trie[o][s[i]-'0'];
}
mark[o]++;
}
void get_fail(void){
queue<int> q;
for(int i=0;i<10;i++){
if(trie[root][i]){
fail[trie[root][i]]=root;
q.push(trie[root][i]);
}
}
while(!q.empty()){
int x=q.front();
q.pop();
for(int i=0;i<10;i++){
if(trie[x][i]){
fail[trie[x][i]]=trie[fail[x]][i];
// mark[trie[x][i]]|=mark[trie[fail[x]][i]];
q.push(trie[x][i]);
}
else{
trie[x][i]=trie[fail[x]][i];
}
}
}
}
void getban(void){
for(int i=0;i<=Nodecnt;i++){
int p=i;
for(;p;p=fail[p])
if(mark[p]){
mark[i]=true;
break;
}
}
}
int dp(void){
for(int i=0;i<lens;i++){
for(int j=0;j<=Nodecnt;j++){
if(f[i][j][0]){
for(int k=0;k<s[i+1]-'0';k++)
if(!mark[trie[j][k]])
f[i+1][trie[j][k]][1]=(f[i+1][trie[j][k]][1]+f[i][j][0])%MOD;
if(!mark[trie[j][s[i+1]-'0']])
f[i+1][trie[j][s[i+1]-'0']][0]=(f[i+1][trie[j][s[i+1]-'0']][0]+f[i][j][0])%MOD;
}
if(f[i][j][1]){
for(int k=0;k<10;k++)
if(!mark[trie[j][k]])
f[i+1][trie[j][k]][1]=(f[i+1][trie[j][k]][1]+f[i][j][1])%MOD;
}
if(!j){
if(!i){
for(int k=1;k<s[i+1]-'0';k++)
if(!mark[trie[j][k]])
f[i+1][trie[j][k]][1]=(1+f[i+1][trie[j][k]][1])%MOD;
if(!mark[trie[j][s[i+1]-'0']])
f[i+1][trie[j][s[i+1]-'0']][0]=(1+f[i+1][trie[j][s[i+1]-'0']][0])%MOD;
}
else{
for(int k=1;k<10;k++)
if(!mark[trie[j][k]])
f[i+1][trie[j][k]][1]=(1+f[i+1][trie[j][k]][1])%MOD;
}
}
}
}
int ans=0;
for(int i=0;i<=Nodecnt;i++)
ans=(f[lens][i][0]+f[lens][i][1]+ans)%MOD;
return ans;
}
signed main(){
scanf("%s",s+1);
lens=strlen(s+1);
int n;
scanf("%lld",&n);
for(int i=1;i<=n;i++){
scanf("%s",t+1);
lent=strlen(t+1);
insert(t,lent);
}
get_fail();
getban();
// for(int i=1;i<=Nodecnt;i++)
// if(mark[i])
// printf("%d!\n",i);
printf("%lld\n",dp());
return 0;
}
原文地址:https://www.cnblogs.com/dreagonm/p/10459643.html
时间: 2024-10-05 03:09:15