pytorch学习使用

发布于:2025-05-01 ⋅ 阅读:(14) ⋅ 点赞:(0)

1. 基础使用

1.1 基础信息

# 输出 torch 版本
print(torch.__version__)

# 判断 cuda 是否可用
print(torch.cuda.is_available())
"""
2.7.0
False
"""

1.2 创建tensor

# 创建一个5*3的矩阵,初始值为0.
print("-------- empty --------")
print(torch.empty(5, 3))  # 等价与 torch.empty((5, 3))

# 创建一个随机初始化的 5*3 矩阵,初始值在[0, 1)之间,符合均匀分布
print("-------- rand --------")
rand_x = torch.rand(5, 3)  # 等价于 rand_x = torch.rand((5, 3))
print(rand_x)
print(rand_x[:, 0]) # 访问第0列,输出为一维数组
print(rand_x[0, :]) # 访问第0行,输出为一维数组
print(rand_x[:, 0:2]) # 访问前两列,输出为二维数组
print(rand_x[:, [0, 2]]) # 访问第0列,第2列,输出为二维数组
print(rand_x[::2]) # 第0,2,4行,输出为二维数组

# 创建一个随机初始化的 2*10 矩阵,符合标准正态分布
print("-------- normal --------")
normal_x = torch.normal(0, 1, size=(2, 10))
print(normal_x)

# 创建一个随机初始化的一维矩阵,初始值在[0, 1000)之间
print("-------- randint --------")
randint_x = torch.randint(low=0, high=1000, size=(8,))
print(randint_x)

# 创建一个数值皆是 0,类型为 long 的矩阵
print("-------- zeros --------")
zero_x = torch.zeros(5, 3, dtype=torch.long)  # 等价于 zero_x = torch.zeros((5, 3), dtype=torch.long)
print(zero_x)

# 创建一个数值皆是 1. ,类型为 float 的矩阵
print("-------- ones --------")
one_x = torch.ones(5, 3, dtype=torch.float)  # 等价于 zero_x = torch.ones((5, 3), dtype=torch.float)
print(one_x)

# 创建一个对角线数值皆是 1. ,类型为 float 的矩阵
print("-------- eye --------")
eye_x = torch.eye(5, 3, dtype=torch.float)
print(eye_x)
# 提取对角线元素
s = eye_x.diag()
print(s)
# 将向量嵌入对角线生成矩阵
t = s.diag_embed() # 等价于:t = torch.diag_embed(s)
print(t)

# 创建一维tensor 数值是 [5.5, 3],值中有一个浮点数,因此所有数值均为浮点类型
print("-------- 一维tensor --------")
tensor1 = torch.tensor([5.5, 3])
print(tensor1)

# 创建二维tensor,值中无浮点数,因此所有数值均为整数类型
print("-------- 二维tensor --------")
tensor0 = torch.tensor([[1, 2], [3, 4], [5, 6]])
print(tensor0)

# 基于现有张量,创建一个新张量,其形状由参数 size 定义,所有元素值为1,默认继承原张量的数据类型(dtype)和设备(如CPU/GPU)
print("-------- new_ones --------")
tensor2 = tensor1.new_ones((2, 3))
print(tensor2)

# 修改数值类型
print("-------- randn_like --------")
tensor3 = torch.randn_like(tensor2, dtype=torch.float)
print(tensor3)

# 输出 tensor 的 size
print("-------- tensor size --------")
print(tensor3.size())
print(tensor3.shape)

# 将单元素张量转化为python标量
print("-------- tensor item --------")
tensor4 = torch.Tensor([3.14])
print(tensor4.item())
"""
-------- empty --------
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
-------- rand --------
tensor([[0.6596, 0.3999, 0.4556],
        [0.2757, 0.6820, 0.7506],
        [0.0683, 0.1522, 0.9666],
        [0.7557, 0.1943, 0.2406],
        [0.5978, 0.7308, 0.1105]])
tensor([0.6596, 0.2757, 0.0683, 0.7557, 0.5978])
tensor([0.6596, 0.3999, 0.4556])
tensor([[0.6596, 0.3999],
        [0.2757, 0.6820],
        [0.0683, 0.1522],
        [0.7557, 0.1943],
        [0.5978, 0.7308]])
tensor([[0.6596, 0.4556],
        [0.2757, 0.7506],
        [0.0683, 0.9666],
        [0.7557, 0.2406],
        [0.5978, 0.1105]])
tensor([[0.6596, 0.3999, 0.4556],
        [0.0683, 0.1522, 0.9666],
        [0.5978, 0.7308, 0.1105]])
-------- normal --------
tensor([[ 0.3300, -0.5461,  1.3952, -1.4907, -0.4039,  0.2111,  0.4386,  0.6213,
         -0.9563, -0.4214],
        [-0.2401, -1.3838, -1.1084,  1.8060, -0.1078, -0.1417, -1.5372, -0.3526,
          0.2074, -1.0423]])
-------- randint --------
tensor([474, 834, 908, 552, 926, 543, 338, 452])
-------- zeros --------
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]])
-------- ones --------
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
-------- eye --------
tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 0., 0.],
        [0., 0., 0.]])
tensor([1., 1., 1.])
tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
-------- 一维tensor --------
tensor([5.5000, 3.0000])
-------- 二维tensor --------
tensor([[1, 2],
        [3, 4],
        [5, 6]])
-------- new_ones --------
tensor([[1., 1., 1.],
        [1., 1., 1.]])
-------- randn_like --------
tensor([[ 0.4086,  0.6232, -0.6118],
        [ 0.3720,  0.0189,  1.0114]])
-------- tensor size --------
torch.Size([2, 3])
torch.Size([2, 3])
-------- tensor item --------
3.140000104904175
"""

1.3 tensor之间的运算

a = torch.tensor([[1.0, 2, 3], [4, 5, 6]])
b = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 加法
print("-------- tensor之间相加 --------")
c = a + b
print(c)
c = torch.add(a, b)
print(c)
c = a.add(b)
print(c)
# a.add_(b) # 会修改a的值,最后带下划线的都会修改调用者的值
# print(a)

# 减法
print("-------- tensor之间相减 --------")
c = a - b
print(c)
c = torch.sub(a, b)
print(c)
c = a.sub(b)
print(c)
# a.sub_(b) # 会修改a的值,最后带下划线的都会修改调用者的值
# print(a)

# 乘法,哈达玛积(对应元素相乘)
print("-------- tensor之间相乘 --------")
c = a * b
print(c)
c = torch.mul(a, b)
print(c)
c = a.mul(b)
print(c)
# a.mul_(b) # 会修改a的值,最后带下划线的都会修改调用者的值
# print(a)

# 除法
print("-------- tensor之间相除 --------")
c = a / b
print(c)
c = torch.div(a, b)
print(c)
c = a.div(b)
print(c)
# a.div_(b) # 会修改a的值,最后带下划线的都会修改调用者的值,a必须是浮点数类型
# print(a)

# 矩阵乘法
print("-------- tensor之间矩阵乘法 --------")
a = torch.tensor([[1, 1, 1], [1, 1, 1]])
b = torch.tensor([[1, 1], [1, 1], [1, 1]])
c = torch.mm(a, b)
print(c)
c = torch.matmul(a, b)
print(c)
c = a @ b
print(c)
c = a.matmul(b)
print(c)

# 幂运算
print("-------- tensor幂运算 --------")
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
c = torch.pow(a, 3)
print(c)
c = a.pow(3)
print(c)
c = a**3
print(c)
# a.pow_(3) # 会修改a的值,最后带下划线的都会修改调用者的值
# print(a)

# 开方运算
print("-------- tensor幂运算 --------")
a = torch.tensor([[1.0, 2, 3], [4, 5, 6]])
c = a.sqrt()
print(c)
# a.sqrt_() # 会修改a的值,最后带下划线的都会修改调用者的值,a必须是浮点数类型
# print(a)

