-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathmain.py
70 lines (45 loc) · 1.78 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import os.path as osp
import argparse
import numpy as np
from operators import cslbp, csldp, csldmp, cslmp, csltp, xcslbp, xcslmp, xcsltp
import load_grayscale
def get_texture_operator(operator):
texture_operator_dict = {
"cslbp": cslbp,
# "csldp": csldp,
# "csldmp": csldmp,
# "cslmp": cslmp,
# "csltp": csltp,
# "xcslbp": xcslbp,
# "xcslmp": xcslmp,
# "xcsltp": xcsltp
}
return texture_operator_dict[operator]
def compute_features(args):
# Load grayscale images
x_train, x_test = load_grayscale.data("grayscale-images")
print("Loaded grayscale images.\n")
# Get corresponding texture operator
texture_operator = get_texture_operator(args.operator)
print("Using {} texture operator".format(args.operator))
print("Computing features.\n")
# This takes about 30-35 minutes for CIFAR-10, hence saving to disk.
x_train = texture_operator.get_features(x_train, args.img_height, args.img_width)
x_test = texture_operator.get_features(x_test, args.img_height, args.img_width)
if not osp.exists("features"):
os.makedirs("features")
with open(osp.join("features", "{}_train_features.npy".format(args.operator)), "wb") as handle:
np.save(handle, x_train)
with open(osp.join("features", "{}_test_features.npy".format(args.operator)), "wb") as handle:
np.save(handle, x_test)
print("Computed features and saved to disk in 'features' directory.")
def main(args):
compute_features(args)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--operator", choices = ["cslbp", "csldp", "csldmp", "cslmp", "csltp", "xcslbp", "xcslmp", "xcsltp"], default = "cslbp")
parser.add_argument("--img_height", default = 32, type = int)
parser.add_argument("--img_width", default = 32, type = int)
args = parser.parse_args()
main(args)