-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Recommendations for training Detr on custom dataset? #9
Comments
+1 |
agree |
Hello, Fine-tuning should work in theory, but at the moment it's not tested/supported. If you want to give it a go anyways, you just need to Best of luck |
Hi, When fine-tuning from model zoo, using my own dataset, how should I modify the number of classes?
As I have 50 labels, and the checkpointed model has 91. Thanks! |
If you just want to replace the classification head, you need to erase it before loading the state dict. One approach would be:
Best of luck. |
It would be easier (or at least more standard practice) to first load the pre-trained model, and then replace the classification head. |
related question but how should we downgrade the query number for smaller classes ( in terms of continuing from the approach above)? I'm looking at model.query_embed with (100,256) and assume that is the right place to adjust but unclear. If we adjust via model.query_embed.num_embeddings=my_new_query_count, is that enough? |
Also wouldn't we want to re-init the weights in class_embed to normal or uniform after wiping the checkpoint weights to kick off the new training? |
If you're fine-tuning, I don't recommend changing the number of queries on the fly, it is extremely unlikely to work out of the box. In this case you're probably better off retraining from scratch (you can change the As for the initialization of Best of luck |
Hi @alcinos - excellent, thanks tremendously for the advice here, esp on a Sat night. |
My dataset has images of various sizes. |
I can't answer definitively but if you look at the code in datasets/coco.py, you can see how they handled their image resizing for coco training. Basically they do random rescaling per the scales list, with the largest size dimension maxed at 1333:
` The colab example used a max size of 800, with half precision weights. Thus if your images are all larger than 1333 in one dimension, then they'll all be resized below that with padding anyway. Hopefully others can add more info here but hope this provides some starter info for you. |
As was noted by @lessw2020, the images will be randomly resized in an appropriate range by our data-augmentation. The images will then be padded, so having different sizes is not an issue.
I'm not sure about the specifics of your dataset, but in general I'd say all the general advice provided in this thread apply to the case where there is only one object class. |
@alcinos, @lessw2020 It seems that these resizes are for data augmentation when training. |
@raviv no, the architecture doesn't expect a maximum size, but note that the Transformer Encoder is quadratic wrt the number of pixels in the feature map, so if your image is very large (like larger than 2000 pixels), you might face memory issues. |
@mlk1337 thanks for sharing the results! I think you are at a good starting point. I would say that from the logs you might want to change the I don't know if using apex with fp16 affects something or not as I haven't tried, but maybe @szagoruyko can comment on this? @raviv your training logs are very weird, it seems that the model stopped working at some point early in training. Are you using gradient clipping (it's on by default) |
@fmassa I'm running with the default args. |
@mlk1337 with such a small dataset, I'd recommend trying to fine-tune the class head, while starting from a pre-trained encoder/decoder. You'll have to keep the 100 queries if you do that, but unless you're after very marginal speed improvement it shouldn't hurt. |
Hey , I wanted to fine tune DETR myself on custom datasets , But I am new to all , I have been using torchvision models all the time to fine tune on my dataset . I would be glad if someone shares a demo code for fine-tuning @alcinos |
@raviv - happy to share my training results but can you post your plot code for the graphs and I'll use that? Right now I just have text output as the detr plot_utils wasn't working (wasn't sure if I should debug that or just move it to tensorboard, looking at that now). |
@tanulsingh I wrote quick gist on how you can modify DETR to finetune on your own coco-formatted dataset Link. Hope this helps. |
Line 20 in 5617b89
changed to |
Hi @ayberksener - here's a link with some initial code, and if you are using mobilenet there's a full implementation in a link further down in that thread: |
Hi @lessw2020 hmmm I am just using the DETR folders... emmm can’t see a train.py just detr.py and main.py |
Hi Guys, has anyone yet manage to get reproducibe results, I mean I have tried everything I can think of and to this day cannot get reproducible results. Each time I run the model, the results for the first epoch are identical but thereafter the numbers begin to diverge. I have also went through each and every class and function in every .py file and added the seed function and the results are still not reproducible ! |
|
Hi @lessw2020 I hope you’re well. I have had some beautiful results with DETR for object detection. I am now wondering how easy would it be to adapt my model so that it can now segment my detected object, using panoptic segmentation. If you’ve any experience with this, I would be greatful for any info. 😀 thanks 😊 |
Hi Guys, I am training from scratch on the Pascal VOC dataset (train: trainval07+12 ~ 16.5k images, val: test2007 ~5k images). I adjusted num_classes = 21 and num_queries = 70. You can see my notebook [here]. I trained over 100 epochs but mAP was also too low. |
You need to show us your training and validation curves. See the log.txt file |
|
Hi, Thanks for the resolution. 1- What should be the preferable size of custom dataset to get good results from fine-tuning? 2- On which cloud platform (AWS, Google cloud, etc.) should I set-up pre-trained DETR for quick inference. Can you hint on some preferable hardware resources (GPU, CPU, Memory) required for fine-tuning and quick inference. I've been researching on this lately, but couldn't get some reasonable answer specifically for DETR. Will highly appreciate your valuable response. Thanks! |
Hi @linsecDev, 2 - Re: inference - quick is relative, as on cpu you might see 4 seconds and one could consider that quick, vs on a V100 (p3) GPU you would be under < 1 sec. As noted, things will vary based on your specific use case but hope the above adds some insights! |
Thanks a lot for the detailed response!! It's quite helpful. |
Hi, if I want to train on custom dataset and use different num_queries, the above suggestions say to load the backbone only and retrain the encoder/decoder and classification head. Then how do I load the pretrained backbone only, e.g., from a Resnet50 trained on the Imagenet? Thanks. |
hii, |
Very impressed with the all new innovative architecture in Detr!
Can you clarify recommendations for training on a custom dataset?
Should we build a model similar to demo and train, or better to use and fine tune a full coco pretrained model and adjust the linear layer to desired class count?
Thanks in advance for any input.
The text was updated successfully, but these errors were encountered: