Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda8.0 can't train YOLOv3 Loss : nan #366

Closed
ShoufaChen opened this issue Oct 7, 2018 · 20 comments
Closed

cuda8.0 can't train YOLOv3 Loss : nan #366

ShoufaChen opened this issue Oct 7, 2018 · 20 comments

Comments

@ShoufaChen
Copy link

When I use cuda8.0, I run the yolov3 script using 2 GPUs, just changing the batch-size to 32, I got the loss nan:

INFO:root:[Epoch 0][Batch 99], LR: 5.99E-05, Speed: 31.597 samples/sec, ObjLoss=nan, BoxCenterLoss=nan, BoxScaleLoss=nan, ClassLoss=nan
INFO:root:[Epoch 0][Batch 199], LR: 1.20E-04, Speed: 32.253 samples/sec, ObjLoss=nan, BoxCenterLoss=nan, BoxScaleLoss=nan, ClassLoss=nan
INFO:root:[Epoch 0][Batch 299], LR: 1.81E-04, Speed: 31.947 samples/sec, ObjLoss=nan, BoxCenterLoss=nan, BoxScaleLoss=nan, ClassLoss=nan

When I comment the net.hybridize() in train() and validate() as mentioned here , I can run it with proper loss, but sacrificing the training speed.

Besides, If I use batch-szie=4, the loss won't become nan with net.hybridize() so I guess that it is not the smaller batch size resulting in nan.

cuda9.0 with bath-size=32 is also OK.

@zhreshold
Copy link
Member

I don't think this is related to cuda or hybridize. If you are getting random nan, especially the beginning iterations, it's probably related to the warm up setting. Warm up is a must have for YOLO3 models, you can increase this number /~https://github.com/dmlc/gluon-cv/blob/master/scripts/detection/yolo/train_yolo3.py#L156 to make it more stable.

In a new PR I have made it a command line argument, so it will be more convenient to tweak

@zhreshold
Copy link
Member

Let me know if the warm up stuff is useless on CUDA8

@ShoufaChen
Copy link
Author

Thank you very much for your reply. I'll try to tweak the warm up argument on CUDA8 later as the gpu is running now.

And I am quite curious about the totally same code and settings can run properly on CUDA9.0.

@ShoufaChen
Copy link
Author

ShoufaChen commented Oct 9, 2018

I just run the same code and settings on another Ubuntu16.04, it will work properly without changing the warm up argument, on cuda8.0 😕 😕
It seems the the nan error raises randomly..

Another small flaw is that the link yolo3_voc train script and for416 is for coco dataset rather than Pascal Voc. 😃

@zhreshold
Copy link
Member

script fixed, thanks for spotting.

Every training process is randomized, so you will get random behavior. So I suggest you to increase the warm up epoch to reduce the chance of Nan, otherwise it's not predictable.

@ShoufaChen
Copy link
Author

OK, thank you very much.

@nicklhy
Copy link

nicklhy commented Oct 9, 2018

@zhreshold , still got nan loss from the first log message after changing warmup_epochs from 2 to larger numbers like 5 or 10.

➜  mx-yolov3 git:(master) ✗ python3 train_yolo3.py --data-root /mnt/workspace/shared_datasets/COCO --dataset coco --gpus 0,1,2,3 --num-workers 10 --syncbn
loading annotations into memory...
Done (t=16.28s)
creating index...
index created!
loading annotations into memory...
Done (t=0.47s)
creating index...
index created!
INFO:root:Namespace(batch_size=64, data_root='/mnt/workspace/shared_datasets/COCO', data_shape=416, dataset='coco', epochs=200, gpus='0,1,2,3', log_interval=100, lr=0.001, lr_decay=0.1, lr_decay_epoch='160,180', momentum=0.9, network='darknet53', num_samples=117266, num_workers=10, resume='', save_interval=10, save_prefix='yolo3_darknet53_coco', seed=233, start_epoch=0, syncbn=True, val_interval=1, wd=0.0005)
INFO:root:Start training from [Epoch 0]
[14:10:42] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:109: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
INFO:root:[Epoch 0][Batch 99], LR: 2.70E-06, Speed: 84.074 samples/sec, ObjLoss=nan, BoxCenterLoss=nan, BoxScaleLoss=nan, ClassLoss=nan
INFO:root:[Epoch 0][Batch 199], LR: 5.43E-06, Speed: 58.834 samples/sec, ObjLoss=nan, BoxCenterLoss=nan, BoxScaleLoss=nan, ClassLoss=nan
INFO:root:[Epoch 0][Batch 299], LR: 8.16E-06, Speed: 87.941 samples/sec, ObjLoss=nan, BoxCenterLoss=nan, BoxScaleLoss=nan, ClassLoss=nan

(I add the --data-root arg to specify dataset's root directory.)

@zhreshold
Copy link
Member

@nicklhy Was that on cuda 9 or later?

@nicklhy
Copy link

nicklhy commented Oct 10, 2018

@zhreshold cuda 8.0, cudnn 7.0.5, Titan XP.

I am wondering if there is a specific requirement of the CuDNN's version?

@ShoufaChen
Copy link
Author

ShoufaChen commented Oct 10, 2018

@zhreshold @nicklhy
I tried cuda8.0, Titan XP
both cudnn 7.1.3 and without cudnn raise the nan problem.

@zhreshold
Copy link
Member

Due to the mixed envs and versions, I am not able to locate the problem. Also we have several fixes on the YOLO network and training script which has been merged to master recently.

Can you guys try the latest master and report if any of the combinations got nan even with --warmup-epochs 10 or something.

I'd appreciate it very much.

@ShoufaChen @nicklhy

@nicklhy
Copy link

nicklhy commented Oct 11, 2018

@zhreshold , Just tried the newest gluoncv with mxnet_cu80-1.3.0.post0. The nan loss still exists with --warmup-epochs 10. The training script is called like below:

python3 train_yolo3_new.py --data-root /mnt/workspace/shared_datasets/VOC --dataset voc --gpus 0,1,2,3 --num-workers 10 --syncbn --batch-size 32 --warmup-epochs 10

BTW, the gpu memory seems to be much larger than the old version. I can not use the default batch size(64) with 4 Titan XP gpus now.

@ShoufaChen
Copy link
Author

I am sorry that I removed my mxnet-cu80 env because there is little memory left on my computer.

@kuonangzhe
Copy link
Contributor

I tested yolov3 on my own dataset, and there was also nan problem. I checked that the default initial lr was 0.001. When I set it to half, which is 0.0005, the training becomes nomal with no nan problem. This might cause the problem.

@weiaicunzai
Copy link

same problem here, run demo from tutorial train yolov3 on pascal_voc dataset, also raise nan loss, using P40 with cuda8.0

@weiaicunzai
Copy link

never mind, I changed the warm-up args, then everything works fine.

@wshuail
Copy link

wshuail commented Nov 16, 2018

hey guys make sure your driver for GPU is compatible with your cuda version.
This happened to me before I updated the driver to the latest.

@ymm4739
Copy link

ymm4739 commented Dec 26, 2018

i trained my own dataset, and there was also nan problem. lr was 0.001 or bigger and cuda9.0. When i changed lr to 0.0005, it worked. Maybe lr was too big?

@BackT0TheFuture
Copy link

Driver Version: 410.48
CUDA 10
CUDNN 7.4.2.24
MXNET mxnet-cu100mkl

@zhreshold same problem .

@zhreshold
Copy link
Member

Just an update, the root cause is found and fix has been merged to master: apache/mxnet#14209

By using master/nightly built pip package hopefully you won't meet same problem any more

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants