// Code By CloudySky #include<bits/stdc++.h> // #define int long long namespace IO // namespace IO usingnamespace IO; constint Maxn = 5e5 + 10; usingnamespace std;
char s[Maxn]; int sa[Maxn], rk[Maxn], h[Maxn], trk[Maxn << 1];
boolcmp(int x, int y, int w){return trk[x] == trk[y] && trk[x + w] == trk[y + w];}
voidSA(int n){ staticint cnt[Maxn], id[Maxn]; int m = 300, p = 0; for (int i = 1; i <= n; ++i) cnt[rk[i] = s[i]]++; for (int i = 1; i <= m; ++i) cnt[i] += cnt[i - 1]; for (int i = n; i; --i) sa[cnt[rk[i]]--] = i; for (int w = 1; p != n; w <<= 1, m = p) { p = 0; for (int i = n; i > n - w; --i) id[++p] = i; for (int i = 1; i <= n; ++i) if (sa[i] > w) id[++p] = sa[i] - w; memset(cnt, 0, sizeof(cnt)); for (int i = 1; i <= n; ++i) cnt[rk[id[i]]]++; for (int i = 1; i <= m; ++i) cnt[i] += cnt[i - 1]; for (int i = n; i; --i) sa[cnt[rk[id[i]]]--] = id[i]; memcpy(trk, rk, sizeof(rk)); p = 0; for (int i = 1; i <= n; ++i) rk[sa[i]] = cmp(sa[i], sa[i - 1], w) ? p : ++p; if (p == n) for (int i = 1; i <= n; ++i) sa[rk[i]] = i; } for (int i = 1, k = 0; i <= n; ++i) { if (k) --k; while (s[i + k] == s[sa[rk[i] - 1] + k]) ++k; h[rk[i]] = k; } }
pair <int, int> kth (int n, int k) { for (int i = 1, x; i <= n; ++i) { x = n - sa[i] + 1 - h[i]; if (x >= k) return {i, k}; else k -= x; } return {0, -1}; }
namespace ST{ int st[Maxn][33], lg[Maxn];
voidinit(int n){ for (int i = 1; i <= n; ++i) st[i][0] = h[i]; for (int k = 1; k <= log2(n); ++k) { for (int i = 1; i + (1 << k) - 1 <= n; ++i) st[i][k] = min(st[i][k - 1], st[i + (1 << (k - 1))][k - 1]); } }
intquery(int l, int r, int n){ int k = log2(r - l); returnmin(st[l + 1][k], st[r - (1 << k) + 1][k]); } }
signedmain(){ scanf("%s", s + 1); int n = strlen(s + 1), t = read(), k = read(), p, w; SA(n); if (t == 0){ auto ans = kth(n, k); if (ans.second == -1) returnprint(-1), 0; p = ans.first, w = ans.second; for (int i = sa[p]; i <= sa[p] + h[p] + w - 1; ++i) putchar(s[i]); return0; } else { if (k > 1ll * n * (n + 1) / 2) returnprint(-1), 0; int l = 0, r = k; ST::init(n); while (l < r) { int mid = (l + r) >> 1, tot = 0, p, w; auto tmp = kth(n, mid); p = tmp.first, w = tmp.second; for (int i = 1; i < p; ++i) tot += n - sa[i] + 1; tot += min(w, n - sa[p] + 1); for (int i = p + 1; i <= n; ++i) tot += min(w, ST::query(p, i, n)); if (tot >= k) r = mid; else l = mid + 1; } auto ans = kth(n, l); p = ans.first, w = ans.second; for (int i = sa[p]; i <= sa[p] + w - 1; ++i) putchar(s[i]); } return0; }
题目思路 - SAM
对于 SAM 来讲,只要思路对了两种情况差不多。
构建出 SAM,对于每个点 x ,统计出以它结尾的子串个数,记为 siz[x] ,利用 siz 统计出经过它的所有子串个数,记为 sum[x] 。
至于如何去求 siz 和 min ,大概有两种方法,一种是常规的倒着连 fail 然后跑 dfs 或 bfs ,而另一种则是对 len 从小到大进行排序,倒着扫更新(当然不能直接对 len 排序,而是要建立一个映射关系。) 这样就能保证更新 fail[x] 时已经更新过 x 了。对于条件 1 在求 sum 之前将每个点的 siz 赋成 1 即可。
递归进行输出,每到一个点 a∼z 枚举转移边,比较 k 和 sum[sam[p][i]] 的大小关系。
voidinsert(char c){ int p = lst, np = ++tot, ch = c - 'a'; lst = np, len[np] = len[p] + 1, siz[np] = 1; for (; p != -1 && !sam[p][ch]; p = fail[p]) sam[p][ch] = np; if (p == -1) return fail[np] = 0, void(); int q = sam[p][ch]; if (len[q] == len[p] + 1) return fail[np] = q, void(); int nq = ++tot; fail[nq] = fail[q], fail[np] = fail[q] = nq; memcpy(sam[nq], sam[q], sizeof(sam[nq])); len[nq] = len[p] + 1, siz[nq] = 0; for (; p != -1 && sam[p][ch] == q; p = fail[p]) sam[p][ch] = nq; }
voidprint_kth(int p, int k){ if (k <= siz[p]) return; k -= siz[p]; for (int i = 0; i < 26; ++i) { if (!sam[p][i]) continue; if (sum[sam[p][i]] >= k) returnputchar(i + 'a'), print_kth(sam[p][i], k); else k -= sum[sam[p][i]]; } }
signedmain(){ staticchar s[Maxn]; staticint cnt[Maxn], id[Maxn]; scanf("%s", s); SAM(); int n = strlen(s), t = read(), k = read(); for (int i = 0; i < n; ++i) insert(s[i]); // 基数排序建立 len 从小到大映射关系 for (int i = 0; i <= tot; ++i) ++cnt[len[i]]; for (int i = 0; i <= tot; ++i) cnt[i] += cnt[i - 1]; for (int i = 0; i <= tot; ++i) id[--cnt[len[i]]] = i; // end // 求 siz 和 sum for (int i = tot; i >= 0; --i) siz[fail[id[i]]] += siz[id[i]]; for (int i = 0; i <= tot; ++i) sum[i] = (t == 0 ? (siz[i] = 1) : siz[i]); // 注意这里 siz[0] = sum[0] = 0; for (int i = tot; i >= 0; --i) for (int j = 0; j < 26; ++j) if (sam[id[i]][j]) sum[id[i]] += sum[sam[id[i]][j]]; // end if (sum[0] < k) returnprint(-1), 0; print_kth(0, k); return0; }