Adv. PyTorch: Freezing Layers
If you’re planning to fine-tune a trained model on a different dataset, chances are you’re going to freeze some of the early layers and only update the later layers. I won’t go into the details of why you may want to freeze some layers and which ones should be frozen, but I’ll show you how to do it in PyTorch. Let’s get started!
We first need a pre-trained model to start with. The
models subpackage in
the torchvision
package provides definitions for many of the poplular model
architectures for image classification. You can construct these models by simply
calling their constructor, which would initialize the model with random weights.
To use the pre-trained models from the PyTorch Model Zoo, you can call the
constructor with the pretrained=True
argument. Let’s load the pretrained
VGG16 model:
import torch
import torch.nn as nn
import torchvision.models as models
vgg16 = models.vgg16(pretrained=True)
This will start downloading the pretrained model into your computer’s PyTorch
cache folder, which usually is the .cache/torch/checkpoints
folder under your
home directory.
There are multiple ways you can look into the model to see its modules and
layers. One way is using the .modules()
member function, which returns in
iterator containing all the member objects of the model. The .modules()
functions recursively goes thruogh all the modules and submodules of the model:
print(list(vgg16.modules()))
[VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
), Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
), Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), AdaptiveAvgPool2d(output_size=(7, 7)), Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
), Linear(in_features=25088, out_features=4096, bias=True), ReLU(inplace=True), Dropout(p=0.5, inplace=False), Linear(in_features=4096, out_features=4096, bias=True), ReLU(inplace=True), Dropout(p=0.5, inplace=False), Linear(in_features=4096, out_features=1000, bias=True)]
That’s a lot of information spewed out onto the screen! Let’s use the
.named_module()
function instead, which returns a (name, module) tuple and
only print the names:
for (name, module) in vgg16.named_modules():
print(name)
features
features.0
features.1
features.2
features.3
features.4
features.5
features.6
features.7
features.8
features.9
features.10
features.11
features.12
features.13
features.14
features.15
features.16
features.17
features.18
features.19
features.20
features.21
features.22
features.23
features.24
features.25
features.26
features.27
features.28
features.29
features.30
avgpool
classifier
classifier.0
classifier.1
classifier.2
classifier.3
classifier.4
classifier.5
classifier.6
That’s much better! We can see the top level modules are features, avgpool
and classifier. We can also see that the features and calssifier modules
consist of 31 and 7 layers respectively. These layers are not named, and only
have numbers associated with them. If you want to see an even more concise
representation of the network, you can use the .named_children()
function
which does not go inside the top level modules recursively:
for (name, module) in vgg16.named_children():
print(name)
features
avgpool
classifier
Now let’s see what layers are there under the features module. Here we use the
.children()
function to get the layers under the features module, since
these layers are not ’named':
for (name, module) in vgg16.named_children():
if name == 'features':
for layer in module.children():
print(layer)
Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ReLU(inplace=True)
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
We can even go deeper and look at the parameters in each layer. Let’s get the parameters of the first layer under the features module:
for (name, module) in vgg16.named_children():
if name == 'features':
for layer in module.children():
for param in layer.parameters():
print(param)
break
Parameter containing:
tensor([[[[-5.5373e-01, 1.4270e-01, 5.2896e-01],
[-5.8312e-01, 3.5655e-01, 7.6566e-01],
[-6.9022e-01, -4.8019e-02, 4.8409e-01]],
[[ 1.7548e-01, 9.8630e-03, -8.1413e-02],
[ 4.4089e-02, -7.0323e-02, -2.6035e-01],
[ 1.3239e-01, -1.7279e-01, -1.3226e-01]],
[[ 3.1303e-01, -1.6591e-01, -4.2752e-01],
[ 4.7519e-01, -8.2677e-02, -4.8700e-01],
[ 6.3203e-01, 1.9308e-02, -2.7753e-01]]],
[[[ 2.3254e-01, 1.2666e-01, 1.8605e-01],
[-4.2805e-01, -2.4349e-01, 2.4628e-01],
[-2.5066e-01, 1.4177e-01, -5.4864e-03]],
[[-1.4076e-01, -2.1903e-01, 1.5041e-01],
[-8.4127e-01, -3.5176e-01, 5.6398e-01],
[-2.4194e-01, 5.1928e-01, 5.3915e-01]],
[[-3.1432e-01, -3.7048e-01, -1.3094e-01],
[-4.7144e-01, -1.5503e-01, 3.4589e-01],
[ 5.4384e-02, 5.8683e-01, 4.9580e-01]]],
[[[ 1.7715e-01, 5.2149e-01, 9.8740e-03],
[-2.7185e-01, -7.1709e-01, 3.1292e-01],
[-7.5753e-02, -2.2079e-01, 3.3455e-01]],
[[ 3.0924e-01, 6.7071e-01, 2.0546e-02],
[-4.6607e-01, -1.0697e+00, 3.3501e-01],
[-8.0284e-02, -3.0522e-01, 5.4460e-01]],
[[ 3.1572e-01, 4.2335e-01, -3.4976e-01],
[ 8.6354e-02, -4.6457e-01, 1.1803e-02],
[ 1.0483e-01, -1.4584e-01, -1.5765e-02]]],
...,
[[[ 7.7599e-02, 1.2692e-01, 3.2305e-02],
[ 2.2131e-01, 2.4681e-01, -4.6637e-02],
[ 4.6407e-02, 2.8246e-02, 1.7528e-02]],
[[-1.8327e-01, -6.7425e-02, -7.2120e-03],
[-4.8855e-02, 7.0427e-03, -1.2883e-01],
[-6.4601e-02, -6.4566e-02, 4.4235e-02]],
[[-2.2547e-01, -1.1931e-01, -2.3425e-02],
[-9.9171e-02, -1.5143e-02, 9.5385e-04],
[-2.6137e-02, 1.3567e-03, 1.4282e-01]]],
[[[ 1.6520e-02, -3.2225e-02, -3.8450e-03],
[-6.8206e-02, -1.9445e-01, -1.4166e-01],
[-6.9528e-02, -1.8340e-01, -1.7422e-01]],
[[ 4.2781e-02, -6.7529e-02, -7.0309e-03],
[ 1.1765e-02, -1.4958e-01, -1.2361e-01],
[ 1.0205e-02, -1.0393e-01, -1.1742e-01]],
[[ 1.2661e-01, 8.5046e-02, 1.3066e-01],
[ 1.7585e-01, 1.1288e-01, 1.1937e-01],
[ 1.4656e-01, 9.8892e-02, 1.0348e-01]]],
[[[ 3.2176e-02, -1.0766e-01, -2.6388e-01],
[ 2.7957e-01, -3.7416e-02, -2.5471e-01],
[ 3.4872e-01, 3.0041e-02, -5.5898e-02]],
[[ 2.5063e-01, 1.5543e-01, -1.7432e-01],
[ 3.9255e-01, 3.2306e-02, -3.5191e-01],
[ 1.9299e-01, -1.9898e-01, -2.9713e-01]],
[[ 4.6032e-01, 4.3399e-01, 2.8352e-01],
[ 1.6341e-01, -5.8165e-02, -1.9196e-01],
[-1.9521e-01, -4.5630e-01, -4.2732e-01]]]], requires_grad=True)
Parameter containing:
tensor([ 0.4034, 0.3778, 0.4644, -0.3228, 0.3940, -0.3953, 0.3951, -0.5496,
0.2693, -0.7602, -0.3508, 0.2334, -1.3239, -0.1694, 0.3938, -0.1026,
0.0460, -0.6995, 0.1549, 0.5628, 0.3011, 0.3425, 0.1073, 0.4651,
0.1295, 0.0788, -0.0492, -0.5638, 0.1465, -0.3890, -0.0715, 0.0649,
0.2768, 0.3279, 0.5682, -1.2640, -0.8368, -0.9485, 0.1358, 0.2727,
0.1841, -0.5325, 0.3507, -0.0827, -1.0248, -0.6912, -0.7711, 0.2612,
0.4033, -0.4802, -0.3066, 0.5807, -1.3325, 0.4844, -0.8160, 0.2386,
0.2300, 0.4979, 0.5553, 0.5230, -0.2182, 0.0117, -0.5516, 0.2108],
requires_grad=True)
Now that we have access to all the modules, layers and their parameters, we can
easily freeze them by setting the parameters’ requires_grad
flag to False
.
This would prevent calculating the gradients for these parameters in the
backward
step which in turn prevents the optimizer from updating them.
Now let’s freeze all the parameters in the features module:
layer_counter = 0
for (name, module) in vgg16.named_children():
if name == 'features':
for layer in module.children():
for param in layer.parameters():
param.requires_grad = False
print('Layer "{}" in module "{}" was frozen!'.format(layer_counter, name))
layer_counter+=1
Layer "0" in module "features" was frozen!
Layer "1" in module "features" was frozen!
Layer "2" in module "features" was frozen!
Layer "3" in module "features" was frozen!
Layer "4" in module "features" was frozen!
Layer "5" in module "features" was frozen!
Layer "6" in module "features" was frozen!
Layer "7" in module "features" was frozen!
Layer "8" in module "features" was frozen!
Layer "9" in module "features" was frozen!
Layer "10" in module "features" was frozen!
Layer "11" in module "features" was frozen!
Layer "12" in module "features" was frozen!
Layer "13" in module "features" was frozen!
Layer "14" in module "features" was frozen!
Layer "15" in module "features" was frozen!
Layer "16" in module "features" was frozen!
Layer "17" in module "features" was frozen!
Layer "18" in module "features" was frozen!
Layer "19" in module "features" was frozen!
Layer "20" in module "features" was frozen!
Layer "21" in module "features" was frozen!
Layer "22" in module "features" was frozen!
Layer "23" in module "features" was frozen!
Layer "24" in module "features" was frozen!
Layer "25" in module "features" was frozen!
Layer "26" in module "features" was frozen!
Layer "27" in module "features" was frozen!
Layer "28" in module "features" was frozen!
Layer "29" in module "features" was frozen!
Layer "30" in module "features" was frozen!
Now that some of the parameters are frozen, the optimizer needs to be modified
to only get the parameters with requires_grad=True
. We can do this by writing
a Lambda function when constructing the optimizer:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, vgg16.parameters()), lr=0.001)
You can now start training your partially frozen model!