Skip to content

Commit

Permalink
Update classifier_classify_new.py
Browse files Browse the repository at this point in the history
  • Loading branch information
solderzzc authored Mar 23, 2019
1 parent c49c1d5 commit 481fa59
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions src/embedding/classifier_classify_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
#import judgeutil

BASEDIR = os.getenv('RUNTIME_BASEDIR',os.path.abspath(os.path.dirname(__file__)))


HAS_OPENCL = os.getenv('HAS_OPENCL','true')
sys.path.append(BASEDIR)
import judgeutil

Expand Down Expand Up @@ -363,7 +366,10 @@ def train_svm_with_embedding(args_list):
judge_paths = paths
judge_labels = labels
judge_nrof_images = len(judge_paths)
judge_emb_array = np.zeros((judge_nrof_images, 512))
if HAS_OPENCL == 'true:
judge_emb_array = np.zeros((judge_nrof_images, 512))
else:
judge_emb_array = np.zeros((judge_nrof_images, 128))
for j in range(judge_nrof_images):
judge_embedding = None
image_path = judge_paths[j]
Expand All @@ -390,7 +396,10 @@ def train_svm_with_embedding(args_list):
nrof_images = len(paths)
nrof_batches_per_epoch = int(math.ceil(1.0 * nrof_images / args.batch_size))

emb_array = np.zeros((nrof_images, 512))
if HAS_OPENCL == 'true':
emb_array = np.zeros((nrof_images, 512))
else:
emb_array = np.zeros((nrof_images, 128))
for i in range(nrof_batches_per_epoch):
start_index = i*args.batch_size
end_index = min((i+1)*args.batch_size, nrof_images)
Expand Down

0 comments on commit 481fa59

Please sign in to comment.