Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upstream changes to big batch search #3170

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 29 additions & 22 deletions contrib/big_batch_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
import pickle
import os
import logging
from multiprocessing.pool import ThreadPool
import threading
import _thread
Expand Down Expand Up @@ -41,7 +42,7 @@ def __init__(
self.use_float16 = use_float16
keep_max = faiss.is_similarity_metric(index.metric_type)
self.rh = faiss.ResultHeap(len(xq), k, keep_max=keep_max)
self.t_accu = [0] * 5
self.t_accu = [0] * 6
self.t_display = self.t0 = time.time()

def start_t_accu(self):
Expand Down Expand Up @@ -74,11 +75,12 @@ def report(self, l):
f"[{t:.1f} s] list {l}/{self.index.nlist} "
f"times prep q {self.t_accu[0]:.3f} prep b {self.t_accu[1]:.3f} "
f"comp {self.t_accu[2]:.3f} res {self.t_accu[3]:.3f} "
f"wait {self.t_accu[4]:.3f} "
f"wait in {self.t_accu[4]:.3f} "
f"wait out {self.t_accu[5]:.3f} "
f"eta {datetime.timedelta(seconds=t*self.index.nlist/(l+1)-t)} "
f"mem {faiss.get_mem_usage_kb()}",
end="\r" if self.verbose <= 2 else "\n",
flush=True,
end="\r" if self.verbose <= 2 else "\n",
flush=True,
)
self.t_display = time.time()

Expand Down Expand Up @@ -293,7 +295,7 @@ def big_batch_search(
)
mem_tot = mem_queries + mem_assign + mem_res
if verbose > 0:
print(
logging.info(
f"memory: queries {mem_queries} assign {mem_assign} "
f"result {mem_res} total {mem_tot} = {mem_tot / (1<<30):.3f} GiB"
)
Expand All @@ -312,8 +314,8 @@ def big_batch_search(
)

bbs.decode_func = comp.decode_func
bbs.by_residual = comp.by_residual

bbs.by_residual = comp.by_residual
if q_assign is None:
bbs.coarse_quantization()
else:
Expand All @@ -327,11 +329,11 @@ def big_batch_search(
if checkpoint is not None:
assert (start_list, end_list) == (0, index.nlist)
if os.path.exists(checkpoint):
print("recovering checkpoint", checkpoint)
logging.info(f"recovering checkpoint: {checkpoint}")
completed = bbs.read_checkpoint(checkpoint)
print(" already completed", len(completed))
logging.info(f" already completed: {len(completed)}")
else:
print("no checkpoint: starting from scratch")
logging.info("no checkpoint: starting from scratch")

if threaded == 0:
# simple sequential version
Expand Down Expand Up @@ -414,29 +416,30 @@ def task_manager(*args):

def prepare_task(task_id, output_queue, input_queue=None):
try:
# print(f"Prepare start: {task_id}")
logging.info(f"Prepare start: {task_id}")
q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(task_id)
output_queue.put((task_id, q_subset, xq_l, list_ids, xb_l))
# print(f"Prepare end: {task_id}")
logging.info(f"Prepare end: {task_id}")
except:
traceback.print_exc()
_thread.interrupt_main()
raise

def compute_task(task_id, output_queue, input_queue):
try:
# print(f"Compute start: {task_id}")
t_wait = 0
logging.info(f"Compute start: {task_id}")
t_wait_out = 0
while True:
t0 = time.time()
logging.info(f'Compute input: task {task_id}')
input_value = input_queue.get()
t_wait += time.time() - t0
t_wait_in = time.time() - t0
if input_value is None:
# signal for other compute tasks
input_queue.put(None)
break
centroid, q_subset, xq_l, list_ids, xb_l = input_value
# print(f'Compute work start: task {task_id}, centroid {centroid}')
logging.info(f'Compute work: task {task_id}, centroid {centroid}')
t0 = time.time()
if computation_threads > 1:
D, I = comp.block_search(
Expand All @@ -445,13 +448,13 @@ def compute_task(task_id, output_queue, input_queue):
else:
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
t_compute = time.time() - t0
# print(f'Compute work end: task {task_id}, centroid {centroid}')
logging.info(f'Compute output: task {task_id}, centroid {centroid}')
t0 = time.time()
output_queue.put(
(centroid, t_wait, t_compute, q_subset, D, list_ids, I)
(centroid, t_wait_in, t_wait_out, t_compute, q_subset, D, list_ids, I)
)
t_wait = time.time() - t0
# print(f"Compute end: {task_id}")
t_wait_out = time.time() - t0
logging.info(f"Compute end: {task_id}")
except:
traceback.print_exc()
_thread.interrupt_main()
Expand Down Expand Up @@ -480,21 +483,25 @@ def compute_task(task_id, output_queue, input_queue):

t_checkpoint = time.time()
while True:
logging.info("Waiting for result")
value = compute_to_main_queue.get()
if not value:
break
centroid, t_wait, t_compute, q_subset, D, list_ids, I = value
centroid, t_wait_in, t_wait_out, t_compute, q_subset, D, list_ids, I = value
# to test checkpointing
if centroid == crash_at:
1 / 0
bbs.t_accu[2] += t_compute
bbs.t_accu[4] += t_wait
bbs.t_accu[4] += t_wait_in
bbs.t_accu[5] += t_wait_out
logging.info(f"Adding to heap start: centroid {centroid}")
bbs.add_results_to_heap(q_subset, D, list_ids, I)
logging.info(f"Adding to heap end: centroid {centroid}")
completed.add(centroid)
bbs.report(centroid)
if checkpoint is not None:
if time.time() - t_checkpoint > checkpoint_freq:
print("writing checkpoint")
logging.info("writing checkpoint")
bbs.write_checkpoint(checkpoint, completed)
t_checkpoint = time.time()

Expand Down