-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathLinearBase.py
97 lines (81 loc) · 2.68 KB
/
LinearBase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# 线性基第k小异或
from collections import defaultdict
from typing import List
class LinearBase:
__slots__ = ("bases", "_rows", "_bit")
@staticmethod
def fromlist(nums: List[int]) -> "LinearBase":
res = LinearBase(bit=max(nums, default=0).bit_length())
for x in nums:
res.add(x)
res.build()
return res
def __init__(self, bit=62):
self.bases = [] # 基底
self._rows = defaultdict(int) # 高斯消元的行
self._bit = bit # 最大数的位数
def add(self, x: int) -> bool:
"""插入一个向量,如果插入成功返回True,否则返回False"""
x = self.normalize(x)
if x == 0:
return False
i = x.bit_length() - 1
for j in range(self._bit):
if (self._rows[j] >> i) & 1:
self._rows[j] ^= x
self._rows[i] = x
return True
def build(self) -> None:
res = []
for _, v in sorted(self._rows.items()):
if v > 0:
res.append(v)
self.bases = res
def kthXor(self, k: int) -> int:
"""子序列(子集,包含空集)第k小的异或 1<=k<=2**len(self.bases)"""
assert 1 <= k <= 2 ** len(self.bases)
k -= 1
res = 0
for i in range(k.bit_length()):
if (k >> i) & 1:
res ^= self.bases[i]
return res
def maxXor(self) -> int:
return self.kthXor(2 ** len(self.bases))
def copy(self) -> "LinearBase":
res = LinearBase(self._bit)
res.bases = self.bases.copy()
res._rows = self._rows.copy()
res._bit = self._bit
return res
def normalize(self, x: int) -> int:
for i in range(x.bit_length() - 1, -1, -1):
if (x >> i) & 1:
x ^= self._rows[i]
return x
def __len__(self) -> int:
return len(self.bases)
def __contains__(self, x: int) -> bool:
"""x是否能由线性基表出"""
return self.normalize(x) == 0
if __name__ == "__main__":
nums = [1, 2, 3, 4, 5, 6, 7, 8, 9, 999]
lb = LinearBase.fromlist(nums)
print(lb.bases, len(lb))
print(lb.maxXor())
print(lb.kthXor(2))
print(lb.kthXor(17))
# test __contains__
res = set()
for i in range(1 << len(lb)):
bases = []
for j in range(len(lb)):
if (i >> j) & 1:
bases.append(lb.bases[j])
cur = 0
for b in bases:
cur ^= b
res.add(cur)
res = sorted(res)
ok = [i for i in range(lb.maxXor() + 1) if i in lb]
assert res == ok