AI 编译器学习笔记之十一 -- 网络结构可视化torchsummary库

发布于:2024-11-03 ⋅ 阅读:(117) ⋅ 点赞:(0)

建议安装Torch-summary库而非Torchsummary库,前者在继承后者的函数外还解决了后者存在的诸多Bug

orchsummary库是深度学习网络结构可视化常用的库:安装地址
Torch-summary库是torchsummary的加强版,库的介绍和安装地址.

Torchsummary库常遇问题

问题一:使用torchsummary查看网络结构时报错:AttributeError: ‘list’ object has no attribute ‘size’

解决方法:安装torch-summary

pip uninstall torchsummary        # 卸载原来的torchsummary库
pip install torch-summary==1.4.4  # 安装升级版本torch-summary

场景1的使用方法:pytorch-summary

# 使用样式
from torchsummary import summary
summary(model, input_size=(channels, H, W))

# 多输入情况并且打印不同层的特征图大小
from torchsummary import summary
summary(model,first_input,second_input)

# 打印不同的内容
import torch
import torch.nn as nn
from torchsummary import summary

class LSTMNet(nn.Module):
    """ Batch-first LSTM model. """
    def __init__(self, vocab_size=20, embed_dim=300, hidden_dim=512, num_layers=2):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.encoder = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True)
        self.decoder = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        embed = self.embedding(x)
        out, hidden = self.encoder(embed)
        out = self.decoder(out)
        out = out.view(-1, out.size(2))
        return out, hidden

summary(
    LSTMNet(),
    (100,),
    dtypes=[torch.long],
    branching=False,
    verbose=2,
    col_width=16,
    col_names=["kernel_size", "output_size", "num_params", "mult_adds"],)

场景二:直接使用输入数据,输入shape根据输入自动推导

安装:pip install torchinfo,多输入

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

class MultipleInputNetDifferentDtypes(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1a = nn.Linear(300, 50)
        self.fc1b = nn.Linear(50, 10)

        self.fc2a = nn.Linear(300, 50)
        self.fc2b = nn.Linear(50, 10)

    def forward(self, x1, x2):
        x1 = F.relu(self.fc1a(x1))
        x1 = self.fc1b(x1)
        x2 = x2.type(torch.float)
        x2 = F.relu(self.fc2a(x2))
        x2 = self.fc2b(x2)
        x = torch.cat((x1, x2), 0)
        return F.log_softmax(x, dim=1)


model = MultipleInputNetDifferentDtypes()

input_data = torch.randn(1, 300)
other_input_data = torch.randn(1, 300).long()
summary(model, input_data=[input_data, other_input_data]) # 根据输入推导shape

summary(model, [(1, 300), (1, 300)], dtypes=[torch.float, torch.long]) # 指定输入shape


网站公告

今日签到

点亮在社区的每一天
去签到