# 对数
print("-------- tensor对数运算 --------")
a = torch.tensor([[1.0, 2, 3], [4, 5, 6]])
c = torch.log2(a)
print(c)
c = torch.log10(a)
print(c)
c = torch.log(a) # 以e为底
print(c)
# torch.log_(a) # 会修改a的值,最后带下划线的都会修改调用者的值,a必须是浮点数类型
# print(a)
"""
-------- tensor之间相加 --------
tensor([[ 2.,  4.,  6.],
        [ 8., 10., 12.]])
tensor([[ 2.,  4.,  6.],
        [ 8., 10., 12.]])
tensor([[ 2.,  4.,  6.],
        [ 8., 10., 12.]])
-------- tensor之间相减 --------
tensor([[0., 0., 0.],
        [0., 0., 0.]])
tensor([[0., 0., 0.],
        [0., 0., 0.]])
tensor([[0., 0., 0.],
        [0., 0., 0.]])
-------- tensor之间相乘 --------
tensor([[ 1.,  4.,  9.],
        [16., 25., 36.]])
tensor([[ 1.,  4.,  9.],
        [16., 25., 36.]])
tensor([[ 1.,  4.,  9.],
        [16., 25., 36.]])
-------- tensor之间相除 --------
tensor([[1., 1., 1.],
        [1., 1., 1.]])
tensor([[1., 1., 1.],
        [1., 1., 1.]])
tensor([[1., 1., 1.],
        [1., 1., 1.]])
-------- tensor之间矩阵乘法 --------
tensor([[3, 3],
        [3, 3]])
tensor([[3, 3],
        [3, 3]])
tensor([[3, 3],
        [3, 3]])
tensor([[3, 3],
        [3, 3]])
-------- tensor幂运算 --------
tensor([[  1,   8,  27],
        [ 64, 125, 216]])
tensor([[  1,   8,  27],
        [ 64, 125, 216]])
tensor([[  1,   8,  27],
        [ 64, 125, 216]])
-------- tensor幂运算 --------
tensor([[1.0000, 1.4142, 1.7321],
        [2.0000, 2.2361, 2.4495]])
-------- tensor对数运算 --------
tensor([[0.0000, 1.0000, 1.5850],
        [2.0000, 2.3219, 2.5850]])
tensor([[0.0000, 0.3010, 0.4771],
        [0.6021, 0.6990, 0.7782]])
tensor([[0.0000, 0.6931, 1.0986],
        [1.3863, 1.6094, 1.7918]])
"""

1.4 tensor和数字之间的运算

a = torch.tensor([[1.0, 2, 3], [4, 5, 6]])
b = 2

# 加减乘除,tensor中的每个数字都与b进行运算
print(a + b)
print(a - b)
print(a * b)
print(a / b)
"""
tensor([[3., 4., 5.],
        [6., 7., 8.]])
tensor([[-1.,  0.,  1.],
        [ 2.,  3.,  4.]])
tensor([[ 2.,  4.,  6.],
        [ 8., 10., 12.]])
tensor([[0.5000, 1.0000, 1.5000],
        [2.0000, 2.5000, 3.0000]])
"""

1.5 tensor尺寸修改

