-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Why FP16 training speed is too slow on Tesla T4 in Gluon? #13709
Comments
@PistonY could you try mxnet profiler /~https://github.com/apache/incubator-mxnet/blob/master/docs/faq/perf.md#profiler to see what operation is costly with fp16? https://mxnet.incubator.apache.org/tutorials/python/profiler.html |
@eric-haibin-lin Hi~I test it with mxnet profiler here are my script and result.It looks good. st_t = time()
with autograd.record():
output = train_net(trans.astype(dtype, copy=False))
loss = Loss(output, labels.astype(dtype, copy=False))
loss.backward()
trainer.step(batch_size)
end_t = time()
print(end_t - st_t) when fp16: float16
Start training with mixup.
[15:23:12] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
2.0626039505004883
0.26385951042175293
0.2520616054534912
0.2604227066040039
0.25570082664489746
0.26578330993652344
0.25952720642089844
0.2606792449951172
0.2637202739715576
0.3433563709259033
0.2613410949707031 fp32: float32
Start training with mixup.
[15:36:23] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
0.8311986923217773
0.20481181144714355
0.278350830078125
0.18038034439086914
0.21913409233093262
0.2587764263153076
0.17470550537109375
0.21522021293640137
0.2749063968658447
0.2962362766265869
0.2280411720275879
0.37300872802734375
0.18066024780273438
0.28769636154174805
0.2858397960662842
0.28676462173461914
0.24347591400146484
0.23549628257751465
0.29531288146972656 |
And if you need you can just run the script. |
@eric-haibin-lin Hello? |
sorry I’ve been busy with a submission deadline. Did you test with fixed input to see if it’s not bottlenecked by data loading ? |
Ok,I’ll test it later. |
I tried to use fixed input,FP32 work well but FP16 out of memory. from mxnet import nd, autograd
from mxnet import gluon
from mxnet.gluon import loss as gloss
from gluoncv.model_zoo import *
import mxnet as mx
import time
ctx = mx.gpu(0)
data = nd.random.normal(shape=(64, 3, 224, 224), ctx=ctx)
lable = nd.random.randint(low=0, high=1, shape=(64, 1), ctx=ctx)
net = resnet101_v2()
net.hybridize()
net.initialize(ctx=ctx)
net(data)
test_num = 500
dtype = 'float16' # float32 or float16
if dtype != 'float32':
net.cast(dtype)
Loss = gloss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(net.collect_params(),
'nag', {'learning_rate': 0.1, 'momentum': 0.9,
'multi_precision': True # when fp16 is enabled
})
sta = time.time()
for _ in range(test_num):
with autograd.record():
output = net(data.astype(dtype, copy=False))
loss = Loss(output, lable.astype(dtype, copy=False))
loss.backward()
trainer.step(128)
end = time.time()
print(end - sta) mxnet version is 1.5.0 (--pre) |
And I tried to only run forward sta = time.time()
for _ in range(test_num):
with autograd.record():
output = net(data.astype(dtype, copy=False))
# loss = Loss(output, lable.astype(dtype, copy=False))
# loss.backward()
# trainer.step(128)
end = time.time() FP32 costs 7.83 |
Were you using self-attention blocks with batch_dot operator? There was an improvement for fp16 in #13716 |
Thx,it worked. |
Hi, I tried to train with FP16 on Tesla T4, but it's speed is slower than GTX 1070 with FP32.
Could you please give me some suggests to solve that?
T4 is on Mxnet-cu100mkl and GTX1070 is on mxnet-cu90mkl
Here are my script and logs:
code: https://gist.github.com/PistonY/8dfcefdc46b747afd4d18b37f9a18665
logs:
T4 log:
GTX 1070 log:
The text was updated successfully, but these errors were encountered: