Fine-tuning SSD Lite in torchvision
Asked Answered
K

2

7

I want to fine-tune an object detector in PyTorch. For that, I was using this tutorial:

https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html

However, FastRCNN model is not suitable for my use case so instead, I fine-tuned SSDLite. I wrote this code to set a new classification head:

from functools import partial
from torchvision.models.detection import _utils as det_utils
from torchvision.models.detection.ssdlite import SSDLiteClassificationHead
    
model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=True)

in_channels = det_utils.retrieve_out_channels(model.backbone, (320, 320))
num_anchors = model.anchor_generator.num_anchors_per_location()
norm_layer  = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)
num_classes = 2
model.head.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, num_classes, norm_layer)

Since my model is not performing well, I want to ask the community if the above code is correct?

Thanks in advance.

Kipton answered 12/2, 2022 at 17:30 Comment(1)
I'm wondering if the poor performance stems from not also resetting the Regression head. You can re-initialize both classification and regression head in one go by doing model.head = SSDLiteHead(in_channels, num_anchors, num_classes, norm_layer) (see github.com/pytorch/vision/blob/…Expository
G
0

if your goal is to create a model with a custom num_classes, then you could just:

  1. Set the new custom class in the initialization of torchvision.
  2. Load the default pretrained model explicitly.
  3. Match the shape, and discard the weights with different shapes.
  4. Load the adjusted pretrained weight to the model, and you could do the retraining process.

As the following:

num_classes = 2
# Step 1.
model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=False, num_classes=num_classes)
checkpoint = torch.load(default_pretrained_model_path) # in windows, you could check the model here C:\Users\user\.cache\torch\hub\checkpoints

# Step 2, load the model state_dict and the default model's state_dict
mstate_dict = model.state_dict()
cstate_dict = torch.load(args.weights)

# Step 3.
for k in mstate_dict.keys():
    if mstate_dict[k].shape != cstate_dict[k].shape:
        print('key {} will be removed, orishape: {}, training shape: {}'.format(k, cstate_dict[k].shape, mstate_dict[k].shape))
        cstate_dict.pop(k)
# Step 4.
model.load_state_dict(cstate_dict, strict=False)

Hope it helps, cheers~

Goldenseal answered 25/5, 2022 at 9:17 Comment(0)
U
0

So, it's my first time doing this kind of thing, but I got good results with this:

model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(num_classes=num_classes, weights_backbone='DEFAULT', trainable_backbone_layers=0)

So I just use an existing backbone, start the rest from scratch and don't train the backbone. Compared to idea in the question and the answer by Briliantn it took at least 10 times less training to reach a similar point (probably because with neither of those approaches the backbone was frozen at the beginning). And with the frozen backbone you can increase batch size which speeds up training even more. Once the model stopped improving I unfroze the backbone and trained it some more.

Until answered 24/12, 2023 at 13:22 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.