This is a car image classification task for custom Stanford Cars dataset. We fine-tune the pre-trained ResNeXt-101-32x8d which provided in torchvision models. The highest testing accuracy can reach 91.64%.
- Define CarsData class for custom dataset
- Create model and set hyperparameters
- Load data and pre-process
- Train model
- Save model
- Plot training loss curve and accuracy curve
- Inference
Model | Testing Accuracy (%) |
---|---|
Wide ResNet-50-2 | 91.04 |
ResNeXt-101-32x8d | 91.64 |
Use pip to install python packages from requirements.txt.
pip install -r requirements.txt
We use custom Stanford Cars dataset as our data. Click here to download the dataset.
There are 196 car classes in the dataset. We have 11,185 images for training and 5,000 for testing. We divide the training data into 10,000 images for training and 1,185 images for validation.
To load data without modify the code, you need to set the data directory structure as:
data
+- training_labels.csv
+- training_data
| +- training_data
| +- training_image.jpg
| +- ...
+- testing_data
| +- testing_data
| +- testing_image.jpg
| +- ...
or you can pass the path of data directory and the label csv file while constructing the CarsDataset object.
- Resize image
- Random crop image
- Random horizontal flip
- Color jitter
- Random rotation
- Image normalization
- Pre-trained ResNeXt-101-32x8d
- Epochs: 100
- Batch size: 32
- Optimizer: SGD (learn rate=0.001, momentum=0.9)
At the end of training loop, we will save the model parameter file to models directory.
Load the testing images and start inference, and you will get an output csv file at the end.
- utkuozbulak, pytorch-custom-dataset-examples, viewed 11 Nov 2020, /~https://github.com/utkuozbulak/pytorch-custom-dataset-examples
- Soumith Chintala, TRAINING A CLASSIFIER, viewed 11 Nov 2020, https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py
- deepBear, Pytorch car classifier - 90% accuracy, viewed 11 Nov 2020, https://www.kaggle.com/deepbear/pytorch-car-classifier-90-accuracy