This repository has been archived by the owner on Mar 6, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathmali_imagenet_bench.py
110 lines (92 loc) · 3.82 KB
/
mali_imagenet_bench.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
Benchmark inference speed on ImageNet
Example (run on Firefly RK3399):
python mali_imagenet_bench.py --target-host 'llvm -target=aarch64-linux-gnu' --host 192.168.0.100 --port 9090 --model mobilenet
"""
import time
import argparse
import numpy as np
import tvm
import nnvm.compiler
import nnvm.testing
from tvm.contrib import util, rpc
from tvm.contrib import graph_runtime as runtime
def run_case(model, dtype):
# load model
if model == 'vgg16':
net, params = nnvm.testing.vgg.get_workload(num_layers=16,
batch_size=1, image_shape=image_shape, dtype=dtype)
elif model == 'resnet18':
net, params = nnvm.testing.resnet.get_workload(num_layers=18,
batch_size=1, image_shape=image_shape, dtype=dtype)
elif model == 'mobilenet':
net, params = nnvm.testing.mobilenet.get_workload(
batch_size=1, image_shape=image_shape, dtype=dtype)
else:
raise ValueError('no benchmark prepared for {}.'.format(model))
# compile
opt_level = 2 if dtype == 'float32' else 1
with nnvm.compiler.build_config(opt_level=opt_level):
graph, lib, params = nnvm.compiler.build(
net, tvm.target.mali(), shape={"data": data_shape}, params=params,
dtype=dtype, target_host=args.target_host)
# upload model to remote device
tmp = util.tempdir()
lib_fname = tmp.relpath('net.tar')
lib.export_library(lib_fname)
if args.host is not None:
remote = rpc.connect(args.host, args.port)
remote.upload(lib_fname)
ctx = remote.cl(0)
rlib = remote.load_module('net.tar')
rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
else:
ctx = tvm.cl(0)
rlib = lib
rparams = params
# create graph runtime
module = runtime.create(graph, rlib, ctx)
module.set_input('data', tvm.nd.array(np.random.uniform(size=(data_shape)).astype(dtype)))
module.set_input(**rparams)
# benchmark
# print("============================================================")
# print("model: %s, dtype: %s" % (model, dtype))
# the num of runs for warm up and test
num_warmup = 10
num_test = 60
if model == 'mobilenet': # mobilenet is fast, need more runs for stable measureament
num_warmup *= 5
num_test *= 5
# perform some warm up runs
# print("warm up..")
warm_up_timer = module.module.time_evaluator("run", ctx, num_warmup)
warm_up_timer()
# test
# print("test..")
ftimer = module.module.time_evaluator("run", ctx, num_test)
prof_res = ftimer()
# print("cost per image: %.4fs" % prof_res.mean)
print("backend: TVM-mali\tmodel: %s\tdtype: %s\tcost:%.4f" % (model, dtype, prof_res.mean))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, required=True, choices=['vgg16', 'resnet18', 'mobilenet', 'all'],
help="The model type.")
parser.add_argument('--dtype', type=str, default='float32', choices=['float16', 'float32'])
parser.add_argument('--host', type=str, help="The host address of your arm device.", default=None)
parser.add_argument('--port', type=int, help="The port number of your arm device", default=None)
parser.add_argument('--target-host', type=str, help="The compilation target of host device.", default=None)
args = parser.parse_args()
# set parameter
batch_size = 1
num_classes = 1000
image_shape = (3, 224, 224)
# load model
data_shape = (batch_size,) + image_shape
out_shape = (batch_size, num_classes)
if args.model == 'all': # test all
for model in ['vgg16', 'resnet18', 'mobilenet']:
for dtype in ['float32', 'float16']:
run_case(model, dtype)
time.sleep(10)
else: # test single
run_case(args.model, args.dtype)