CCPC 2021 网络预选赛重赛 hdu 7131-Nun Heh Heh Aaaaaaaaaaa (计数dp-公共子序列模型)

http://acm.hdu.edu.cn/showproblem.php?pid=7131

题意

给定一个序列s,求其[前缀是𝚗𝚞𝚗𝚑𝚎𝚑𝚑𝚎𝚑,后缀是>=1个a]的子序列个数

题解

先求出子序列为nunhehheh的个数,定义dp(i,j)为s的前i个字符中和nunhehheh匹配到第j个个数.然后预处理出i后面有多少个a,记为a[i],对于每个dp(i,9)乘$2^{a[i]}$再相加即可得到所有方案数

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<stack>
#include<map>
#include<unordered_map>
#include<set>
#pragma GCC optimize(2)
#pragma GCC optimize("inline")
#pragma GCC optimize("-fgcse")
#pragma GCC target("avx","sse2")
#pragma GCC optimize("-fgcse-lm")
#pragma GCC optimize("-fipa-sra")
#pragma GCC optimize("-ftree-pre")
#pragma GCC optimize("-ftree-vrp")
#pragma GCC optimize("-fpeephole2")
#pragma GCC optimize("-ffast-math")
#pragma GCC optimize("-fsched-spec")
#pragma GCC optimize("unroll-loops")
using namespace std;
#define ll long long
#define PII pair<int,int>
#define PLL pair<ll,ll>
#define PIII pair<int,PII>
#define PLLL pair<ll,PLL>
#define fi first
#define se second
#define pb push_back
#define debug(a) cout << #a << " " << a << '\n';
const int N = 1e5 + 5;
const int M = 1e5 + 5;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll mod = 998244353;

inline ll read();

int n, m, t;

ll dp[N][15];
ll a[N];
ll poww[N];
void solve() {
char s[N];
string p = "@nunhehheh";
cin >> (s + 1);
ll len =strlen(s+1);
for (int i = 0; i <= len+1; i++) {//初始化
for (int j = 0; i <= 10; i++)dp[i][j] = 0;
a[i] = 0;
}
for (int i = len; i >= 0; i--) {
dp[i][0] = 1;//与s中第i个字符一个都不匹配的数量是1
if (s[i] == 'a')a[i] = (a[i + 1] + 1) % mod;
else {
a[i] = a[i + 1];//预处理
}
}
ll ans = 0;
for (int i = 1; i <= len; i++) {
for (int j = 1; j <= 9; j++) {
if (s[i] == p[j])dp[i][j] = (dp[i - 1][j - 1] + dp[i - 1][j]) % mod;
else {
dp[i][j] = dp[i - 1][j] % mod;//算公共序列个数
}
}
}
for (int i = 0; i <= len; i++) {
if (s[i] == 'h') {
ans += (dp[i][8] * (poww[a[i]] - 1)) % mod;//注意这里是dp[i][8].如用dp[i][9]算答案会重复
}
}
cout << ans % mod << '\n';
}

int main() {
ios::sync_with_stdio(false);
cin >> t;
poww[0]=1;
for(int i=1;i<=1e5;i++){
poww[i] =(poww[i-1]*2)%mod;
}
while (t--) {
solve();
}

return 0;
}


inline ll read() {
char ch = getchar();
ll p = 1, data = 0;
while (ch < '0' || ch > '9') {
if (ch == '-')p = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9') {
data = data * 10 + (ch ^ 48);
ch = getchar();
}
return p * data;
}