Skip to content
This repository has been archived by the owner on Mar 12, 2024. It is now read-only.

Recommendations for training Detr on custom dataset? #9

Open
lessw2020 opened this issue May 28, 2020 · 205 comments
Open

Recommendations for training Detr on custom dataset? #9

lessw2020 opened this issue May 28, 2020 · 205 comments
Labels
question Further information is requested

Comments

@lessw2020
Copy link
Contributor

Very impressed with the all new innovative architecture in Detr!
Can you clarify recommendations for training on a custom dataset?
Should we build a model similar to demo and train, or better to use and fine tune a full coco pretrained model and adjust the linear layer to desired class count?
Thanks in advance for any input.

@Zumbalamambo
Copy link

+1

@PancakeAwesome
Copy link

agree

@alcinos
Copy link
Contributor

alcinos commented May 28, 2020

Hello,
Thanks for your interest in DETR.
It depends on the size of your dataset. If you have enough data (say at least 10K), training from scratch should work just fine. You'll need to prepare the data in the coco format and then follow instructions from the Readme. Note that if your dataset has a substantially different average number of objects per image than coco, you might need to adjust the number of object queries (--num_queries) It should be strictly higher than the max number of objects you may have to detect, and it's good to have some slack (in coco we use 100, the max number of objects in a coco image is ~70)

Fine-tuning should work in theory, but at the moment it's not tested/supported. If you want to give it a go anyways, you just need to --resume from one of the checkpoint we provide. Feel free to report back any results you obtain :)

Best of luck

@alcinos alcinos added the question Further information is requested label May 28, 2020
@raviv
Copy link

raviv commented May 30, 2020

Hi,

When fine-tuning from model zoo, using my own dataset, how should I modify the number of classes?
Loading the model fails (as expected) on:

RuntimeError: Error(s) in loading state_dict for DETR:
	size mismatch for class_embed.weight: copying a param with shape torch.Size([92, 256]) from checkpoint, the shape in current model is torch.Size([51, 256]).
	size mismatch for class_embed.bias: copying a param with shape torch.Size([92]) from checkpoint, the shape in current model is torch.Size([51]).

As I have 50 labels, and the checkpointed model has 91.

Thanks!

@alcinos
Copy link
Contributor

alcinos commented May 30, 2020

If you just want to replace the classification head, you need to erase it before loading the state dict. One approach would be:

model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=False, num_classes=50)
checkpoint = torch.hub.load_state_dict_from_url(
            url='https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth',
            map_location='cpu',
            check_hash=True)
del checkpoint["model"]["class_embed.weight"]
del checkpoint["model"]["class_embed.bias"]
model.load_state_dict(checkpoint["model"], strict=False)

Best of luck.

@cbasavaraj
Copy link

cbasavaraj commented May 30, 2020

It would be easier (or at least more standard practice) to first load the pre-trained model, and then replace the classification head.

@lessw2020
Copy link
Contributor Author

lessw2020 commented May 30, 2020

