Codeforces 842D
题意:给你$n(n \leq 3 \times 10^5)$个数,有$m(m \leq 3 \times 10^5)$次询问,每次询问给一个数$x$,要你把整个序列所有数都与$x$异或,然后取$mex$值
考虑Trie树维护序列中每个数的二进制形式值,从左到右由根到叶子插入(比如插入$(100)_2$先插入1边再插入0边)。考虑设一个最大深度$MNL$,所有数的二进制位数都要为$MNL$,不足在左边补0。$MNL$的大小必须大于所有数二进制形式长度。
之后我们就得到了一棵维护二进制数的Trie。先不管异或,我们来谈谈$mex$的求法。
这是一个贪心过程,因为首位越小数字越小,所以在Trie树中找最小的不存在的数即可以从根开始往下走,如果能走0边,就走0边。不能走的情况是,0边这个方向的子树大小是满的,不会有空,所以子树下的都在集合中,不是$mex$,那么就往1边走。如果向下走出现了没有节点可走,那么下面就直接全部选0边(这里的0边都是不存在的)往下走即可。
由于xor满足右结合,$a$异或$b$异或$c= a$异或$(b$异或$c)$。那么我们每次询问只需要把前程的所有$x$异或起来得到$nx$就行了。
字典树怎么异或?很麻烦,时间也不保证。我们尝试不修改字典树来进行查询$mex$
对于$nx$,如果要求得原序列以后的$mex$,从根向下遍历,类似不异或的情况。但是选边尽量要选和$nx$二进制下的边相同的。因为这样异或以后就是0。然后每次询问就可以了,类似不异或的情况
Trie维护二进制很常用!
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<set>
#define ms(i, j) memset(i, j, sizeof i)
#define LL long long
#define db double
using namespace std;
const int MAXN = 3 * 100000 + 5, MAX_NODE_LOG = 20;
int n, m, sz;
struct data {
int siz, dep, p, ch[2];
}node[MAX_NODE_LOG * MAXN];
void ins(int u, int x, int ws) {
if (ws == 0) node[u].siz = 1;
if (ws == 0) return ;
int a = x >> (ws - 1) & 1;
if (node[u].ch[a] == 0) node[u].ch[a] = ++sz, node[node[u].ch[a]].p = u, node[node[u].ch[a]].dep = node[u].dep + 1;
ins(node[u].ch[a], x, ws - 1);
node[u].siz = node[node[u].ch[0]].siz + node[node[u].ch[1]].siz + 1;
}
int flag, ans[MAX_NODE_LOG + 5], yyq;
void query(int u, int x) {
if (flag) return ;
if (yyq == 0) return ;
int a = x >> (yyq - 1) & 1;
if (node[node[u].ch[a]].siz != (1 << (MAX_NODE_LOG - node[node[u].ch[a]].dep + 1)) - 1) {
ans[yyq--] = a;
if (node[u].ch[a] != 0) query(node[u].ch[a], x); else {
while (yyq) ans[yyq] = x >> (yyq - 1) & 1, yyq--;
flag = true; return ;
}
} else {
ans[yyq--] = !a;
if (node[u].ch[!a] != 0) query(node[u].ch[!a], x); else {
while (yyq) ans[yyq] = x >> (yyq - 1) & 1, yyq--;
flag = true; return ;
}
}
}
void clean() {
sz = 1;
for (int i = 0; i < MAX_NODE_LOG * MAXN; i++) node[i].siz = node[i].dep = node[i].p = node[i].ch[0] = node[i].ch[1] = 0;
}
void solve() {
clean();
int nx = 0;
for (int x, i = 1; i <= n; i++) scanf("%d", &x), ins(1, x, MAX_NODE_LOG);
for (int x, i = 1; i <= m; i++) {
scanf("%d", &x), nx ^= x;
flag = false, yyq = MAX_NODE_LOG, ms(ans, 0), query(1, nx);
int zz = MAX_NODE_LOG, fans = 0;
while (zz > 0) {
if (ans[zz] != (nx >> (zz - 1) & 1)) fans += 1 << (zz - 1);
zz--;
}
printf("%d\n", fans);
}
}
int main() {
scanf("%d%d", &n, &m), solve();
return 0;
}