print("-------- 二维张量 --------")
a = torch.tensor([[1.0, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
c = a.reshape(-1, 2) # 转化成两列的二维矩阵
print(c)
c = a.view(-1, 3) # 转化成三列的二维矩阵
print(c)
c = a.reshape(-1) # 按照行转化为一维数组
print(c)
c = a.view(-1) # 按照行转化为一维数组
print(c)
print("-------- 三维维张量 --------")
a = torch.tensor([
    [
        [1.0, 2, 3],
        [4, 5, 6],
        [7, 8, 9]
    ],
    [
        [10, 11, 12],
        [13, 14, 15],
        [16, 17, 18]
    ],
    [
        [20, 21, 22],
        [23, 24, 25],
        [26, 27, 28]
    ]
])
c = a.view(-1, 3*3)
print(c)
"""
-------- 二维张量 --------
tensor([[ 1.,  2.],
        [ 3.,  4.],
        [ 5.,  6.],
        [ 7.,  8.],
        [ 9., 10.],
        [11., 12.]])
tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.]])
tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.])
tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.])
-------- 三维维张量 --------
tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18.],
        [20., 21., 22., 23., 24., 25., 26., 27., 28.]])
"""

1.6 tensor转置

print("-------- 一维张量 --------")
a = torch.tensor([1, 2, 3])  # shape: (3,)
c = a.transpose(0, 0)  # 无变化,仍为 [1, 2, 3]
print(c)
b = a.unsqueeze(0)  # shape: (1, 3),先通过 unsqueeze() 升维至二维,再转置
c = b.transpose(0, 1)  # shape: (3, 1)
print(c)

print("-------- 二维张量 --------")
a = torch.tensor([[1, 2], [3, 4], [5, 6]])  # shape: (3, 2)
c = a.transpose(0, 1)  # 或 a.t()
print(a)
print(c)  # tensor([[1, 3, 5], [2, 4, 6]]), shape: (2, 3)
c = a.t()
print(c)

print("-------- 三维张量 --------")
a = torch.arange(24).reshape(2, 3, 4)  # shape: (2, 3, 4)
c = a.transpose(0, 2)  # 交换第0和第2维
print(a)
print(c) # torch.Size([4, 3, 2])
"""
-------- 一维张量 --------
tensor([1, 2, 3])
tensor([[1],
        [2],
        [3]])
-------- 二维张量 --------
tensor([[1, 2],
        [3, 4],
        [5, 6]])
tensor([[1, 3, 5],
        [2, 4, 6]])
tensor([[1, 3, 5],
        [2, 4, 6]])
-------- 三维张量 --------
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
tensor([[[ 0, 12],
         [ 4, 16],
         [ 8, 20]],

        [[ 1, 13],
         [ 5, 17],
         [ 9, 21]],

        [[ 2, 14],
         [ 6, 18],
         [10, 22]],

        [[ 3, 15],
         [ 7, 19],
         [11, 23]]])
"""

1.7 tensor拼接

a = torch.tensor([[1.0, 2, 3], [4, 5, 6]])
b = torch.tensor([[7, 8, 9], [10, 11, 12]])
c = torch.stack([a, b], dim=0)
print(c)
c = torch.stack([a, b], dim=1)
print(c)
c = torch.stack([a, b], dim=2)
print(c)
"""
tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.]],

        [[ 7.,  8.,  9.],
         [10., 11., 12.]]])
tensor([[[ 1.,  2.,  3.],
         [ 7.,  8.,  9.]],

        [[ 4.,  5.,  6.],
         [10., 11., 12.]]])
tensor([[[ 1.,  7.],
         [ 2.,  8.],
         [ 3.,  9.]],

        [[ 4., 10.],
         [ 5., 11.],
         [ 6., 12.]]])
"""

2. 搭建常见模型

2.1 DNN

2.1.1 代码

# 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# MINIST手写数字集DNN
"""
MINIST数据集:
wget https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz
wget https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz
wget https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz
wget https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz
"""

# 设置随机种子保证可重复性
torch.manual_seed(42)

# 设置计算设备(优先使用GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------- 1.数据加载与预处理 --------------------
# 定义数据预处理转换(标准化参数来自MNIST官方统计值)
transform = transforms.Compose([
    transforms.ToTensor(),  # 将PIL图像像素[0,255]转换为[0,1]范围的Tensor
    transforms.Normalize((0.1307,), (0.3081,))  # 标准化到[-1,1]范围
])

# 下载并加载训练集和测试集
train_dataset = datasets.MNIST(
    root='./data',  # 数据集存储路径
    train=True,  # 加载训练集
    download=True,  # 自动下载数据集
    transform=transform  # 应用定义的数据转换
)

test_dataset = datasets.MNIST(
    root='./data',
    train=False,  # 加载测试集
    download=True,
    transform=transform
)

# 创建数据加载器(分批加载数据)
train_loader = DataLoader(
    train_dataset,
    batch_size=64,  # 每批64个样本
    shuffle=True  # 打乱训练数据顺序
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1000,  # 测试时使用更大的批处理量
    shuffle=False  # 测试数据无需打乱
)


# -------------------- 2.定义卷积神经网络模型 --------------------
class DNN(nn.Module):
    def __init__(self):
        super(DNN, self).__init__()
        # 定义网络层结构
        self.fc1 = nn.Linear(28 * 28, 128)  # 输入层(784像素→128神经元)
        self.fc2 = nn.Linear(128, 64)  # 隐藏层(128→64)
        self.fc3 = nn.Linear(64, 10)  # 输出层(64→10类)
        self.dropout = nn.Dropout(0.5)  # 50%概率丢弃神经元防止过拟合
        self.relu = nn.ReLU()
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # 将图像展平为一维向量
        x = self.relu(self.fc1(x))  # 第一层激活函数
        x = self.dropout(x)  # Dropout
        x = self.relu(self.fc2(x))  # 第二层激活函数
        x = self.dropout(x)  # Dropout
        x = self.fc3(x)  # 最终输出(无需激活函数,因使用CrossEntropyLoss)
        return self.softmax(x)


# 实例化模型并转移到计算设备
model = DNN().to(device)

# 输出网络结构
# print(model) # 通过print(model)输出模型结构,显示的是__init__中定义的层顺序,但不反映实际执行顺序
from net_structure import *
print_model_leaf_structure(model, torch.randn(64, 1, 28, 28)) # 64张图片,每张图片1个通道(灰色图像),图片尺寸28x28

# -------------------- 3.定义损失函数和优化器 --------------------
criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 自适应学习率优化器


# -------------------- 4.训练过程 --------------------
def train(epochs):
    model.train()  # 设置为训练模式
    for epoch in range(epochs):
        total_loss, running_loss = 0.0, 0.0
        for batch_idx, (data, target) in enumerate(train_loader):
            # 将数据转移到对应设备(CPU/GPU)
            data, target = data.to(device), target.to(device)
            # 前向传播
            outputs = model(data)
            loss = criterion(outputs, target)
            # 反向传播和优化
            optimizer.zero_grad()  # 清空梯度
            loss.backward()  # 计算梯度
            optimizer.step()  # 更新参数
            # 记录损失值
            running_loss += loss.item()
            total_loss += loss.item()
            if batch_idx % 100 == 99:  # 每100个batch打印一次
                print(f'Epoch {epoch + 1}, Batch {batch_idx + 1}, Loss: {running_loss / 100:.3f}')
                running_loss = 0
        # 打印每个epoch的损失
        print(f'Epoch {epoch + 1}/{epochs} - Loss: {total_loss / len(train_loader):.4f}')


# 执行5个epoch的训练
train(epochs=5)

# -------------------- 5.保存训练好的模型 --------------------
torch.save(model.state_dict(), 'mnist_dnn.pth')  # 推荐保存参数的方式


# -------------------- 6.模型评估 --------------------
def evaluate(new_model):
    new_model.eval()  # 设置为评估模式
    correct = 0
    total = 0
    with torch.no_grad():  # 不计算梯度,节省内存
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs = new_model(data)
            _, predicted = torch.max(outputs.data, 1) # 获取预测结果(最大概率的类别)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')

new_model = DNN().to(device)
# 加载保存的模型参数(演示加载过程)
new_model.load_state_dict(torch.load('mnist_dnn.pth'))
# 执行评估
evaluate(new_model)

2.1.2 结果

"""
【    Linear     】Input shape:     torch.Size([64, 784])      → Output shape:     torch.Size([64, 128])      | Params count: 100480
【     ReLU      】Input shape:     torch.Size([64, 128])      → Output shape:     torch.Size([64, 128])      | Params count: 0
【    Dropout    】Input shape:     torch.Size([64, 128])      → Output shape:     torch.Size([64, 128])      | Params count: 0
【    Linear     】Input shape:     torch.Size([64, 128])      → Output shape:      torch.Size([64, 64])      | Params count: 8256
【     ReLU      】Input shape:      torch.Size([64, 64])      → Output shape:      torch.Size([64, 64])      | Params count: 0
【    Dropout    】Input shape:      torch.Size([64, 64])      → Output shape:      torch.Size([64, 64])      | Params count: 0
【    Linear     】Input shape:      torch.Size([64, 64])      → Output shape:      torch.Size([64, 10])      | Params count: 650
【  LogSoftmax   】Input shape:      torch.Size([64, 10])      → Output shape:      torch.Size([64, 10])      | Params count: 0
***Total Parameters***: 109386 = [100480 + 0 + 0 + 8256 + 0 + 0 + 650 + 0]

Epoch 1, Batch 100, Loss: 1.360
Epoch 1, Batch 200, Loss: 0.689
Epoch 1, Batch 300, Loss: 0.536
Epoch 1, Batch 400, Loss: 0.495
Epoch 1, Batch 500, Loss: 0.457
Epoch 1, Batch 600, Loss: 0.434
Epoch 1, Batch 700, Loss: 0.405
Epoch 1, Batch 800, Loss: 0.397
Epoch 1, Batch 900, Loss: 0.378
Epoch 1/5 - Loss: 0.5635
Epoch 2, Batch 100, Loss: 0.352
Epoch 2, Batch 200, Loss: 0.345
Epoch 2, Batch 300, Loss: 0.354
Epoch 2, Batch 400, Loss: 0.340
Epoch 2, Batch 500, Loss: 0.309
Epoch 2, Batch 600, Loss: 0.297
Epoch 2, Batch 700, Loss: 0.325
Epoch 2, Batch 800, Loss: 0.318
Epoch 2, Batch 900, Loss: 0.307
Epoch 2/5 - Loss: 0.3257
Epoch 3, Batch 100, Loss: 0.285
Epoch 3, Batch 200, Loss: 0.290
Epoch 3, Batch 300, Loss: 0.282
Epoch 3, Batch 400, Loss: 0.289
Epoch 3, Batch 500, Loss: 0.280
Epoch 3, Batch 600, Loss: 0.271
Epoch 3, Batch 700, Loss: 0.273
Epoch 3, Batch 800, Loss: 0.272
Epoch 3, Batch 900, Loss: 0.267
Epoch 3/5 - Loss: 0.2788
Epoch 4, Batch 100, Loss: 0.257
Epoch 4, Batch 200, Loss: 0.236
Epoch 4, Batch 300, Loss: 0.269
Epoch 4, Batch 400, Loss: 0.269
Epoch 4, Batch 500, Loss: 0.264
Epoch 4, Batch 600, Loss: 0.272
Epoch 4, Batch 700, Loss: 0.255
Epoch 4, Batch 800, Loss: 0.251
Epoch 4, Batch 900, Loss: 0.254
Epoch 4/5 - Loss: 0.2578
Epoch 5, Batch 100, Loss: 0.247
Epoch 5, Batch 200, Loss: 0.219
Epoch 5, Batch 300, Loss: 0.236
Epoch 5, Batch 400, Loss: 0.226
Epoch 5, Batch 500, Loss: 0.236
Epoch 5, Batch 600, Loss: 0.250
Epoch 5, Batch 700, Loss: 0.240
Epoch 5, Batch 800, Loss: 0.240
Epoch 5, Batch 900, Loss: 0.235
Epoch 5/5 - Loss: 0.2361
Test Accuracy: 96.37%
"""

2.1.3 输出结果分析

网络结构

在这里插入图片描述

  • 线性全连接层的参数量为: ( C i n + 1 ) × C o u t (C_{in}+1) \times C_{out} (Cin+1)×Cout,其中:

    • C i n C_{in} Cin:输入维度;

    • C o u t C_{out} Cout:输出维度;

    • 其中+1是偏置量。

在这里插入图片描述

  • 可以看出参数都在Linear线性层(全连接)

    • 第一个线性层参数量: ( 784 + 1 ) × 128 = 100480 (784+1) \times 128 = 100480 (784+1)×128=100480

    • 第二个线性层参数量: ( 128 + 1 ) × 64 = 8256 (128+1) \times 64 = 8256 (128+1)×64=8256

    • 第三个线性层参数量: ( 64 + 1 ) × 10 = 650 (64+1) \times 10 = 650 (64+1)×10=650

Epoch、Batch、Batch_size
  • Epoch:所有训练数据训练一次称为一次Epoch;

  • Batch:所有训练数据可能被分为多组进行训练,每组数据称为一个Batch;

  • Batch_size:一各Batch种元素数量称为Batch_size;例如上述网络结构中的64就是Batch_size。

  • 举个例子:例如训练数据一共6400条,一次训练输入64条数据,那一次训练会有 6400 64 = 100 \frac{6400}{64} = 100 646400=100个Batch,每个Batch中有64个数据。

2.2 CNN

2.2.1 代码

# 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# MINIST手写数字集CNN
"""
MINIST数据集:
wget https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz
wget https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz
wget https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz
wget https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz
"""

# 设置随机种子保证可重复性
torch.manual_seed(42)

# 设置计算设备(优先使用GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------- 1.数据加载与预处理 --------------------
# 定义数据预处理转换(标准化参数来自MNIST官方统计值)
transform = transforms.Compose([
    transforms.ToTensor(),  # 将PIL图像像素[0,255]转换为[0,1]范围的Tensor
    transforms.Normalize((0.1307,), (0.3081,))  # 标准化到[-1,1]范围
])

# 下载并加载训练集和测试集
train_dataset = datasets.MNIST(
    root='./data',  # 数据集存储路径
    train=True,  # 加载训练集
    download=True,  # 自动下载数据集
    transform=transform  # 应用定义的数据转换
)

test_dataset = datasets.MNIST(
    root='./data',
    train=False,  # 加载测试集
    download=True,
    transform=transform
)

# 创建数据加载器(分批加载数据)
train_loader = DataLoader(
    train_dataset,
    batch_size=64,  # 每批64个样本
    shuffle=True  # 打乱训练数据顺序
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1000,  # 测试时使用更大的批处理量
    shuffle=False  # 测试数据无需打乱
)

# -------------------- 2.定义卷积神经网络模型 --------------------
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # 第一个卷积层:1输入通道(灰度图),10个输出通道,5x5卷积核
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        # 第二个卷积层:10输入通道,20个输出通道,5x5卷积核
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        # relu层
        self.relu = nn.ReLU()
        # 最大池化层,2x2窗口,步长2
        self.pool = nn.MaxPool2d(2)
        # 全连接层:输入维度320(计算见forward),输出10类(0-9数字)
        self.fc = nn.Linear(320, 10)

    def forward(self, x):
        # 输入尺寸:[batch_size, 1, 28, 28]
        x = self.pool(self.relu(self.conv1(x)))  # -> [64,10,12,12]
        x = self.pool(self.relu(self.conv2(x)))  # -> [64,20,4,4]
        x = x.view(-1, 320)  # 展平处理(320=20 * 4 * 4)
        x = self.fc(x)  # 全连接层输出
        return x


# 实例化模型并转移到计算设备
model = CNN().to(device)

# 输出网络结构
# print(model) # 通过print(model)输出模型结构,显示的是__init__中定义的层顺序,但不反映实际执行顺序
from net_structure import *
print_model_leaf_structure(model, torch.randn(64, 1, 28, 28)) # 64张图片,每张图片1个通道(灰色图像),图片尺寸28x28

# -------------------- 3.定义损失函数和优化器 --------------------
criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数(适用于分类)
optimizer = optim.SGD(
    model.parameters(),
    lr=0.01,  # 初始学习率
    momentum=0.5  # 动量参数加速收敛
)


# -------------------- 4.训练过程 --------------------
def train(epochs):
    model.train()  # 设置为训练模式
    for epoch in range(epochs):
        total_loss, running_loss = 0.0, 0.0
        for batch_idx, (data, target) in enumerate(train_loader):
            # 将数据转移到对应设备(CPU/GPU)
            data, target = data.to(device), target.to(device)
            # 前向传播
            outputs = model(data)
            loss = criterion(outputs, target)
            # 反向传播和优化
            optimizer.zero_grad()  # 清空梯度
            loss.backward()  # 计算梯度
            optimizer.step()  # 更新参数
            # 记录损失值
            running_loss += loss.item()
            total_loss += loss.item()
            if batch_idx % 100 == 99:  # 每100个batch打印一次
                print(f'Epoch {epoch + 1}, Batch {batch_idx + 1}, Loss: {running_loss / 100:.3f}')
                running_loss = 0
        # 打印每个epoch的损失
        print(f'Epoch {epoch + 1}/{epochs} - Loss: {total_loss / len(train_loader):.4f}')

# 执行5个epoch的训练
train(epochs=5)

# -------------------- 5.保存训练好的模型 --------------------
torch.save(model.state_dict(), 'mnist_cnn.pth')  # 保存模型参数


# -------------------- 6.模型评估 --------------------
def evaluate(new_model):
    new_model.eval()  # 设置为评估模式
    correct = 0
    total = 0
    with torch.no_grad():  # 不计算梯度,节省内存
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs = new_model(data)
            _, predicted = torch.max(outputs.data, 1) # 获取预测结果(最大概率的类别)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')

new_model = CNN().to(device)
# 加载保存的模型参数(演示加载过程)
new_model.load_state_dict(torch.load('mnist_cnn.pth'))
# 执行评估
evaluate(new_model)

# -------------------- 7.可视化预测结果(可选) --------------------
# 获取测试集的一个batch
dataiter = iter(test_loader)
images, labels = next(dataiter)
images, labels = images.to(device), labels.to(device)

# 进行预测
outputs = new_model(images)
_, preds = torch.max(outputs, 1)

# 可视化前16张图片及其预测结果
fig = plt.figure(figsize=(12, 6))
for idx in range(16):
    ax = fig.add_subplot(4, 4, idx + 1)
    img = images[idx].cpu().numpy().squeeze()
    ax.imshow(img, cmap='gray_r')
    ax.set_title(f'Pred: {preds[idx]} | True: {labels[idx]}')
    ax.axis('off')
plt.tight_layout()
plt.show()

2.2.2 结果

"""
【    Conv2d     】Input shape:  torch.Size([64, 1, 28, 28])   → Output shape:  torch.Size([64, 10, 24, 24])  | Params count: 260
【     ReLU      】Input shape:  torch.Size([64, 10, 24, 24])  → Output shape:  torch.Size([64, 10, 24, 24])  | Params count: 0
【   MaxPool2d   】Input shape:  torch.Size([64, 10, 24, 24])  → Output shape:  torch.Size([64, 10, 12, 12])  | Params count: 0
【    Conv2d     】Input shape:  torch.Size([64, 10, 12, 12])  → Output shape:   torch.Size([64, 20, 8, 8])   | Params count: 5020
【     ReLU      】Input shape:   torch.Size([64, 20, 8, 8])   → Output shape:   torch.Size([64, 20, 8, 8])   | Params count: 0
【   MaxPool2d   】Input shape:   torch.Size([64, 20, 8, 8])   → Output shape:   torch.Size([64, 20, 4, 4])   | Params count: 0
【    Linear     】Input shape:     torch.Size([64, 320])      → Output shape:      torch.Size([64, 10])      | Params count: 3210
***Total Parameters***: 8490 = [260 + 0 + 0 + 5020 + 0 + 0 + 3210]

Epoch 1, Batch 100, Loss: 1.293
Epoch 1, Batch 200, Loss: 0.383
Epoch 1, Batch 300, Loss: 0.275
Epoch 1, Batch 400, Loss: 0.223
Epoch 1, Batch 500, Loss: 0.182
Epoch 1, Batch 600, Loss: 0.161
Epoch 1, Batch 700, Loss: 0.151
Epoch 1, Batch 800, Loss: 0.142
Epoch 1, Batch 900, Loss: 0.131
Epoch 1/5 - Loss: 0.3183
Epoch 2, Batch 100, Loss: 0.114
Epoch 2, Batch 200, Loss: 0.107
Epoch 2, Batch 300, Loss: 0.114
Epoch 2, Batch 400, Loss: 0.098
Epoch 2, Batch 500, Loss: 0.100
Epoch 2, Batch 600, Loss: 0.095
Epoch 2, Batch 700, Loss: 0.090
Epoch 2, Batch 800, Loss: 0.092
Epoch 2, Batch 900, Loss: 0.086
Epoch 2/5 - Loss: 0.0998
Epoch 3, Batch 100, Loss: 0.090
Epoch 3, Batch 200, Loss: 0.072
Epoch 3, Batch 300, Loss: 0.072
Epoch 3, Batch 400, Loss: 0.079
Epoch 3, Batch 500, Loss: 0.078
Epoch 3, Batch 600, Loss: 0.072
Epoch 3, Batch 700, Loss: 0.069
Epoch 3, Batch 800, Loss: 0.083
Epoch 3, Batch 900, Loss: 0.068
Epoch 3/5 - Loss: 0.0749
Epoch 4, Batch 100, Loss: 0.063
Epoch 4, Batch 200, Loss: 0.066
Epoch 4, Batch 300, Loss: 0.063
Epoch 4, Batch 400, Loss: 0.070
Epoch 4, Batch 500, Loss: 0.061
Epoch 4, Batch 600, Loss: 0.065
Epoch 4, Batch 700, Loss: 0.058
Epoch 4, Batch 800, Loss: 0.058
Epoch 4, Batch 900, Loss: 0.055
Epoch 4/5 - Loss: 0.0625
Epoch 5, Batch 100, Loss: 0.052
Epoch 5, Batch 200, Loss: 0.057
Epoch 5, Batch 300, Loss: 0.063
Epoch 5, Batch 400, Loss: 0.052
Epoch 5, Batch 500, Loss: 0.052
Epoch 5, Batch 600, Loss: 0.066
Epoch 5, Batch 700, Loss: 0.053
Epoch 5, Batch 800, Loss: 0.051
Epoch 5, Batch 900, Loss: 0.054
Epoch 5/5 - Loss: 0.0553
Test Accuracy: 98.30%
"""

在这里插入图片描述

2.2.3 输出结果分析

网络结构

在这里插入图片描述

  • 卷积层参数量​: ( K h × K w × C i n + 1 ) × C o u t (K_h \times K_w \times C_{in} + 1) \times C_{out} (Kh×Kw×Cin+1)×Cout,其中

    • K h , K w K_h, K_w Kh,Kw:卷积核高宽;

    • C i n C_{in} Cin:输入通道数;

    • C o u t C_{out} Cout:输出通道数;

    • 其中+1是偏置量。

在这里插入图片描述

  • 参数量都在卷积层和线性层:

    • 第一个卷积层参数量: ( 5 × 5 × 1 + 1 ) × 10 = 260 (5 \times 5 \times 1 + 1) \times 10 = 260 (5×5×1+1)×10=260

    • 第二个卷积层参数量: ( 5 × 5 × 10 + 1 ) × 20 = 5020 (5 \times 5 \times 10 + 1) \times 20 = 5020 (5×5×10+1)×20=5020

    • 第三个卷积层参数量: ( 320 + 1 ) × 10 = 3210 (320+1) \times 10 = 3210 (320+1)×10=3210

3. 绘制forward定义的模型结构

第2节中有对如下函数的使用

3.1 打印函数定义

from functools import partial
import traceback


# register_out_hooks 给网络注册钩子函数,用于输出网络结构,仅输出最外层结构
def register_out_hooks(model):
    def hook_fn(module, input, output):
        layer_name = str(module).split('(')[0]
        input_shape = str(input[0].shape if isinstance(input, tuple) else input.shape)
        output_shape = str(output[0].shape if isinstance(output, tuple) else output.shape)
        print(f"【{layer_name}】: Input shape: {input_shape} → Output shape: {output_shape}")

    hooks = []
    for name, layer in model.named_children():  # 遍历直接子层
        hook = layer.register_forward_hook(hook_fn)
        hooks.append(hook)
    return hooks


# register_leaf_hooks 给网络注册钩子函数,用于输出网络结构,输出最内层结构,要求:nn中网络需要事先在__init__函数中定义
def register_leaf_hooks(model):
    # 定义钩子函数(捕获输入输出形状)
    total_params_list = []  # 初始化总参数量列表

    def hook_fn(module, input, output):
        params_count = sum(p.numel() for p in module.parameters())
        total_params_list.append(params_count)
        input_shape = str(input[0].shape if isinstance(input, tuple) else input.shape)
        output_shape = str(output[0].shape if isinstance(output, tuple) else output.shape)
        print(
            f"【{module.__class__.__name__:^15}】Input shape: {input_shape:^30} → Output shape: {output_shape:^30} | Params count: {params_count}")

    hooks = []
    for name, module in model.named_modules():  # forward中动态创建的层不会注册为模型的子模块,因此无法通过named_modules()遍历到,导致钩子无法绑定
        # 判断是否为叶子节点(无子模块)
        if len(list(module.children())) == 0:
            hook = module.register_forward_hook(hook_fn)
            hooks.append(hook)
    return hooks, total_params_list


def register_tree_hooks(model):
    output_cache = []  # 缓存各模块输出信息
    total_params_list = []  # 初始化总参数量列表

    def tree_hook_fn(module, input, output, depth):
        # 统计当前模块参数(不包含子模块)
        params_count = sum(p.numel() for p in module.parameters(recurse=False))
        total_params_list.append(params_count)
        # 获取形状信息
        input_shape = str(input[0].shape if isinstance(input, tuple) else input.shape)
        output_shape = str(output[0].shape if isinstance(output, tuple) else output.shape)
        # 生成树形前缀
        prefix = "│   " * (depth - 1) + "├── " if depth > 0 else ""
        # 格式化输出
        output_cache.append({
            "depth": depth,
            "info": f"{prefix}{module.__class__.__name__:^15}】Input shape: {input_shape:^30} → Output shape: {output_shape:^30} | Params count: {params_count}",
        })

    hooks = []

    def recursive_hook_register(module, depth=0):
        # 注册当前模块的前向钩子
        hook = module.register_forward_hook(partial(tree_hook_fn, depth=depth))
        hooks.append(hook)
        # 递归遍历子模块
        for child in module.children():
            recursive_hook_register(child, depth + 1)

    # 调整打印顺序,并打印输出结果
    def print_cached_results():
        # 1. 找到max_depth
        max_depth = max([item["depth"] for item in output_cache])
        # 2. 调整打印顺序
        """
        # 算法解释:
            max_depth = 2
            s = [2, 2, 1, 2, 2, 1, 1, 0] # 深度
            for i in range(max_depth, 0, -1):
                k = 0
                while k < len(s):
                    j = k
                    while k < len(s) and s[k] >= i:
                        k += 1
                    if k < len(s):
                        val = s[k]
                        for t in range(k, j, -1):
                            s[t] = s[t - 1]
                        s[j] = val
                    k += 1
            print(s) # 预期结果:[0, 1, 2, 2, 1, 2, 2, 1]
        """
        for i in range(max_depth, 0, -1):
            k = 0
            while k < len(output_cache):
                j = k
                while k < len(output_cache) and output_cache[k]["depth"] >= i:
                    k += 1
                if k < len(output_cache):
                    val = output_cache[k]
                    for t in range(k, j, -1):
                        output_cache[t] = output_cache[t - 1]
                    output_cache[j] = val
                k += 1
        # 打印输出结果
        for val in output_cache:
            print(val["info"])
        output_cache.clear()

    # 启动递归注册
    recursive_hook_register(model)
    return hooks, total_params_list, print_cached_results


# print_model_out_structure 输出模型最外层结构
def print_model_out_structure(model, inputs):
    hooks = []
    try:
        hooks = register_out_hooks(model)
        if isinstance(inputs, (list, tuple)):
            model(*inputs)
        else:
            model(inputs)
        print()
    except Exception as e:
        print(e)
        traceback.print_exc()
    finally:
        for hook in hooks:
            hook.remove()


# print_model_leaf_structure 输出模型最内层结构
def print_model_leaf_structure(model, inputs):
    hooks = []
    try:
        hooks, total_params_list = register_leaf_hooks(model)
        if isinstance(inputs, (list, tuple)):
            model(*inputs)
        else:
            model(inputs)
        print(f"***Total Parameters***: {sum(total_params_list)} = [" + " + ".join(
            [str(e) for e in total_params_list]) + "]\n")
    except Exception as e:
        print(e)
        traceback.print_exc()
    finally:
        for hook in hooks:
            hook.remove()


# print_model_tree_structure 输出模型树形结构
def print_model_tree_structure(model, inputs):
    hooks = []
    try:
        hooks, total_params_list, print_fn = register_tree_hooks(model)
        if isinstance(inputs, (list, tuple)):
            model(*inputs)
        else:
            model(inputs)
        print_fn()
        print(f"***Total Parameters***: {sum(total_params_list)} = [" + " + ".join(
            [str(e) for e in total_params_list]) + "]\n")
    except Exception as e:
        print(e)
        traceback.print_exc()
    finally:
        for hook in hooks:
            hook.remove()

3.2 打印函数的使用

注意点:

  1. 需要给定输入的维度,一般来说第一维是Batch_size,后面是单个数据的维度;
  2. 网络定义时要在__init__函数中提前定义好各层,之后直接在forward中使用__init__定义好的层,这样输出网络结构时才能够捕获到。相当于要在__init__中提前注册好各层的定义
  • 例如对于上述DNN和CNN:
# 设置计算设备(优先使用GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 实例化模型并转移到计算设备
model = DNN().to(device) 
# model = CNN().to(device)

# 输出网络结构
# print(model) # 通过print(model)输出模型结构,显示的是__init__中定义的层顺序,但不反映实际执行顺序
from net_structure import *
print_model_leaf_structure(model, torch.randn(64, 1, 28, 28)) # 64张图片,每张图片1个通道(灰色图像),图片尺寸28x28

3.3 使用demo

3.3.1 demo 1简单嵌套结构

  • 代码:
import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 16, 3),
            nn.ReLU(),
        )
        self.block2 = nn.Sequential(
            nn.MaxPool2d(2),
            nn.ReLU(),
        )
        self.fc = nn.Linear(16 * 13 * 13, 10)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# 初始化模型
model = MyModel()
print_model_tree_structure(model, torch.randn(1, 3, 28, 28))
  • 输出结果:
【 MyModel  】Input shape:   torch.Size([1, 3, 28, 28])   → Output shape:      torch.Size([1, 10])       | Params count: 0
├── 【Sequential】Input shape:   torch.Size([1, 3, 28, 28])   → Output shape:  torch.Size([1, 16, 26, 26])   | Params count: 0
│   ├── 【  Conv2d  】Input shape:   torch.Size([1, 3, 28, 28])   → Output shape:  torch.Size([1, 16, 26, 26])   | Params count: 448
│   ├── 【   ReLU   】Input shape:  torch.Size([1, 16, 26, 26])   → Output shape:  torch.Size([1, 16, 26, 26])   | Params count: 0
├── 【Sequential】Input shape:  torch.Size([1, 16, 26, 26])   → Output shape:  torch.Size([1, 16, 13, 13])   | Params count: 0
│   ├── 【MaxPool2d 】Input shape:  torch.Size([1, 16, 26, 26])   → Output shape:  torch.Size([1, 16, 13, 13])   | Params count: 0
│   ├── 【   ReLU   】Input shape:  torch.Size([1, 16, 13, 13])   → Output shape:  torch.Size([1, 16, 13, 13])   | Params count: 0
├── 【  Linear  】Input shape:     torch.Size([1, 2704])      → Output shape:      torch.Size([1, 10])       | Params count: 27050
***Total Parameters***: 27498 = [448 + 0 + 0 + 0 + 0 + 0 + 27050 + 0]

3.3.2 demo 2:复杂嵌套结构UNet

  • 代码:
import math
import torch
import torch.nn as nn


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, total_time_steps=1000, time_emb_dims=128, time_emb_dims_exp=512):
        super().__init__()

        half_dim = time_emb_dims // 2

        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)

        ts = torch.arange(total_time_steps, dtype=torch.float32)

        emb = torch.unsqueeze(ts, dim=-1) * torch.unsqueeze(emb, dim=0)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)

        self.time_blocks = nn.Sequential(
            nn.Embedding.from_pretrained(emb),
            nn.Linear(in_features=time_emb_dims, out_features=time_emb_dims_exp),
            nn.SiLU(),
            nn.Linear(in_features=time_emb_dims_exp, out_features=time_emb_dims_exp),
        )

    def forward(self, time):
        return self.time_blocks(time)


class AttentionBlock(nn.Module):
    def __init__(self, channels=64):
        super().__init__()
        self.channels = channels

        self.group_norm = nn.GroupNorm(num_groups=8, num_channels=channels)
        self.mhsa = nn.MultiheadAttention(embed_dim=self.channels, num_heads=4, batch_first=True)

    def forward(self, x):
        B, _, H, W = x.shape
        h = self.group_norm(x)
        h = h.reshape(B, self.channels, H * W).swapaxes(1, 2)  # [B, C, H, W] --> [B, C, H * W] --> [B, H*W, C]
        h, _ = self.mhsa(h, h, h)  # [B, H*W, C]
        h = h.swapaxes(2, 1).view(B, self.channels, H, W)  # [B, C, H*W] --> [B, C, H, W]
        return x + h


class ResnetBlock(nn.Module):
    def __init__(self, *, in_channels, out_channels, dropout_rate=0.1, time_emb_dims=512, apply_attention=False):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.act_fn = nn.SiLU()
        # Group 1
        self.normlize_1 = nn.GroupNorm(num_groups=8, num_channels=self.in_channels)
        self.conv_1 = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=3, stride=1,
                                padding="same")

        # Group 2 time embedding
        self.dense_1 = nn.Linear(in_features=time_emb_dims, out_features=self.out_channels)

        # Group 3
        self.normlize_2 = nn.GroupNorm(num_groups=8, num_channels=self.out_channels)
        self.dropout = nn.Dropout2d(p=dropout_rate)
        self.conv_2 = nn.Conv2d(in_channels=self.out_channels, out_channels=self.out_channels, kernel_size=3, stride=1,
                                padding="same")

        if self.in_channels != self.out_channels:
            self.match_input = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=1,
                                         stride=1)
        else:
            self.match_input = nn.Identity()

        if apply_attention:
            self.attention = AttentionBlock(channels=self.out_channels)
        else:
            self.attention = nn.Identity()

    def forward(self, x, t):
        # group 1
        h = self.act_fn(self.normlize_1(x))
        h = self.conv_1(h)

        # group 2
        # add in timestep embedding
        h += self.dense_1(self.act_fn(t))[:, :, None, None]

        # group 3
        h = self.act_fn(self.normlize_2(h))
        h = self.dropout(h)
        h = self.conv_2(h)

        # Residual and attention
        h = h + self.match_input(x)
        h = self.attention(h)

        return h


class DownSample(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.downsample = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, stride=2, padding=1)

    def forward(self, x, *args):
        return self.downsample(x)


class UpSample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x, *args):
        return self.upsample(x)


class UNet(nn.Module):
    def __init__(
            self,
            input_channels=3,
            output_channels=3,
            num_res_blocks=2,
            base_channels=128,
            base_channels_multiples=(1, 2, 4, 8),
            apply_attention=(False, False, True, False),
            dropout_rate=0.1,
            time_multiple=4,
    ):
        super().__init__()

        time_emb_dims_exp = base_channels * time_multiple
        self.time_embeddings = SinusoidalPositionEmbeddings(time_emb_dims=base_channels,
                                                            time_emb_dims_exp=time_emb_dims_exp)

        self.first = nn.Conv2d(in_channels=input_channels, out_channels=base_channels, kernel_size=3, stride=1,
                               padding="same")

        num_resolutions = len(base_channels_multiples)

        # Encoder part of the UNet. Dimension reduction.
        self.encoder_blocks = nn.ModuleList()
        curr_channels = [base_channels]
        in_channels = base_channels

        for level in range(num_resolutions):
            out_channels = base_channels * base_channels_multiples[level]

            for _ in range(num_res_blocks):
                block = ResnetBlock(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    dropout_rate=dropout_rate,
                    time_emb_dims=time_emb_dims_exp,
                    apply_attention=apply_attention[level],
                )
                self.encoder_blocks.append(block)

                in_channels = out_channels
                curr_channels.append(in_channels)

            if level != (num_resolutions - 1):
                self.encoder_blocks.append(DownSample(channels=in_channels))
                curr_channels.append(in_channels)

        # Bottleneck in between
        self.bottleneck_blocks = nn.ModuleList(
            (
                ResnetBlock(
                    in_channels=in_channels,
                    out_channels=in_channels,
                    dropout_rate=dropout_rate,
                    time_emb_dims=time_emb_dims_exp,
                    apply_attention=True,
                ),
                ResnetBlock(
                    in_channels=in_channels,
                    out_channels=in_channels,
                    dropout_rate=dropout_rate,
                    time_emb_dims=time_emb_dims_exp,
                    apply_attention=False,
                ),
            )
        )

        # Decoder part of the UNet. Dimension restoration with skip-connections.
        self.decoder_blocks = nn.ModuleList()

        for level in reversed(range(num_resolutions)):
            out_channels = base_channels * base_channels_multiples[level]

            for _ in range(num_res_blocks + 1):
                encoder_in_channels = curr_channels.pop()
                block = ResnetBlock(
                    in_channels=encoder_in_channels + in_channels,
                    out_channels=out_channels,
                    dropout_rate=dropout_rate,
                    time_emb_dims=time_emb_dims_exp,
                    apply_attention=apply_attention[level],
                )

                in_channels = out_channels
                self.decoder_blocks.append(block)

            if level != 0:
                self.decoder_blocks.append(UpSample(in_channels))

        self.final = nn.Sequential(
            nn.GroupNorm(num_groups=8, num_channels=in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels=in_channels, out_channels=output_channels, kernel_size=3, stride=1, padding="same"),
        )

    def forward(self, x, t):

        time_emb = self.time_embeddings(t)

        h = self.first(x)
        outs = [h]

        for layer in self.encoder_blocks:
            h = layer(h, time_emb)
            outs.append(h)

        for layer in self.bottleneck_blocks:
            h = layer(h, time_emb)

        for layer in self.decoder_blocks:
            if isinstance(layer, ResnetBlock):
                out = outs.pop()
                h = torch.cat([h, out], dim=1)
            h = layer(h, time_emb)

        h = self.final(h)

        return h


model = UNet(
    input_channels=3,
    output_channels=3,
    base_channels=64,
    base_channels_multiples=(1, 2, 4, 4),
    apply_attention=(False, True, True, False),
    dropout_rate=0.1,
    time_multiple=4,
)
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

from net_structure import *

print_model_tree_structure(model, [torch.randn(32, 3, 32, 32), torch.randint(low=1, high=1000, size=(32,))])
  • 输出结果:
【     UNet      】Input shape:  torch.Size([32, 3, 32, 32])   → Output shape:  torch.Size([32, 3, 32, 32])   | Params count: 0
├── 【SinusoidalPositionEmbeddings】Input shape:        torch.Size([32])        → Output shape:     torch.Size([32, 256])      | Params count: 0
│   ├── 【  Sequential   】Input shape:        torch.Size([32])        → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【   Embedding   】Input shape:        torch.Size([32])        → Output shape:      torch.Size([32, 64])      | Params count: 64000
│   │   ├── 【    Linear     】Input shape:      torch.Size([32, 64])      → Output shape:     torch.Size([32, 256])      | Params count: 16640
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 65792
├── 【    Conv2d     】Input shape:  torch.Size([32, 3, 32, 32])   → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 1792
├── 【  Sequential   】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 3, 32, 32])   | Params count: 0
│   ├── 【  ResnetBlock  】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 128
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 36928
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:      torch.Size([32, 64])      | Params count: 16448
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 128
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 36928
│   │   ├── 【   Identity    】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   │   ├── 【   Identity    】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   ├── 【  ResnetBlock  】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 128
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 36928
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:      torch.Size([32, 64])      | Params count: 16448
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 128
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 36928
│   │   ├── 【   Identity    】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   │   ├── 【   Identity    】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   ├── 【  DownSample   】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 16, 16])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 16, 16])  | Params count: 36928
│   ├── 【  ResnetBlock  】Input shape:  torch.Size([32, 64, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 64, 16, 16])  → Output shape:  torch.Size([32, 64, 16, 16])  | Params count: 128
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 64, 16, 16])  → Output shape:  torch.Size([32, 64, 16, 16])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 64, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 73856
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 128])      | Params count: 32896
│   │   ├── 【   GroupNorm   】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 256
│   │   ├── 【     SiLU      】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 147584
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 64, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 8320
│   │   ├── 【AttentionBlock 】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   │   ├── 【   GroupNorm   】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 256
│   │   │   ├── 【MultiheadAttention】Input shape:   torch.Size([32, 256, 128])   → Output shape:   torch.Size([32, 256, 128])   | Params count: 49536
│   ├── 【  ResnetBlock  】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 256
│   │   ├── 【     SiLU      】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 147584
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 128])      | Params count: 32896
│   │   ├── 【   GroupNorm   】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 256
│   │   ├── 【     SiLU      】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 147584
│   │   ├── 【   Identity    】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   ├── 【AttentionBlock 】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   │   ├── 【   GroupNorm   】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 256
│   │   │   ├── 【MultiheadAttention】Input shape:   torch.Size([32, 256, 128])   → Output shape:   torch.Size([32, 256, 128])   | Params count: 49536
│   ├── 【  DownSample   】Input shape: torch.Size([32, 128, 16, 16])  → Output shape:  torch.Size([32, 128, 8, 8])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 128, 16, 16])  → Output shape:  torch.Size([32, 128, 8, 8])   | Params count: 147584
│   ├── 【  ResnetBlock  】Input shape:  torch.Size([32, 128, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 128, 8, 8])   → Output shape:  torch.Size([32, 128, 8, 8])   | Params count: 256
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 128, 8, 8])   → Output shape:  torch.Size([32, 128, 8, 8])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 128, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 295168
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 65792
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 512
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 590080
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 128, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 33024
│   │   ├── 【AttentionBlock 】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 512
│   │   │   ├── 【MultiheadAttention】Input shape:   torch.Size([32, 64, 256])    → Output shape:   torch.Size([32, 64, 256])    | Params count: 197376
│   ├── 【  ResnetBlock  】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 512
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 590080
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 65792
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 512
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 590080
│   │   ├── 【   Identity    】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   ├── 【AttentionBlock 】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 512
│   │   │   ├── 【MultiheadAttention】Input shape:   torch.Size([32, 64, 256])    → Output shape:   torch.Size([32, 64, 256])    | Params count: 197376
│   ├── 【  DownSample   】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 590080
│   ├── 【  ResnetBlock  】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 512
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 590080
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 65792
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 512
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 590080
│   │   ├── 【   Identity    】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【   Identity    】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   ├── 【  ResnetBlock  】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 512
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 590080
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 65792
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 512
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 590080
│   │   ├── 【   Identity    】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【   Identity    】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   ├── 【  ResnetBlock  】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 512
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 590080
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 65792
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 512
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 590080
│   │   ├── 【   Identity    】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【AttentionBlock 】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 512
│   │   │   ├── 【MultiheadAttention】Input shape:   torch.Size([32, 16, 256])    → Output shape:   torch.Size([32, 16, 256])    | Params count: 197376
│   ├── 【  ResnetBlock  】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 512
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 590080
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 65792
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 512
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 590080
│   │   ├── 【   Identity    】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【   Identity    】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   ├── 【  ResnetBlock  】Input shape:  torch.Size([32, 512, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 512, 4, 4])   → Output shape:  torch.Size([32, 512, 4, 4])   | Params count: 1024
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 512, 4, 4])   → Output shape:  torch.Size([32, 512, 4, 4])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 512, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 1179904
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 65792
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 512
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 590080
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 512, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 131328
│   │   ├── 【   Identity    】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   ├── 【  ResnetBlock  】Input shape:  torch.Size([32, 512, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 512, 4, 4])   → Output shape:  torch.Size([32, 512, 4, 4])   | Params count: 1024
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 512, 4, 4])   → Output shape:  torch.Size([32, 512, 4, 4])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 512, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 1179904
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 65792
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 512
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 590080
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 512, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 131328
│   │   ├── 【   Identity    】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   ├── 【  ResnetBlock  】Input shape:  torch.Size([32, 512, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 512, 4, 4])   → Output shape:  torch.Size([32, 512, 4, 4])   | Params count: 1024
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 512, 4, 4])   → Output shape:  torch.Size([32, 512, 4, 4])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 512, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 1179904
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 65792
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 512
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 590080
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 512, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 131328
│   │   ├── 【   Identity    】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 4, 4])   | Params count: 0
│   ├── 【   UpSample    】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   ├── 【  Sequential   】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   │   ├── 【   Upsample    】Input shape:  torch.Size([32, 256, 4, 4])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 590080
│   ├── 【  ResnetBlock  】Input shape:  torch.Size([32, 512, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 512, 8, 8])   → Output shape:  torch.Size([32, 512, 8, 8])   | Params count: 1024
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 512, 8, 8])   → Output shape:  torch.Size([32, 512, 8, 8])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 512, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 1179904
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 65792
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 512
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 590080
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 512, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 131328
│   │   ├── 【AttentionBlock 】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 512
│   │   │   ├── 【MultiheadAttention】Input shape:   torch.Size([32, 64, 256])    → Output shape:   torch.Size([32, 64, 256])    | Params count: 197376
│   ├── 【  ResnetBlock  】Input shape:  torch.Size([32, 512, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 512, 8, 8])   → Output shape:  torch.Size([32, 512, 8, 8])   | Params count: 1024
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 512, 8, 8])   → Output shape:  torch.Size([32, 512, 8, 8])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 512, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 1179904
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 65792
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 512
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 590080
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 512, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 131328
│   │   ├── 【AttentionBlock 】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 512
│   │   │   ├── 【MultiheadAttention】Input shape:   torch.Size([32, 64, 256])    → Output shape:   torch.Size([32, 64, 256])    | Params count: 197376
│   ├── 【  ResnetBlock  】Input shape:  torch.Size([32, 384, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 384, 8, 8])   → Output shape:  torch.Size([32, 384, 8, 8])   | Params count: 768
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 384, 8, 8])   → Output shape:  torch.Size([32, 384, 8, 8])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 384, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 884992
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 65792
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 512
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 590080
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 384, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 98560
│   │   ├── 【AttentionBlock 】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 0
│   │   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape:  torch.Size([32, 256, 8, 8])   | Params count: 512
│   │   │   ├── 【MultiheadAttention】Input shape:   torch.Size([32, 64, 256])    → Output shape:   torch.Size([32, 64, 256])    | Params count: 197376
│   ├── 【   UpSample    】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape: torch.Size([32, 256, 16, 16])  | Params count: 0
│   │   ├── 【  Sequential   】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape: torch.Size([32, 256, 16, 16])  | Params count: 0
│   │   │   ├── 【   Upsample    】Input shape:  torch.Size([32, 256, 8, 8])   → Output shape: torch.Size([32, 256, 16, 16])  | Params count: 0
│   │   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 256, 16, 16])  → Output shape: torch.Size([32, 256, 16, 16])  | Params count: 590080
│   ├── 【  ResnetBlock  】Input shape: torch.Size([32, 384, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape: torch.Size([32, 384, 16, 16])  → Output shape: torch.Size([32, 384, 16, 16])  | Params count: 768
│   │   ├── 【     SiLU      】Input shape: torch.Size([32, 384, 16, 16])  → Output shape: torch.Size([32, 384, 16, 16])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 384, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 442496
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 128])      | Params count: 32896
│   │   ├── 【   GroupNorm   】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 256
│   │   ├── 【     SiLU      】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 147584
│   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 384, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 49280
│   │   ├── 【AttentionBlock 】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   │   ├── 【   GroupNorm   】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 256
│   │   │   ├── 【MultiheadAttention】Input shape:   torch.Size([32, 256, 128])   → Output shape:   torch.Size([32, 256, 128])   | Params count: 49536
│   ├── 【  ResnetBlock  】Input shape: torch.Size([32, 256, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape: torch.Size([32, 256, 16, 16])  → Output shape: torch.Size([32, 256, 16, 16])  | Params count: 512
│   │   ├── 【     SiLU      】Input shape: torch.Size([32, 256, 16, 16])  → Output shape: torch.Size([32, 256, 16, 16])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 256, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 295040
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 128])      | Params count: 32896
│   │   ├── 【   GroupNorm   】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 256
│   │   ├── 【     SiLU      】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 147584
│   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 256, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 32896
│   │   ├── 【AttentionBlock 】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   │   ├── 【   GroupNorm   】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 256
│   │   │   ├── 【MultiheadAttention】Input shape:   torch.Size([32, 256, 128])   → Output shape:   torch.Size([32, 256, 128])   | Params count: 49536
│   ├── 【  ResnetBlock  】Input shape: torch.Size([32, 192, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape: torch.Size([32, 192, 16, 16])  → Output shape: torch.Size([32, 192, 16, 16])  | Params count: 384
│   │   ├── 【     SiLU      】Input shape: torch.Size([32, 192, 16, 16])  → Output shape: torch.Size([32, 192, 16, 16])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 192, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 221312
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 128])      | Params count: 32896
│   │   ├── 【   GroupNorm   】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 256
│   │   ├── 【     SiLU      】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 147584
│   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 192, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 24704
│   │   ├── 【AttentionBlock 】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 0
│   │   │   ├── 【   GroupNorm   】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 16, 16])  | Params count: 256
│   │   │   ├── 【MultiheadAttention】Input shape:   torch.Size([32, 256, 128])   → Output shape:   torch.Size([32, 256, 128])   | Params count: 49536
│   ├── 【   UpSample    】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 32, 32])  | Params count: 0
│   │   ├── 【  Sequential   】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 32, 32])  | Params count: 0
│   │   │   ├── 【   Upsample    】Input shape: torch.Size([32, 128, 16, 16])  → Output shape: torch.Size([32, 128, 32, 32])  | Params count: 0
│   │   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 128, 32, 32])  → Output shape: torch.Size([32, 128, 32, 32])  | Params count: 147584
│   ├── 【  ResnetBlock  】Input shape: torch.Size([32, 192, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape: torch.Size([32, 192, 32, 32])  → Output shape: torch.Size([32, 192, 32, 32])  | Params count: 384
│   │   ├── 【     SiLU      】Input shape: torch.Size([32, 192, 32, 32])  → Output shape: torch.Size([32, 192, 32, 32])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 192, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 110656
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:      torch.Size([32, 64])      | Params count: 16448
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 128
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 36928
│   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 192, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 12352
│   │   ├── 【   Identity    】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   ├── 【  ResnetBlock  】Input shape: torch.Size([32, 128, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape: torch.Size([32, 128, 32, 32])  → Output shape: torch.Size([32, 128, 32, 32])  | Params count: 256
│   │   ├── 【     SiLU      】Input shape: torch.Size([32, 128, 32, 32])  → Output shape: torch.Size([32, 128, 32, 32])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 128, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 73792
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:      torch.Size([32, 64])      | Params count: 16448
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 128
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 36928
│   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 128, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 8256
│   │   ├── 【   Identity    】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   ├── 【  ResnetBlock  】Input shape: torch.Size([32, 128, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   │   ├── 【   GroupNorm   】Input shape: torch.Size([32, 128, 32, 32])  → Output shape: torch.Size([32, 128, 32, 32])  | Params count: 256
│   │   ├── 【     SiLU      】Input shape: torch.Size([32, 128, 32, 32])  → Output shape: torch.Size([32, 128, 32, 32])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 128, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 73792
│   │   ├── 【     SiLU      】Input shape:     torch.Size([32, 256])      → Output shape:     torch.Size([32, 256])      | Params count: 0
│   │   ├── 【    Linear     】Input shape:     torch.Size([32, 256])      → Output shape:      torch.Size([32, 64])      | Params count: 16448
│   │   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 128
│   │   ├── 【     SiLU      】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   │   ├── 【   Dropout2d   】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   │   ├── 【    Conv2d     】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 36928
│   │   ├── 【    Conv2d     】Input shape: torch.Size([32, 128, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 8256
│   │   ├── 【   Identity    】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   ├── 【   GroupNorm   】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 128
│   ├── 【     SiLU      】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 64, 32, 32])  | Params count: 0
│   ├── 【    Conv2d     】Input shape:  torch.Size([32, 64, 32, 32])  → Output shape:  torch.Size([32, 3, 32, 32])   | Params count: 1731
***Total Parameters***: 25223171 = [64000 + 16640 + 0 + 65792 + 0 + 0 + 1792 + 128 + 0 + 36928 + 0 + 16448 + 128 + 0 + 0 + 36928 + 0 + 0 + 0 + 128 + 0 + 36928 + 0 + 16448 + 128 + 0 + 0 + 36928 + 0 + 0 + 0 + 36928 + 0 + 128 + 0 + 73856 + 0 + 32896 + 256 + 0 + 0 + 147584 + 8320 + 256 + 49536 + 0 + 0 + 256 + 0 + 147584 + 0 + 32896 + 256 + 0 + 0 + 147584 + 0 + 256 + 49536 + 0 + 0 + 147584 + 0 + 256 + 0 + 295168 + 0 + 65792 + 512 + 0 + 0 + 590080 + 33024 + 512 + 197376 + 0 + 0 + 512 + 0 + 590080 + 0 + 65792 + 512 + 0 + 0 + 590080 + 0 + 512 + 197376 + 0 + 0 + 590080 + 0 + 512 + 0 + 590080 + 0 + 65792 + 512 + 0 + 0 + 590080 + 0 + 0 + 0 + 512 + 0 + 590080 + 0 + 65792 + 512 + 0 + 0 + 590080 + 0 + 0 + 0 + 512 + 0 + 590080 + 0 + 65792 + 512 + 0 + 0 + 590080 + 0 + 512 + 197376 + 0 + 0 + 512 + 0 + 590080 + 0 + 65792 + 512 + 0 + 0 + 590080 + 0 + 0 + 0 + 1024 + 0 + 1179904 + 0 + 65792 + 512 + 0 + 0 + 590080 + 131328 + 0 + 0 + 1024 + 0 + 1179904 + 0 + 65792 + 512 + 0 + 0 + 590080 + 131328 + 0 + 0 + 1024 + 0 + 1179904 + 0 + 65792 + 512 + 0 + 0 + 590080 + 131328 + 0 + 0 + 0 + 590080 + 0 + 0 + 1024 + 0 + 1179904 + 0 + 65792 + 512 + 0 + 0 + 590080 + 131328 + 512 + 197376 + 0 + 0 + 1024 + 0 + 1179904 + 0 + 65792 + 512 + 0 + 0 + 590080 + 131328 + 512 + 197376 + 0 + 0 + 768 + 0 + 884992 + 0 + 65792 + 512 + 0 + 0 + 590080 + 98560 + 512 + 197376 + 0 + 0 + 0 + 590080 + 0 + 0 + 768 + 0 + 442496 + 0 + 32896 + 256 + 0 + 0 + 147584 + 49280 + 256 + 49536 + 0 + 0 + 512 + 0 + 295040 + 0 + 32896 + 256 + 0 + 0 + 147584 + 32896 + 256 + 49536 + 0 + 0 + 384 + 0 + 221312 + 0 + 32896 + 256 + 0 + 0 + 147584 + 24704 + 256 + 49536 + 0 + 0 + 0 + 147584 + 0 + 0 + 384 + 0 + 110656 + 0 + 16448 + 128 + 0 + 0 + 36928 + 12352 + 0 + 0 + 256 + 0 + 73792 + 0 + 16448 + 128 + 0 + 0 + 36928 + 8256 + 0 + 0 + 256 + 0 + 73792 + 0 + 16448 + 128 + 0 + 0 + 36928 + 8256 + 0 + 0 + 128 + 0 + 1731 + 0 + 0]


网站公告

今日签到

点亮在社区的每一天
去签到