related question but how should we downgrade the query number for smaller classes ( in terms of continuing from the approach above)?
For example I only have 5 classes to detect and each image will have exactly 5 classes per image, so I was planning to run with queries = 12 instead of the default 100 (or should it be 5 if we know that's the max our images will ever have...)

I'm looking at model.query_embed with (100,256) and assume that is the right place to adjust but unclear. If we adjust via model.query_embed.num_embeddings=my_new_query_count, is that enough?
(update - I'm working on this and the DETR model stores a self.num_queries as well, but this is only referenced later for segmentation.
But to be correct should update both model.num_queries and the model.query_embed.num_embeddings would need to be adjusted together...)

@lessw2020
Copy link
Contributor Author

Also wouldn't we want to re-init the weights in class_embed to normal or uniform after wiping the checkpoint weights to kick off the new training?

@alcinos
Copy link
Contributor

alcinos commented May 31, 2020

If you're fine-tuning, I don't recommend changing the number of queries on the fly, it is extremely unlikely to work out of the box. In this case you're probably better off retraining from scratch (you can change the --num_queries arg from our training script).

As for the initialization of class_embed, the solution I posted above makes sure it is initialized as it should.

Best of luck

@lessw2020
Copy link
Contributor Author

Hi @alcinos - excellent, thanks tremendously for the advice here, esp on a Sat night.
I will try both fine tuning for now (with smaller dataset and will not touch num_queries) and from scratch as we'll have a larger dataset soon, and update here to share results.
Thanks again!

@raviv
Copy link

raviv commented May 31, 2020

My dataset has images of various sizes.
Do I need to resize them to a specific size?

@lessw2020
Copy link
Contributor Author

My dataset has images of various sizes.
Do I need to resize them to a specific size?

I can't answer definitively but if you look at the code in datasets/coco.py, you can see how they handled their image resizing for coco training. Basically they do random rescaling per the scales list, with the largest size dimension maxed at 1333:
`
scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]

if image_set == 'train':
    return T.Compose([
        T.RandomHorizontalFlip(),
        T.RandomSelect(
            T.RandomResize(scales, max_size=1333),
            T.Compose([
                T.RandomResize([400, 500, 600]),
                T.RandomSizeCrop(384, 600),
                T.RandomResize(scales, max_size=1333),

`

The colab example used a max size of 800, with half precision weights.

Thus if your images are all larger than 1333 in one dimension, then they'll all be resized below that with padding anyway.

Hopefully others can add more info here but hope this provides some starter info for you.

@alcinos
Copy link
Contributor

alcinos commented May 31, 2020

My dataset has images of various sizes.
Do I need to resize them to a specific size?

As was noted by @lessw2020, the images will be randomly resized in an appropriate range by our data-augmentation. The images will then be padded, so having different sizes is not an issue.

Thanks for wonderful work,
What is your recommendation to use DETR for single object detection(e.g., scene text detection) datasets?

I'm not sure about the specifics of your dataset, but in general I'd say all the general advice provided in this thread apply to the case where there is only one object class.

@raviv
Copy link

raviv commented Jun 1, 2020

@alcinos, @lessw2020 It seems that these resizes are for data augmentation when training.
As I'm using my own dataloader and augmentations, my question is does the architecture (or implementation) expects images to have some maximum size?
Thanks.

@fmassa
Copy link
Contributor

fmassa commented Jun 1, 2020

@raviv no, the architecture doesn't expect a maximum size, but note that the Transformer Encoder is quadratic wrt the number of pixels in the feature map, so if your image is very large (like larger than 2000 pixels), you might face memory issues.

@raviv
Copy link

raviv commented Jun 5, 2020

This is how my losses look like so far.
Would love to get other's input on their attempt to train on DETR on custom datasets.

image

@m-klasen
Copy link

m-klasen commented Jun 5, 2020

Hi, currently working with my custom dataset. Relatively small with ~2k Train, 400 valid images (32 video sequence clips) and only 4 classes with a maximum of 6 instances per image.
For my first training attempt i set num_queries=20 and discared all transformer weights etc.
I trained 400 epochs with apex fp16 at lr 1e-4 with a lr_drop to 1e-5 at 200.
image
Evaluation at ep400 gives me a mAP of 0.45 which i can benchmark against a known good MaskRCNN from my colleague who achieves 0.63 mAP.
My questions now are, which are the primary reason for the weaker performance?

  1. More training? Better LR adjustments with a decay for example (hard to do with the first attempt when you are going in blind)?
  2. reduce num_queries further?
  3. class/bg loss coef adjustment?
  4. ...?

@fmassa
Copy link
Contributor

fmassa commented Jun 5, 2020

@mlk1337 thanks for sharing the results!

I think you are at a good starting point. I would say that from the logs you might want to change the eos_coef a bit and try different values. I think the number of num_queries is ok, but the eos_coef probably needs to be adapted.

I don't know if using apex with fp16 affects something or not as I haven't tried, but maybe @szagoruyko can comment on this?

@raviv your training logs are very weird, it seems that the model stopped working at some point early in training. Are you using gradient clipping (it's on by default)

@raviv
Copy link

raviv commented Jun 5, 2020

@fmassa I'm running with the default args.
To keep things simple, I'm using 1 class and disabled all augmentations.
The behavior was similar when training multiple classes and with aug enabled.
To speed things up I'm using a subset of my dataset with 8K train and 2K test

@alcinos
Copy link
Contributor

alcinos commented Jun 5, 2020

@mlk1337 with such a small dataset, I'd recommend trying to fine-tune the class head, while starting from a pre-trained encoder/decoder. You'll have to keep the 100 queries if you do that, but unless you're after very marginal speed improvement it shouldn't hurt.

@tanulsingh
Copy link

tanulsingh commented Jun 5, 2020

Hey , I wanted to fine tune DETR myself on custom datasets , But I am new to all , I have been using torchvision models all the time to fine tune on my dataset . I would be glad if someone shares a demo code for fine-tuning @alcinos

@lessw2020
Copy link
Contributor Author

@raviv - happy to share my training results but can you post your plot code for the graphs and I'll use that? Right now I just have text output as the detr plot_utils wasn't working (wasn't sure if I should debug that or just move it to tensorboard, looking at that now).
@mlk1337 - same question, can you share your plot code for the logs?

@m-klasen
Copy link

m-klasen commented Jun 5, 2020

@tanulsingh I wrote quick gist on how you can modify DETR to finetune on your own coco-formatted dataset Link. Hope this helps.

@m-klasen
Copy link

m-klasen commented Jun 5, 2020

@lessw2020

coco_eval = pd.DataFrame(pd.np.stack(df.test_coco_eval.dropna().values)[:, 1]).ewm(com=ewm_col).mean()

changed to
pd.DataFrame(pd.np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]).ewm(com=ewm_col).mean()
worked for me (for bounding boxes)

@lessw2020
Copy link
Contributor Author

@lessw2020 I want to use DETR with different backbone, can you provide any starter code for this. Thank you

Hi @ayberksener - here's a link with some initial code, and if you are using mobilenet there's a full implementation in a link further down in that thread:
#154

@Dicko87
Copy link

Dicko87 commented Dec 14, 2020

Hi @lessw2020 hmmm I am just using the DETR folders... emmm can’t see a train.py just detr.py and main.py
Sorry, the results I am on about is the mAP, it’s different every time I run the model.

@Dicko87
Copy link

Dicko87 commented Jan 5, 2021

Hi Guys, has anyone yet manage to get reproducibe results, I mean I have tried everything I can think of and to this day cannot get reproducible results. Each time I run the model, the results for the first epoch are identical but thereafter the numbers begin to diverge. I have also went through each and every class and function in every .py file and added the seed function and the results are still not reproducible !

@Dicko87
Copy link

Dicko87 commented Jan 6, 2021

Hi folks, I am struggling with getting repeatable result for a DETR model I am running with pytorch.
I have defined the following function:
import os
from numpy import random
import numpy as np
def seed_torch(seed=42):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
Please note that I have also tried this function with cudnn_benchmark set to false.
This is the folder structure for DETR and the py files contained withing it.
image

I have tried using the seed function in each file separately to see if anyone of them would make my results reproducible and this did not work.
I then tried using the seed function in all files in each folder separately and this did not work.
I tried using the seed function in the first three folders and the first epoch was repeatable but after the first epoch the numbers were different.
The numbers I am talking about look like this:

image

But I am comparing losses for now.
In the end I copied the seed function to every py file and then used the function in every class and definition in the py files and the results are still not reproducible.
Does anybody have any idea how to resolve this?
I am running my code in the Anaconda terminal.

@arunraja-hub
Copy link

@tanulsingh I wrote quick gist on how you can modify DETR to finetune on your own coco-formatted dataset Link. Hope this helps.
Hi @m-klasen the link is no longer available. Could you please make it available? It will be very useful. Thanks

@Dicko87
Copy link

Dicko87 commented Aug 4, 2021

Hi @lessw2020 I hope you’re well. I have had some beautiful results with DETR for object detection. I am now wondering how easy would it be to adapt my model so that it can now segment my detected object, using panoptic segmentation. If you’ve any experience with this, I would be greatful for any info. 😀 thanks 😊

@Dicko87
Copy link

Dicko87 commented Aug 26, 2021

Hi Guys, I am having a problem. I trained the panoptic detr model, and tried to follow their notebook to see the results, when I tried to plot the predicted masks, I received the error:
KeyError: 'pred_masks
image

I do not think that the weights that were saved included the masks as when I print the model output, I get.
image

The command used to train the model was:
python main.py --masks --dataset_file coco_panoptic --coco_path /home/detr/datasets/mycoco --coco_panoptic_path /home/detr/datasets/coco_panoptic --epochs 25 --num_classes 2 --lr=1e-5 --lr_drop 15 --batch_size=4 --num_workers=4 --output_dir="outputs" --frozen_weights /home/detr/outputs/checkpoint0100.pth

Any help would be appreciated, thank you.

@ghost
Copy link

ghost commented Sep 6, 2021

Hi Guys, I am training from scratch on the Pascal VOC dataset (train: trainval07+12 ~ 16.5k images, val: test2007 ~5k images). I adjusted num_classes = 21 and num_queries = 70. You can see my notebook [here]. I trained over 100 epochs but mAP was also too low.
image
image
image
File log.txt
I want to ask you Is DETR really too long and I need to take more time to train or has my notebook and my code had a problem?
Thank you in advance.

@Dicko87
Copy link

Dicko87 commented Sep 6, 2021

You need to show us your training and validation curves. See the log.txt file

@cc-guowenchang
Copy link

Hi ,Did you solve it?
I'm in a similar situation to you

@linsecDev
Copy link

linsecDev commented Jul 4, 2022

Hello, Thanks for your interest in DETR. It depends on the size of your dataset. If you have enough data (say at least 10K), training from scratch should work just fine. You'll need to prepare the data in the coco format and then follow instructions from the Readme. Note that if your dataset has a substantially different average number of objects per image than coco, you might need to adjust the number of object queries (--num_queries) It should be strictly higher than the max number of objects you may have to detect, and it's good to have some slack (in coco we use 100, the max number of objects in a coco image is ~70)

Fine-tuning should work in theory, but at the moment it's not tested/supported. If you want to give it a go anyways, you just need to --resume from one of the checkpoint we provide. Feel free to report back any results you obtain :)

Best of luck

Hi, Thanks for the resolution.
I have two follow-up questions. Would highly appreciate your response.

1- What should be the preferable size of custom dataset to get good results from fine-tuning?

2- On which cloud platform (AWS, Google cloud, etc.) should I set-up pre-trained DETR for quick inference. Can you hint on some preferable hardware resources (GPU, CPU, Memory) required for fine-tuning and quick inference. I've been researching on this lately, but couldn't get some reasonable answer specifically for DETR.

Will highly appreciate your valuable response. Thanks!

@lessw2020
Copy link
Contributor Author

Hi @linsecDev,
I can offer some insight from my experience, but ultimately things will adjust based on your specific dataset etc.
1 - I found that anywhere from 1-2K is a good range for fine tuning. Again though, how many classes you have etc will impact this, but ~500 images per class is another way to put it.
Also note - I did a lot of work using synthetic data for fine tuning (from Unity perception package) which got to where I could get pretty strong detection with zero real images. Just went from synthetic and then to inference directly...I found about 8k synthetic did a good job, or 2K per class.

2 - Re: inference - quick is relative, as on cpu you might see 4 seconds and one could consider that quick, vs on a V100 (p3) GPU you would be under < 1 sec.
There's a cost tradeoff and the other thing to realize is a lot of your 'total' inference time will be
a) resizing a large image to one acceptable for DETR (i.e. 4086 -> 1024 for inference) assuming CPU, and
b)if your server is not started, could take 15 seconds to spin it up.
c)potentially uploading the image could also take 10+ seconds depending on connection type..
In these cases non-inference aspects of the process would easily be far more time costly vs the ~1 second to run DETR on the resized image with a V100.
Thus, make sure you consider the entire end to end process and account for those before solely looking just at inference time.
For fine tuning training of DETR, I generally found a p3 with single V100 (p3.2xlarge) was sufficient assuming you are using res50 as the backbone. That runs around $3/hour atm on AWS.
I haven't used other cloud providers so can't comment but you could look at V100 equivalents on other providers.

