When I use the Keras neural network library, I often use the built-in model.summary() function to display information about a network such as the number of weights and biases.
When I use the PyTorch neural network library, I rarely display network information. I’m not really sure why. I suspect that when I use PyTorch I almost always design and implement my networks from scratch and so I already know everything about my networks.
I ran across some blog posts that describe how to display model information for a PyTorch neural network, but some of the posts were contradictory. So I decided to do a quick experiment. Briefly, 1.) you can get simple information just by issuing a print(network_name) statement, and 2.) you can get more detailed information by installing the torchinfo package and then calling its summary() function. Some of the blog-world confusion on this topic is related to the fact that the torchinfo package is a successor to the older torchsummary package, but torchsummary still exists.
I grabbed my existing MNIST CNN example to use a the basis for my PyTorch network information mini-exploration. I installed the torchinfo package by issuing the command “pip install torchinfo” in a shell.
At the top of the MNIST CNN program I added the statement:
from torchinfo import summary # for network info # import torchinfo as TI # alternative syntax
Then in the program I displayed network information in two ways:
. . . # 2. create network print("Creating CNN network with 2 conv and 3 linear ") net = Net().to(device) # 2b. show network info print("Simple print(net) info: ") print(net) # 2c. more detailed # requires separate package: pip install torchinfo print("Detailed info from torchinfo package: ") summary(net, (20, 1, 28, 28)) # bs, chnl, rows, cols # TI.summary(net, (20, 1, 28, 28)) # alternative syntax . . .
The detailed information display looked like:
==================================================== Layer (type:depth-idx) Output Shape Param # ==================================================== Net [20, 10] -- ├─Conv2d: 1-1 [20, 32, 24, 24] 832 ├─MaxPool2d: 1-2 [20, 32, 12, 12] -- ├─Dropout: 1-3 [20, 32, 12, 12] -- ├─Conv2d: 1-4 [20, 64, 8, 8] 51,264 ├─MaxPool2d: 1-5 [20, 64, 4, 4] -- ├─Linear: 1-6 [20, 512] 524,800 ├─Dropout: 1-7 [20, 512] -- ├─Linear: 1-8 [20, 256] 131,328 ├─Linear: 1-9 [20, 10] 2,570 ==================================================== Total params: 710,794 Trainable params: 710,794 Non-trainable params: 0 Total mult-adds (M): 88.38 ==================================================== Input size (MB): 0.06 Forward/backward pass size (MB): 3.73 Params size (MB): 2.84 Estimated Total Size (MB): 6.63 ====================================================
I think that printing detailed network/model information is probably most useful in scenarios where you load a model that was created by someone else, such as a ResNet model.
Left: Model summary = The design is direct and to the point. Center: Model summary = The designer is getting the hang of it. Right: Model summary = It appears that the designer may have an inflated opinion of the fashion value of plastic.