Printing PyTorch Model Summary Information

When I use the Keras neural network library, I often use the built-in model.summary() function to display information about a network such as the number of weights and biases.

When I use the PyTorch neural network library, I rarely display network information. I’m not really sure why. I suspect that when I use PyTorch I almost always design and implement my networks from scratch and so I already know everything about my networks.

I ran across some blog posts that describe how to display model information for a PyTorch neural network, but some of the posts were contradictory. So I decided to do a quick experiment. Briefly, 1.) you can get simple information just by issuing a print(network_name) statement, and 2.) you can get more detailed information by installing the torchinfo package and then calling its summary() function. Some of the blog-world confusion on this topic is related to the fact that the torchinfo package is a successor to the older torchsummary package, but torchsummary still exists.

I grabbed my existing MNIST CNN example to use a the basis for my PyTorch network information mini-exploration. I installed the torchinfo package by issuing the command “pip install torchinfo” in a shell.

At the top of the MNIST CNN program I added the statement:

from torchinfo import summary  # for network info
# import torchinfo as TI  # alternative syntax

Then in the program I displayed network information in two ways:

. . . 
# 2. create network
print("Creating CNN network with 2 conv and 3 linear ")
net = Net().to(device)

# 2b. show network info
print("Simple print(net) info: ")
print(net)

# 2c. more detailed
# requires separate package: pip install torchinfo
print("Detailed info from torchinfo package: ")
summary(net, (20, 1, 28, 28))  # bs, chnl, rows, cols
# TI.summary(net, (20, 1, 28, 28))  # alternative syntax
. . .

The detailed information display looked like:

====================================================
Layer (type:depth-idx)   Output Shape    Param #
====================================================
Net                 [20, 10]              --
├─Conv2d: 1-1       [20, 32, 24, 24]      832
├─MaxPool2d: 1-2    [20, 32, 12, 12]      --
├─Dropout: 1-3      [20, 32, 12, 12]      --
├─Conv2d: 1-4       [20, 64, 8, 8]        51,264
├─MaxPool2d: 1-5    [20, 64, 4, 4]        --
├─Linear: 1-6       [20, 512]             524,800
├─Dropout: 1-7      [20, 512]             --
├─Linear: 1-8       [20, 256]             131,328
├─Linear: 1-9       [20, 10]              2,570
====================================================
Total params: 710,794
Trainable params: 710,794
Non-trainable params: 0
Total mult-adds (M): 88.38
====================================================
Input size (MB): 0.06
Forward/backward pass size (MB): 3.73
Params size (MB): 2.84
Estimated Total Size (MB): 6.63
====================================================

I think that printing detailed network/model information is probably most useful in scenarios where you load a model that was created by someone else, such as a ResNet model.



Left: Model summary = The design is direct and to the point. Center: Model summary = The designer is getting the hang of it. Right: Model summary = It appears that the designer may have an inflated opinion of the fashion value of plastic.


This entry was posted in PyTorch. Bookmark the permalink.

Leave a Reply

Please log in using one of these methods to post your comment:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s