As noted, things will vary based on your specific use case but hope the above adds some insights!
Less

@linsecDev
Copy link

Hi @linsecDev, I can offer some insight from my experience, but ultimately things will adjust based on your specific dataset etc. 1 - I found that anywhere from 1-2K is a good range for fine tuning. Again though, how many classes you have etc will impact this, but ~500 images per class is another way to put it. Also note - I did a lot of work using synthetic data for fine tuning (from Unity perception package) which got to where I could get pretty strong detection with zero real images. Just went from synthetic and then to inference directly...I found about 8k synthetic did a good job, or 2K per class.

2 - Re: inference - quick is relative, as on cpu you might see 4 seconds and one could consider that quick, vs on a V100 (p3) GPU you would be under < 1 sec. There's a cost tradeoff and the other thing to realize is a lot of your 'total' inference time will be a) resizing a large image to one acceptable for DETR (i.e. 4086 -> 1024 for inference) assuming CPU, and b)if your server is not started, could take 15 seconds to spin it up. c)potentially uploading the image could also take 10+ seconds depending on connection type.. In these cases non-inference aspects of the process would easily be far more time costly vs the ~1 second to run DETR on the resized image with a V100. Thus, make sure you consider the entire end to end process and account for those before solely looking just at inference time. For fine tuning training of DETR, I generally found a p3 with single V100 (p3.2xlarge) was sufficient assuming you are using res50 as the backbone. That runs around $3/hour atm on AWS. I haven't used other cloud providers so can't comment but you could look at V100 equivalents on other providers.

As noted, things will vary based on your specific use case but hope the above adds some insights! Less

Thanks a lot for the detailed response!! It's quite helpful.

@wjtan99
Copy link

wjtan99 commented Aug 23, 2022

Hi, if I want to train on custom dataset and use different num_queries, the above suggestions say to load the backbone only and retrain the encoder/decoder and classification head. Then how do I load the pretrained backbone only, e.g., from a Resnet50 trained on the Imagenet? Thanks.

@shruti22kumari
Copy link

Hello, Thanks for your interest in DETR. It depends on the size of your dataset. If you have enough data (say at least 10K), training from scratch should work just fine. You'll need to prepare the data in the coco format and then follow instructions from the Readme. Note that if your dataset has a substantially different average number of objects per image than coco, you might need to adjust the number of object queries (--num_queries) It should be strictly higher than the max number of objects you may have to detect, and it's good to have some slack (in coco we use 100, the max number of objects in a coco image is ~70)

Fine-tuning should work in theory, but at the moment it's not tested/supported. If you want to give it a go anyways, you just need to --resume from one of the checkpoint we provide. Feel free to report back any results you obtain :)

Best of luck

hii,
i need to test the pothole dataset on pretrained model. how to do that?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests