AIGC笔记--Stable Diffusion源码剖析之UNetModel

发布于:2024-06-11 ⋅ 阅读:(17) ⋅ 点赞:(0)

1--前言

        以论文《High-Resolution Image Synthesis with Latent Diffusion Models》  开源的项目为例,剖析Stable Diffusion经典组成部分,巩固学习加深印象。

2--UNetModel

一个可以debug的小demo:SD_UNet​​​​​​​

        以文生图为例,剖析UNetModel核心组成模块。

2-1--Forward总揽

提供的文生图Demo中,实际传入的参数只有x、timesteps和context三个,其中:

        x 表示随机初始化的噪声Tensor(shape: [B*2, 4, 64, 64],*2表示使用Classifier-Free Diffusion Guidance)。

        timesteps 表示去噪过程中每一轮传入的timestep(shape: [B*2])。

        context表示经过CLIP编码后对应的文本Prompt(shape: [B*2, 77, 768])。

    def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
        """
        Apply the model to an input batch.
        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param context: conditioning plugged in via crossattn
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"
        hs = []
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # Create sinusoidal timestep embeddings.
        emb = self.time_embed(t_emb) # MLP

        if self.num_classes is not None:
            assert y.shape == (x.shape[0],)
            emb = emb + self.label_emb(y)

        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb, context)
            hs.append(h)
        h = self.middle_block(h, emb, context)
        for module in self.output_blocks:
            h = th.cat([h, hs.pop()], dim=1)
            h = module(h, emb, context)
        h = h.type(x.dtype)
        if self.predict_codebook_ids:
            return self.id_predictor(h)
        else:
            return self.out(h)

2-2--timestep embedding生成

        使用函数 timestep_embedding() 和 self.time_embed() 对传入的timestep进行位置编码,生成sinusoidal timestep embeddings。

        其中 timestep_embedding() 函数定义如下,而self.time_embed()是一个MLP函数。

def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    if not repeat_only:
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=timesteps.device)
        args = timesteps[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    else:
        embedding = repeat(timesteps, 'b -> b d', d=dim)
    return embedding
self.time_embed = nn.Sequential(
    linear(model_channels, time_embed_dim),
    nn.SiLU(),
    linear(time_embed_dim, time_embed_dim),
)

2-3--self.input_blocks下采样

        在 Forward() 中,使用 self.input_blocks 将输入噪声进行分辨率下采样,经过下采样具体维度变化为:[B*2, 4, 64, 64] > [B*2, 1280, 8, 8];

        下采样模块共有12个 module,其组成如下:

ModuleList(
  (0): TimestepEmbedSequential(
    (0): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (1-2): 2 x TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=320, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Identity()
    )
    (1): SpatialTransformer(
      (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
      (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
      (transformer_blocks): ModuleList(
        (0): BasicTransformerBlock(
          (attn1): CrossAttention(
            (to_q): Linear(in_features=320, out_features=320, bias=False)
            (to_k): Linear(in_features=320, out_features=320, bias=False)
            (to_v): Linear(in_features=320, out_features=320, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=320, out_features=320, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (ff): FeedForward(
            (net): Sequential(
              (0): GEGLU(
                (proj): Linear(in_features=320, out_features=2560, bias=True)
              )
              (1): Dropout(p=0.0, inplace=False)
              (2): Linear(in_features=1280, out_features=320, bias=True)
            )
          )
          (attn2): CrossAttention(
            (to_q): Linear(in_features=320, out_features=320, bias=False)
            (to_k): Linear(in_features=768, out_features=320, bias=False)
            (to_v): Linear(in_features=768, out_features=320, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=320, out_features=320, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
          (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
      )
      (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (3): TimestepEmbedSequential(
    (0): Downsample(
      (op): Conv2d(320, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
  )
  (4): TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(320, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=640, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Conv2d(320, 640, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): SpatialTransformer(
      (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
      (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
      (transformer_blocks): ModuleList(
        (0): BasicTransformerBlock(
          (attn1): CrossAttention(
            (to_q): Linear(in_features=640, out_features=640, bias=False)
            (to_k): Linear(in_features=640, out_features=640, bias=False)
            (to_v): Linear(in_features=640, out_features=640, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=640, out_features=640, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (ff): FeedForward(
            (net): Sequential(
              (0): GEGLU(
                (proj): Linear(in_features=640, out_features=5120, bias=True)
              )
              (1): Dropout(p=0.0, inplace=False)
              (2): Linear(in_features=2560, out_features=640, bias=True)
            )
          )
          (attn2): CrossAttention(
            (to_q): Linear(in_features=640, out_features=640, bias=False)
            (to_k): Linear(in_features=768, out_features=640, bias=False)
            (to_v): Linear(in_features=768, out_features=640, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=640, out_features=640, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
          (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
        )
      )
      (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (5): TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=640, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Identity()
    )
    (1): SpatialTransformer(
      (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
      (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
      (transformer_blocks): ModuleList(
        (0): BasicTransformerBlock(
          (attn1): CrossAttention(
            (to_q): Linear(in_features=640, out_features=640, bias=False)
            (to_k): Linear(in_features=640, out_features=640, bias=False)
            (to_v): Linear(in_features=640, out_features=640, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=640, out_features=640, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (ff): FeedForward(
            (net): Sequential(
              (0): GEGLU(
                (proj): Linear(in_features=640, out_features=5120, bias=True)
              )
              (1): Dropout(p=0.0, inplace=False)
              (2): Linear(in_features=2560, out_features=640, bias=True)
            )
          )
          (attn2): CrossAttention(
            (to_q): Linear(in_features=640, out_features=640, bias=False)
            (to_k): Linear(in_features=768, out_features=640, bias=False)
            (to_v): Linear(in_features=768, out_features=640, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=640, out_features=640, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
          (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
        )
      )
      (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (6): TimestepEmbedSequential(
    (0): Downsample(
      (op): Conv2d(640, 640, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
  )
  (7): TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 640, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(640, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Conv2d(640, 1280, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): SpatialTransformer(
      (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
      (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
      (transformer_blocks): ModuleList(
        (0): BasicTransformerBlock(
          (attn1): CrossAttention(
            (to_q): Linear(in_features=1280, out_features=1280, bias=False)
            (to_k): Linear(in_features=1280, out_features=1280, bias=False)
            (to_v): Linear(in_features=1280, out_features=1280, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=1280, out_features=1280, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (ff): FeedForward(
            (net): Sequential(
              (0): GEGLU(
                (proj): Linear(in_features=1280, out_features=10240, bias=True)
              )
              (1): Dropout(p=0.0, inplace=False)
              (2): Linear(in_features=5120, out_features=1280, bias=True)
            )
          )
          (attn2): CrossAttention(
            (to_q): Linear(in_features=1280, out_features=1280, bias=False)
            (to_k): Linear(in_features=768, out_features=1280, bias=False)
            (to_v): Linear(in_features=768, out_features=1280, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=1280, out_features=1280, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        )
      )
      (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (8): TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Identity()
    )
    (1): SpatialTransformer(
      (norm): GroupNorm(32, 1280, eps=1e-06, affine=True)
      (proj_in): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
      (transformer_blocks): ModuleList(
        (0): BasicTransformerBlock(
          (attn1): CrossAttention(
            (to_q): Linear(in_features=1280, out_features=1280, bias=False)
            (to_k): Linear(in_features=1280, out_features=1280, bias=False)
            (to_v): Linear(in_features=1280, out_features=1280, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=1280, out_features=1280, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (ff): FeedForward(
            (net): Sequential(
              (0): GEGLU(
                (proj): Linear(in_features=1280, out_features=10240, bias=True)
              )
              (1): Dropout(p=0.0, inplace=False)
              (2): Linear(in_features=5120, out_features=1280, bias=True)
            )
          )
          (attn2): CrossAttention(
            (to_q): Linear(in_features=1280, out_features=1280, bias=False)
            (to_k): Linear(in_features=768, out_features=1280, bias=False)
            (to_v): Linear(in_features=768, out_features=1280, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=1280, out_features=1280, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          (norm3): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        )
      )
      (proj_out): Conv2d(1280, 1280, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (9): TimestepEmbedSequential(
    (0): Downsample(
      (op): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
  )
  (10-11): 2 x TimestepEmbedSequential(
    (0): ResBlock(
      (in_layers): Sequential(
        (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (h_upd): Identity()
      (x_upd): Identity()
      (emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (out_layers): Sequential(
        (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Dropout(p=0, inplace=False)
        (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Identity()
    )
  )
)

        12个 module 都使用了 TimestepEmbedSequential 类进行封装,根据不同的网络层,将输入噪声x与timestep embedding和prompt context进行运算。

class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """
    A sequential module that passes timestep embeddings to the children that
    support it as an extra input.
    """

    def forward(self, x, emb, context=None):
        for layer in self:
            if isinstance(layer, TimestepBlock):
                x = layer(x, emb)
            elif isinstance(layer, SpatialTransformer):
                x = layer(x, context)
            else:
                x = layer(x)
        return x

2-3-1--Module0

 Module 0 是一个2D卷积层,主要对输入噪声进行特征提取;

# init 初始化
self.input_blocks = nn.ModuleList(
    [
        TimestepEmbedSequential(
            conv_nd(dims, in_channels, model_channels, 3, padding=1)
        )
    ]
)

# 打印 self.input_blocks[0]
TimestepEmbedSequential(
  (0): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

2-3-2--Module1和Module2

        Module1和Module2的结构相同,都由一个ResBlock和一个SpatialTransformer组成;

# init 初始化
for _ in range(num_res_blocks):
                layers = [
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=mult * model_channels,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = mult * model_channels
                if ds in attention_resolutions:
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        #num_heads = 1
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads,
                            num_head_channels=dim_head,
                            use_new_attention_order=use_new_attention_order,
                        ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                        )
                    )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch
                input_block_chans.append(ch)

# 打印 self.input_blocks[1]
TimestepEmbedSequential(
  (0): ResBlock(
    (in_layers): Sequential(
      (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (h_upd): Identity()
    (x_upd): Identity()
    (emb_layers): Sequential(
      (0): SiLU()
      (1): Linear(in_features=1280, out_features=320, bias=True)
    )
    (out_layers): Sequential(
      (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Dropout(p=0, inplace=False)
      (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (skip_connection): Identity()
  )
  (1): SpatialTransformer(
    (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
    (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
    (transformer_blocks): ModuleList(
      (0): BasicTransformerBlock(
        (attn1): CrossAttention(
          (to_q): Linear(in_features=320, out_features=320, bias=False)
          (to_k): Linear(in_features=320, out_features=320, bias=False)
          (to_v): Linear(in_features=320, out_features=320, bias=False)
          (to_out): Sequential(
            (0): Linear(in_features=320, out_features=320, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
        (ff): FeedForward(
          (net): Sequential(
            (0): GEGLU(
              (proj): Linear(in_features=320, out_features=2560, bias=True)
            )
            (1): Dropout(p=0.0, inplace=False)
            (2): Linear(in_features=1280, out_features=320, bias=True)
          )
        )
        (attn2): CrossAttention(
          (to_q): Linear(in_features=320, out_features=320, bias=False)
          (to_k): Linear(in_features=768, out_features=320, bias=False)
          (to_v): Linear(in_features=768, out_features=320, bias=False)
          (to_out): Sequential(
            (0): Linear(in_features=320, out_features=320, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
        (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
      )
    )
    (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
  )
)

# 打印 self.input_blocks[2]
TimestepEmbedSequential(
  (0): ResBlock(
    (in_layers): Sequential(
      (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (h_upd): Identity()
    (x_upd): Identity()
    (emb_layers): Sequential(
      (0): SiLU()
      (1): Linear(in_features=1280, out_features=320, bias=True)
    )
    (out_layers): Sequential(
      (0): GroupNorm32(32, 320, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Dropout(p=0, inplace=False)
      (3): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (skip_connection): Identity()
  )
  (1): SpatialTransformer(
    (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
    (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
    (transformer_blocks): ModuleList(
      (0): BasicTransformerBlock(
        (attn1): CrossAttention(
          (to_q): Linear(in_features=320, out_features=320, bias=False)
          (to_k): Linear(in_features=320, out_features=320, bias=False)
          (to_v): Linear(in_features=320, out_features=320, bias=False)
          (to_out): Sequential(
            (0): Linear(in_features=320, out_features=320, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
        (ff): FeedForward(
          (net): Sequential(
            (0): GEGLU(
              (proj): Linear(in_features=320, out_features=2560, bias=True)
            )
            (1): Dropout(p=0.0, inplace=False)
            (2): Linear(in_features=1280, out_features=320, bias=True)
          )
        )
        (attn2): CrossAttention(
          (to_q): Linear(in_features=320, out_features=320, bias=False)
          (to_k): Linear(in_features=768, out_features=320, bias=False)
          (to_v): Linear(in_features=768, out_features=320, bias=False)
          (to_out): Sequential(
            (0): Linear(in_features=320, out_features=320, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
        (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
      )
    )
    (proj_out): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
  )
)

2-3-3--Module3

        Module3是一个下采样2D卷积层。

# init 初始化
if level != len(channel_mult) - 1:
    out_ch = ch
    self.input_blocks.append(
        TimestepEmbedSequential(
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                out_channels=out_ch,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
                down=True,
            )
            if resblock_updown
            else Downsample(
                ch, conv_resample, dims=dims, out_channels=out_ch
            )
        )
    )

# 打印 self.input_blocks[3]
TimestepEmbedSequential(
  (0): Downsample(
    (op): Conv2d(320, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  )
)

2-3-4--Module4、Module5、Module7和Module8

        与Module1和Module2的结构相同,都由一个ResBlock和一个SpatialTransformer组成,只有特征维度上的区别;

2-3-4--Module6和Module9

        与Module3的结构相同,是一个下采样2D卷积层。

2-3--5--Module10和Module11

        Module10和Module12的结构相同,只由一个ResBlock组成。

# 打印 self.input_blocks[10]
TimestepEmbedSequential(
  (0): ResBlock(
    (in_layers): Sequential(
      (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (h_upd): Identity()
    (x_upd): Identity()
    (emb_layers): Sequential(
      (0): SiLU()
      (1): Linear(in_features=1280, out_features=1280, bias=True)
    )
    (out_layers): Sequential(
      (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Dropout(p=0, inplace=False)
      (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (skip_connection): Identity()
  )
)

# 打印 self.input_blocks[11]
TimestepEmbedSequential(
  (0): ResBlock(
    (in_layers): Sequential(
      (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (h_upd): Identity()
    (x_upd): Identity()
    (emb_layers): Sequential(
      (0): SiLU()
      (1): Linear(in_features=1280, out_features=1280, bias=True)
    )
    (out_layers): Sequential(
      (0): GroupNorm32(32, 1280, eps=1e-05, affine=True)
      (1): SiLU()
      (2): Dropout(p=0, inplace=False)
      (3): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (skip_connection): Identity()
  )
)