diffusers-源码解析-四-

news/2024/10/22 12:42:31

diffusers 源码解析(四)

.\diffusers\models\attention_flax.py

# 版权声明,表明该代码的版权归 HuggingFace 团队所有
# 根据 Apache 2.0 许可证授权使用该文件,未遵守许可证不得使用
# 许可证获取链接
# 指出该软件是以“现状”分发,不附带任何明示或暗示的保证
# 具体的权限和限制请参见许可证# 导入 functools 模块,用于函数式编程工具
import functools
# 导入 math 模块,提供数学相关的功能
import math# 导入 flax.linen 模块,作为神经网络构建的工具
import flax.linen as nn
# 导入 jax 库,用于加速计算
import jax
# 导入 jax.numpy 模块,提供类似于 NumPy 的数组功能
import jax.numpy as jnpdef _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):"""多头点积注意力,查询数目有限的实现。"""# 获取 key 的维度信息,包括 key 的数量、头数和特征维度num_kv, num_heads, k_features = key.shape[-3:]# 获取 value 的特征维度v_features = value.shape[-1]# 确保 key_chunk_size 不超过 num_kvkey_chunk_size = min(key_chunk_size, num_kv)# 对查询进行缩放,防止数值溢出query = query / jnp.sqrt(k_features)@functools.partial(jax.checkpoint, prevent_cse=False)def summarize_chunk(query, key, value):# 计算查询和键之间的注意力权重attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)# 获取每个查询的最大得分,用于数值稳定性max_score = jnp.max(attn_weights, axis=-1, keepdims=True)# 计算最大得分的梯度不更新max_score = jax.lax.stop_gradient(max_score)# 计算经过 softmax 的注意力权重exp_weights = jnp.exp(attn_weights - max_score)# 计算加权后的值exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)# 获取每个查询的最大得分max_score = jnp.einsum("...qhk->...qh", max_score)return (exp_values, exp_weights.sum(axis=-1), max_score)def chunk_scanner(chunk_idx):# 动态切片获取键的部分数据key_chunk = jax.lax.dynamic_slice(operand=key,start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0],  # [...,k,h,d]slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features],  # [...,k,h,d])# 动态切片获取值的部分数据value_chunk = jax.lax.dynamic_slice(operand=value,start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0],  # [...,v,h,d]slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features],  # [...,v,h,d])return summarize_chunk(query, key_chunk, value_chunk)# 对每个键块进行注意力计算chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))# 计算全局最大得分global_max = jnp.max(chunk_max, axis=0, keepdims=True)# 计算每个块与全局最大得分的差异max_diffs = jnp.exp(chunk_max - global_max)# 更新值和权重以便于归一化chunk_values *= jnp.expand_dims(max_diffs, axis=-1)chunk_weights *= max_diffs# 计算所有块的总值和总权重all_values = chunk_values.sum(axis=0)all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)# 返回归一化后的总值return all_values / all_weightsdef jax_memory_efficient_attention(query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
):r"""# Flax 实现的内存高效多头点积注意力机制,相关文献链接Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2# 相关 GitHub 项目链接https://github.com/AminRezaei0x443/memory-efficient-attention# 参数说明:# query: 输入的查询张量,形状为 (batch..., query_length, head, query_key_depth_per_head)Args:query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)# key: 输入的键张量,形状为 (batch..., key_value_length, head, query_key_depth_per_head)key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)# value: 输入的值张量,形状为 (batch..., key_value_length, head, value_depth_per_head)value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)# precision: 计算时的数值精度,默认值为 jax.lax.Precision.HIGHESTprecision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):numerical precision for computation# query_chunk_size: 将查询数组划分的块大小,必须能整除 query_lengthquery_chunk_size (`int`, *optional*, defaults to 1024):chunk size to divide query array value must divide query_length equally without remainder# key_chunk_size: 将键和值数组划分的块大小,必须能整除 key_value_lengthkey_chunk_size (`int`, *optional*, defaults to 4096):chunk size to divide key and value array value must divide key_value_length equally without remainder# 返回值为形状为 (batch..., query_length, head, value_depth_per_head) 的数组Returns:(`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)"""# 获取查询张量的最后三个维度的大小num_q, num_heads, q_features = query.shape[-3:]# 定义一个函数,用于扫描处理每个查询块def chunk_scanner(chunk_idx, _):# 从查询数组中切片出当前块query_chunk = jax.lax.dynamic_slice(# 操作的对象是查询张量operand=query,# 起始索引,保持前面的维度不变,从 chunk_idx 开始切片start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0],  # [...,q,h,d]# 切片的大小,前面的维度不变,后面根据块大小取最小值slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features],  # [...,q,h,d])return (# 返回未使用的下一个块索引chunk_idx + query_chunk_size,  # unused ignore it# 调用注意力函数处理当前查询块_query_chunk_attention(query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size),)# 使用 jax.lax.scan 进行块的扫描处理_, res = jax.lax.scan(f=chunk_scanner,  # 处理函数init=0,  # 初始化块索引为 0xs=None,  # 不需要额外的输入数据# 根据查询块大小计算要处理的块数length=math.ceil(num_q / query_chunk_size),  # start counter  # stop counter)# 将所有块的结果在第 -3 维度拼接在一起return jnp.concatenate(res, axis=-3)  # fuse the chunked result back
# 定义一个 Flax 的多头注意力模块,遵循文献中的描述
class FlaxAttention(nn.Module):r"""Flax多头注意力模块,详见: https://arxiv.org/abs/1706.03762参数:query_dim (:obj:`int`):输入隐藏状态的维度heads (:obj:`int`, *optional*, defaults to 8):注意力头的数量dim_head (:obj:`int`, *optional*, defaults to 64):每个头内隐藏状态的维度dropout (:obj:`float`, *optional*, defaults to 0.0):dropout比率use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):启用内存高效注意力 https://arxiv.org/abs/2112.05682split_head_dim (`bool`, *optional*, defaults to `False`):是否将头维度拆分为自注意力计算的新轴。通常情况下,启用该标志可以加快Stable Diffusion 2.x和Stable Diffusion XL的计算速度。dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):参数的 `dtype`"""# 定义输入参数的类型和默认值query_dim: intheads: int = 8dim_head: int = 64dropout: float = 0.0use_memory_efficient_attention: bool = Falsesplit_head_dim: bool = Falsedtype: jnp.dtype = jnp.float32# 设置模块的初始化函数def setup(self):# 计算内部维度为每个头的维度与头的数量的乘积inner_dim = self.dim_head * self.heads# 计算缩放因子self.scale = self.dim_head**-0.5# 创建权重矩阵,使用旧的命名 {to_q, to_k, to_v, to_out}self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")# 创建键的权重矩阵self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")# 创建值的权重矩阵self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")# 创建输出的权重矩阵self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")# 创建dropout层self.dropout_layer = nn.Dropout(rate=self.dropout)# 将张量的头部维度重塑为批次维度def reshape_heads_to_batch_dim(self, tensor):# 解构张量的形状batch_size, seq_len, dim = tensor.shapehead_size = self.heads# 重塑张量形状以分离头维度tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)# 转置张量的维度tensor = jnp.transpose(tensor, (0, 2, 1, 3))# 进一步重塑为批次与头维度合并tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)return tensor# 将张量的批次维度重塑为头部维度def reshape_batch_dim_to_heads(self, tensor):# 解构张量的形状batch_size, seq_len, dim = tensor.shapehead_size = self.heads# 重塑张量形状以合并批次与头维度tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)# 转置张量的维度tensor = jnp.transpose(tensor, (0, 2, 1, 3))# 进一步重塑为合并批次与头维度tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)return tensor# 定义一个 Flax 基础变换器块层,使用 GLU 激活函数,详见:
class FlaxBasicTransformerBlock(nn.Module):r"""Flax 变换器块层,使用 `GLU` (门控线性单元) 激活函数,详见:https://arxiv.org/abs/1706.03762# 参数说明部分Parameters:dim (:obj:`int`):  # 内部隐藏状态的维度Inner hidden states dimensionn_heads (:obj:`int`):  # 注意力头的数量Number of headsd_head (:obj:`int`):  # 每个头内部隐藏状态的维度Hidden states dimension inside each headdropout (:obj:`float`, *optional*, defaults to 0.0):  # 随机失活率Dropout rateonly_cross_attention (`bool`, defaults to `False`):  # 是否仅应用交叉注意力Whether to only apply cross attention.dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):  # 参数数据类型Parameters `dtype`use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):  # 启用内存高效注意力enable memory efficient attention https://arxiv.org/abs/2112.05682split_head_dim (`bool`, *optional*, defaults to `False`):  # 是否将头维度拆分为新轴Whether to split the head dimension into a new axis for the self-attention computation. In most cases,enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL."""dim: int  # 内部隐藏状态维度的类型声明n_heads: int  # 注意力头数量的类型声明d_head: int  # 每个头的隐藏状态维度的类型声明dropout: float = 0.0  # 随机失活率的默认值only_cross_attention: bool = False  # 默认不只应用交叉注意力dtype: jnp.dtype = jnp.float32  # 默认数据类型为 jnp.float32use_memory_efficient_attention: bool = False  # 默认不启用内存高效注意力split_head_dim: bool = False  # 默认不拆分头维度def setup(self):# 设置自注意力(如果 only_cross_attention 为 True,则为交叉注意力)self.attn1 = FlaxAttention(self.dim,  # 传入的内部隐藏状态维度self.n_heads,  # 传入的注意力头数量self.d_head,  # 传入的每个头的隐藏状态维度self.dropout,  # 传入的随机失活率self.use_memory_efficient_attention,  # 是否使用内存高效注意力self.split_head_dim,  # 是否拆分头维度dtype=self.dtype,  # 传入的数据类型)# 设置交叉注意力self.attn2 = FlaxAttention(self.dim,  # 传入的内部隐藏状态维度self.n_heads,  # 传入的注意力头数量self.d_head,  # 传入的每个头的隐藏状态维度self.dropout,  # 传入的随机失活率self.use_memory_efficient_attention,  # 是否使用内存高效注意力self.split_head_dim,  # 是否拆分头维度dtype=self.dtype,  # 传入的数据类型)# 设置前馈网络self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)  # 前馈网络初始化# 设置第一个归一化层self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)  # 归一化层初始化# 设置第二个归一化层self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)  # 归一化层初始化# 设置第三个归一化层self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)  # 归一化层初始化# 设置丢弃层self.dropout_layer = nn.Dropout(rate=self.dropout)  # 丢弃层初始化# 定义可调用对象,接收隐藏状态、上下文和确定性标志def __call__(self, hidden_states, context, deterministic=True):# 保存输入的隐藏状态以供后续残差连接使用residual = hidden_states# 如果仅执行交叉注意力,进行相关的处理if self.only_cross_attention:hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)else:# 否则执行自注意力处理hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)# 将自注意力的输出与输入的残差相加hidden_states = hidden_states + residual# 交叉注意力处理residual = hidden_states# 处理交叉注意力hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)# 将交叉注意力的输出与输入的残差相加hidden_states = hidden_states + residual# 前馈网络处理residual = hidden_states# 应用前馈网络hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)# 将前馈网络的输出与输入的残差相加hidden_states = hidden_states + residual# 返回经过 dropout 处理的最终隐藏状态return self.dropout_layer(hidden_states, deterministic=deterministic)
# 定义一个二维的 Flax Transformer 模型,继承自 nn.Module
class FlaxTransformer2DModel(nn.Module):r"""A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:https://arxiv.org/pdf/1506.02025.pdf文档字符串,描述该类的功能和参数。Parameters:in_channels (:obj:`int`):Input number of channelsn_heads (:obj:`int`):Number of headsd_head (:obj:`int`):Hidden states dimension inside each headdepth (:obj:`int`, *optional*, defaults to 1):Number of transformers blockdropout (:obj:`float`, *optional*, defaults to 0.0):Dropout rateuse_linear_projection (`bool`, defaults to `False`): tbdonly_cross_attention (`bool`, defaults to `False`): tbddtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):Parameters `dtype`use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):enable memory efficient attention https://arxiv.org/abs/2112.05682split_head_dim (`bool`, *optional*, defaults to `False`):Whether to split the head dimension into a new axis for the self-attention computation. In most cases,enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL."""# 定义输入通道数in_channels: int# 定义头的数量n_heads: int# 定义每个头的隐藏状态维度d_head: int# 定义 Transformer 块的数量,默认为 1depth: int = 1# 定义 Dropout 率,默认为 0.0dropout: float = 0.0# 定义是否使用线性投影,默认为 Falseuse_linear_projection: bool = False# 定义是否仅使用交叉注意力,默认为 Falseonly_cross_attention: bool = False# 定义参数的数据类型,默认为 jnp.float32dtype: jnp.dtype = jnp.float32# 定义是否使用内存高效注意力,默认为 Falseuse_memory_efficient_attention: bool = False# 定义是否将头维度拆分为新的轴,默认为 Falsesplit_head_dim: bool = False# 设置模型的组件def setup(self):# 使用 Group Normalization 规范化层,分组数为 32,epsilon 为 1e-5self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)# 计算内部维度为头的数量乘以每个头的维度inner_dim = self.n_heads * self.d_head# 根据是否使用线性投影选择输入层if self.use_linear_projection:# 创建一个线性投影层,输出维度为 inner_dim,数据类型为 self.dtypeself.proj_in = nn.Dense(inner_dim, dtype=self.dtype)else:# 创建一个卷积层,输出维度为 inner_dim,卷积核大小为 (1, 1),步幅为 (1, 1),填充方式为 "VALID",数据类型为 self.dtypeself.proj_in = nn.Conv(inner_dim,kernel_size=(1, 1),strides=(1, 1),padding="VALID",dtype=self.dtype,)# 创建一系列 Transformer 块,数量为 depthself.transformer_blocks = [FlaxBasicTransformerBlock(inner_dim,self.n_heads,self.d_head,dropout=self.dropout,only_cross_attention=self.only_cross_attention,dtype=self.dtype,use_memory_efficient_attention=self.use_memory_efficient_attention,split_head_dim=self.split_head_dim,)for _ in range(self.depth)  # 循环生成每个 Transformer 块]# 根据是否使用线性投影选择输出层if self.use_linear_projection:# 创建一个线性投影层,输出维度为 inner_dim,数据类型为 self.dtypeself.proj_out = nn.Dense(inner_dim, dtype=self.dtype)else:# 创建一个卷积层,输出维度为 inner_dim,卷积核大小为 (1, 1),步幅为 (1, 1),填充方式为 "VALID",数据类型为 self.dtypeself.proj_out = nn.Conv(inner_dim,kernel_size=(1, 1),strides=(1, 1),padding="VALID",dtype=self.dtype,)# 创建一个 Dropout 层,Dropout 率为 self.dropoutself.dropout_layer = nn.Dropout(rate=self.dropout)# 定义可调用对象的方法,接收隐藏状态、上下文和确定性标志def __call__(self, hidden_states, context, deterministic=True):# 解构隐藏状态的形状,获取批量大小、高度、宽度和通道数batch, height, width, channels = hidden_states.shape# 保存原始隐藏状态以用于残差连接residual = hidden_states# 对隐藏状态进行归一化处理hidden_states = self.norm(hidden_states)# 如果使用线性投影,则重塑隐藏状态if self.use_linear_projection:# 将隐藏状态重塑为(batch, height * width, channels)的形状hidden_states = hidden_states.reshape(batch, height * width, channels)# 应用输入投影hidden_states = self.proj_in(hidden_states)else:# 直接应用输入投影hidden_states = self.proj_in(hidden_states)# 将隐藏状态重塑为(batch, height * width, channels)的形状hidden_states = hidden_states.reshape(batch, height * width, channels)# 遍历每个变换块,更新隐藏状态for transformer_block in self.transformer_blocks:# 通过变换块处理隐藏状态和上下文hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)# 如果使用线性投影,则先应用输出投影if self.use_linear_projection:hidden_states = self.proj_out(hidden_states)# 将隐藏状态重塑回原来的形状hidden_states = hidden_states.reshape(batch, height, width, channels)else:# 先重塑隐藏状态hidden_states = hidden_states.reshape(batch, height, width, channels)# 再应用输出投影hidden_states = self.proj_out(hidden_states)# 将隐藏状态与原始状态相加,实现残差连接hidden_states = hidden_states + residual# 返回经过dropout层处理后的隐藏状态return self.dropout_layer(hidden_states, deterministic=deterministic)
# 定义一个 Flax 的前馈神经网络模块,继承自 nn.Module
class FlaxFeedForward(nn.Module):r"""Flax 模块封装了两个线性层,中间由一个非线性激活函数分隔。它是 PyTorch 的[`FeedForward`] 类的对应物,具有以下简化:- 激活函数目前硬编码为门控线性单元,来自:https://arxiv.org/abs/2002.05202- `dim_out` 等于 `dim`。- 隐藏维度的数量硬编码为 `dim * 4` 在 [`FlaxGELU`] 中。参数:dim (:obj:`int`):内部隐藏状态的维度dropout (:obj:`float`, *可选*, 默认为 0.0):丢弃率dtype (:obj:`jnp.dtype`, *可选*, 默认为 jnp.float32):参数的数据类型"""# 定义类属性 dim、dropout 和 dtype,分别表示维度、丢弃率和数据类型dim: intdropout: float = 0.0dtype: jnp.dtype = jnp.float32# 设置方法,初始化网络层def setup(self):# 第二个线性层暂时称为 net_2,以匹配顺序层的索引self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)  # 初始化 FlaxGEGLU 网络self.net_2 = nn.Dense(self.dim, dtype=self.dtype)  # 初始化线性层# 定义前向传播方法def __call__(self, hidden_states, deterministic=True):hidden_states = self.net_0(hidden_states, deterministic=deterministic)  # 通过 net_0 处理隐藏状态hidden_states = self.net_2(hidden_states)  # 通过 net_2 处理隐藏状态return hidden_states  # 返回处理后的隐藏状态# 定义 Flax 的 GEGLU 激活层,继承自 nn.Module
class FlaxGEGLU(nn.Module):r"""Flax 实现的线性层后跟门控线性单元激活函数变体,来自https://arxiv.org/abs/2002.05202。参数:dim (:obj:`int`):输入隐藏状态的维度dropout (:obj:`float`, *可选*, 默认为 0.0):丢弃率dtype (:obj:`jnp.dtype`, *可选*, 默认为 jnp.float32):参数的数据类型"""# 定义类属性 dim、dropout 和 dtypedim: intdropout: float = 0.0dtype: jnp.dtype = jnp.float32# 设置方法,初始化网络层def setup(self):inner_dim = self.dim * 4  # 计算内部维度self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)  # 初始化线性层self.dropout_layer = nn.Dropout(rate=self.dropout)  # 初始化丢弃层# 定义前向传播方法def __call__(self, hidden_states, deterministic=True):hidden_states = self.proj(hidden_states)  # 通过线性层处理隐藏状态hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)  # 将输出分为两个部分return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)  # 返回带丢弃的激活输出

.\diffusers\models\attention_processor.py

# 版权声明,标明该文件的版权归 HuggingFace 团队所有
# 该文件根据 Apache 2.0 许可证进行许可
# 在遵守许可证的情况下,您可以使用该文件
# 许可证的副本可以在以下网址获取
# http://www.apache.org/licenses/LICENSE-2.0
# 除非法律要求或书面同意,否则软件按 "现状" 提供,不附带任何明示或暗示的担保
# 请参阅许可证以了解有关权限和限制的具体信息import inspect  # 导入 inspect 模块,用于获取对象的信息
import math  # 导入 math 模块,提供数学函数
from typing import Callable, List, Optional, Tuple, Union  # 导入类型提示相关的类型import torch  # 导入 PyTorch 库
import torch.nn.functional as F  # 导入 PyTorch 中的神经网络功能模块,并重命名为 F
from torch import nn  # 从 PyTorch 导入 nn 模块,提供神经网络的构建块from ..image_processor import IPAdapterMaskProcessor  # 从上层模块导入 IPAdapterMaskProcessor
from ..utils import deprecate, logging  # 从上层模块导入弃用和日志记录功能
from ..utils.import_utils import is_torch_npu_available, is_xformers_available  # 导入检查 PyTorch NPU 和 xformers 可用性的工具
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph  # 导入与 PyTorch 版本和图形相关的工具logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器实例,便于记录日志信息if is_torch_npu_available():  # 检查是否可以使用 PyTorch NPUimport torch_npu  # 如果可用,则导入 torch_npu 模块if is_xformers_available():  # 检查是否可以使用 xformers 库import xformers  # 如果可用,导入 xformers 模块import xformers.ops  # 导入 xformers 中的操作模块
else:  # 如果 xformers 不可用xformers = None  # 将 xformers 设为 None@maybe_allow_in_graph  # 装饰器,可能允许在图中使用该类
class Attention(nn.Module):  # 定义 Attention 类,继承自 nn.Moduler"""  # 文档字符串,描述该类是一个交叉注意力层A cross attention layer."""def __init__(  # 初始化方法,定义构造函数self,query_dim: int,  # 查询维度,类型为整数cross_attention_dim: Optional[int] = None,  # 可选的交叉注意力维度,默认为 Noneheads: int = 8,  # 注意力头的数量,默认为 8kv_heads: Optional[int] = None,  # 可选的键值头数量,默认为 Nonedim_head: int = 64,  # 每个头的维度,默认为 64dropout: float = 0.0,  # dropout 概率,默认为 0.0bias: bool = False,  # 是否使用偏置,默认为 Falseupcast_attention: bool = False,  # 是否上升注意力精度,默认为 Falseupcast_softmax: bool = False,  # 是否上升 softmax 精度,默认为 Falsecross_attention_norm: Optional[str] = None,  # 可选的交叉注意力归一化方式,默认为 Nonecross_attention_norm_num_groups: int = 32,  # 交叉注意力归一化的组数量,默认为 32qk_norm: Optional[str] = None,  # 可选的查询键归一化方式,默认为 Noneadded_kv_proj_dim: Optional[int] = None,  # 可选的添加键值投影维度,默认为 Noneadded_proj_bias: Optional[bool] = True,  # 是否为添加的投影使用偏置,默认为 Truenorm_num_groups: Optional[int] = None,  # 可选的归一化组数量,默认为 Nonespatial_norm_dim: Optional[int] = None,  # 可选的空间归一化维度,默认为 Noneout_bias: bool = True,  # 是否使用输出偏置,默认为 Truescale_qk: bool = True,  # 是否缩放查询和键,默认为 Trueonly_cross_attention: bool = False,  # 是否仅使用交叉注意力,默认为 Falseeps: float = 1e-5,  # 为数值稳定性引入的微小常数,默认为 1e-5rescale_output_factor: float = 1.0,  # 输出重标定因子,默认为 1.0residual_connection: bool = False,  # 是否使用残差连接,默认为 False_from_deprecated_attn_block: bool = False,  # 可选参数,指示是否来自弃用的注意力块,默认为 Falseprocessor: Optional["AttnProcessor"] = None,  # 可选的处理器,默认为 Noneout_dim: int = None,  # 输出维度,默认为 Nonecontext_pre_only=None,  # 上下文前处理,默认为 Nonepre_only=False,  # 是否仅进行前处理,默认为 False# 设置是否使用来自 `torch_npu` 的 npu flash attentiondef set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:r"""设置是否使用来自 `torch_npu` 的 npu flash attention。"""# 如果选择使用 npu flash attentionif use_npu_flash_attention:# 创建 NPU 注意力处理器实例processor = AttnProcessorNPU()else:# 设置注意力处理器# 默认情况下使用 AttnProcessor2_0,当使用 torch 2.x 时,# 它利用 torch.nn.functional.scaled_dot_product_attention 进行本地 Flash/内存高效注意力# 仅在其具有默认 `scale` 参数时适用。TODO: 在迁移到 torch 2.1 时移除 scale_qk 检查processor = (AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor())# 设置当前的处理器self.set_processor(processor)# 设置是否使用内存高效的 xformers 注意力def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None):pass  # 此处可能缺少实现# 设置注意力计算的切片大小def set_attention_slice(self, slice_size: int) -> None:r"""设置注意力计算的切片大小。参数:slice_size (`int`):用于注意力计算的切片大小。"""# 如果切片大小不为 None 且大于可切片头维度if slice_size is not None and slice_size > self.sliceable_head_dim:# 抛出值错误,切片大小必须小于或等于可切片头维度raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")# 如果切片大小不为 None 且添加的 kv 投影维度不为 Noneif slice_size is not None and self.added_kv_proj_dim is not None:# 创建带切片大小的 KV 处理器实例processor = SlicedAttnAddedKVProcessor(slice_size)# 如果切片大小不为 Noneelif slice_size is not None:# 创建带切片大小的注意力处理器实例processor = SlicedAttnProcessor(slice_size)# 如果添加的 kv 投影维度不为 Noneelif self.added_kv_proj_dim is not None:# 创建 KV 注意力处理器实例processor = AttnAddedKVProcessor()else:# 设置注意力处理器# 默认情况下使用 AttnProcessor2_0,当使用 torch 2.x 时,# 它利用 torch.nn.functional.scaled_dot_product_attention 进行本地 Flash/内存高效注意力# 仅在其具有默认 `scale` 参数时适用。TODO: 在迁移到 torch 2.1 时移除 scale_qk 检查processor = (AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor())# 设置当前的处理器self.set_processor(processor)# 设置要使用的注意力处理器def set_processor(self, processor: "AttnProcessor") -> None:r"""设置要使用的注意力处理器。参数:processor (`AttnProcessor`):要使用的注意力处理器。"""# 如果当前处理器在 `self._modules` 中,且传入的 `processor` 不在其中,则需要从 `self._modules` 中移除当前处理器if (hasattr(self, "processor")  # 检查当前对象是否有处理器属性and isinstance(self.processor, torch.nn.Module)  # 确保当前处理器是一个 PyTorch 模块and not isinstance(processor, torch.nn.Module)  # 检查传入的处理器不是 PyTorch 模块):# 记录日志,指出将移除已训练权重的处理器logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")# 从模块中移除当前处理器self._modules.pop("processor")# 设置当前对象的处理器为传入的处理器self.processor = processor# 获取正在使用的注意力处理器def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":r"""获取正在使用的注意力处理器。参数:return_deprecated_lora (`bool`, *可选*, 默认为 `False`):设置为 `True` 以返回过时的 LoRA 注意力处理器。返回:"AttentionProcessor": 正在使用的注意力处理器。"""# 如果不需要返回过时的 LoRA 处理器,则返回当前处理器if not return_deprecated_lora:return self.processor# 前向传播方法,处理输入的隐藏状态def forward(self,hidden_states: torch.Tensor,  # 输入的隐藏状态张量encoder_hidden_states: Optional[torch.Tensor] = None,  # 可选的编码器隐藏状态张量attention_mask: Optional[torch.Tensor] = None,  # 可选的注意力掩码张量**cross_attention_kwargs,  # 可变参数,用于交叉注意力) -> torch.Tensor:r"""  # 文档字符串,描述此方法的功能和参数The forward method of the `Attention` class.Args:  # 参数说明hidden_states (`torch.Tensor`):  # 查询的隐藏状态,类型为张量The hidden states of the query.encoder_hidden_states (`torch.Tensor`, *optional*):  # 编码器的隐藏状态,可选参数The hidden states of the encoder.attention_mask (`torch.Tensor`, *optional*):  # 注意力掩码,可选参数The attention mask to use. If `None`, no mask is applied.**cross_attention_kwargs:  # 额外的关键字参数,传递给交叉注意力Additional keyword arguments to pass along to the cross attention.Returns:  # 返回值说明`torch.Tensor`: The output of the attention layer.  # 返回注意力层的输出"""# `Attention` 类可以调用不同的注意力处理器/函数# 这里我们简单地将所有张量传递给所选的处理器类# 对于此处定义的标准处理器,`**cross_attention_kwargs` 是空的attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())  # 获取处理器调用方法的参数名集合quiet_attn_parameters = {"ip_adapter_masks"}  # 定义不需要警告的参数集合unused_kwargs = [  # 筛选出未被使用的关键字参数k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]if len(unused_kwargs) > 0:  # 如果存在未使用的关键字参数logger.warning(  # 记录警告日志f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored.")cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}  # 过滤出有效的关键字参数return self.processor(  # 调用处理器并返回结果self,hidden_states,  # 传递隐藏状态encoder_hidden_states=encoder_hidden_states,  # 传递编码器的隐藏状态attention_mask=attention_mask,  # 传递注意力掩码**cross_attention_kwargs,  # 解包有效的额外关键字参数)def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:  # 定义方法,输入张量并返回处理后的张量r"""  # 文档字符串,描述此方法的功能和参数Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`  # 将张量从 `[batch_size, seq_len, dim]` 重新形状为 `[batch_size // heads, seq_len, dim * heads]`,`heads` 为初始化时的头数量is the number of heads initialized while constructing the `Attention` class.Args:  # 参数说明tensor (`torch.Tensor`): The tensor to reshape.  # 要重新形状的张量Returns:  # 返回值说明`torch.Tensor`: The reshaped tensor.  # 返回重新形状后的张量"""head_size = self.heads  # 获取头的数量batch_size, seq_len, dim = tensor.shape  # 解包输入张量的形状tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)  # 重新调整张量的形状tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)  # 调整维度顺序并重新形状return tensor  # 返回处理后的张量# 将输入张量从形状 `[batch_size, seq_len, dim]` 转换为 `[batch_size, seq_len, heads, dim // heads]`def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:r"""将张量从 `[batch_size, seq_len, dim]` 重塑为 `[batch_size, seq_len, heads, dim // heads]`,其中 `heads` 是在构造 `Attention` 类时初始化的头数。参数:tensor (`torch.Tensor`): 要重塑的张量。out_dim (`int`, *可选*, 默认值为 `3`): 张量的输出维度。如果为 `3`,则张量被重塑为 `[batch_size * heads, seq_len, dim // heads]`。返回:`torch.Tensor`: 重塑后的张量。"""# 获取头的数量head_size = self.heads# 检查输入张量的维度,如果是三维则提取形状信息if tensor.ndim == 3:batch_size, seq_len, dim = tensor.shapeextra_dim = 1else:# 如果不是三维,提取四维形状信息batch_size, extra_dim, seq_len, dim = tensor.shape# 重塑张量为 `[batch_size, seq_len * extra_dim, head_size, dim // head_size]`tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)# 调整张量维度顺序为 `[batch_size, heads, seq_len * extra_dim, dim // heads]`tensor = tensor.permute(0, 2, 1, 3)# 如果输出维度为 3,进一步重塑张量为 `[batch_size * heads, seq_len * extra_dim, dim // heads]`if out_dim == 3:tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)# 返回重塑后的张量return tensor# 计算注意力得分的函数def get_attention_scores(self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:r"""计算注意力得分。参数:query (`torch.Tensor`): 查询张量。key (`torch.Tensor`): 键张量。attention_mask (`torch.Tensor`, *可选*): 使用的注意力掩码。如果为 `None`,则不应用掩码。返回:`torch.Tensor`: 注意力概率/得分。"""# 获取查询张量的数据类型dtype = query.dtype# 如果需要上升类型,将查询和键张量转换为浮点型if self.upcast_attention:query = query.float()key = key.float()# 如果没有提供注意力掩码,创建空的输入张量if attention_mask is None:baddbmm_input = torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device)# 设置 beta 为 0beta = 0else:# 如果有注意力掩码,将其用作输入baddbmm_input = attention_mask# 设置 beta 为 1beta = 1# 计算注意力得分attention_scores = torch.baddbmm(baddbmm_input,query,key.transpose(-1, -2),beta=beta,alpha=self.scale,)# 删除临时的输入张量del baddbmm_input# 如果需要上升类型,将注意力得分转换为浮点型if self.upcast_softmax:attention_scores = attention_scores.float()# 计算注意力概率attention_probs = attention_scores.softmax(dim=-1)# 删除注意力得分张量del attention_scores# 将注意力概率转换回原始数据类型attention_probs = attention_probs.to(dtype)# 返回注意力概率return attention_probs# 准备注意力掩码的函数def prepare_attention_mask(self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3) -> torch.Tensor:  # 定义一个函数的返回类型为 torch.Tensorr"""  # 开始文档字符串,描述函数的作用和参数Prepare the attention mask for the attention computation.  # 准备注意力计算的注意力掩码Args:  # 参数说明attention_mask (`torch.Tensor`):  # 输入参数,注意力掩码,类型为 torch.TensorThe attention mask to prepare.  # 待准备的注意力掩码target_length (`int`):  # 输入参数,目标长度,类型为 intThe target length of the attention mask. This is the length of the attention mask after padding.  # 注意力掩码的目标长度,经过填充后的长度batch_size (`int`):  # 输入参数,批处理大小,类型为 intThe batch size, which is used to repeat the attention mask.  # 批处理大小,用于重复注意力掩码out_dim (`int`, *optional*, defaults to `3`):  # 可选参数,输出维度,类型为 int,默认为 3The output dimension of the attention mask. Can be either `3` or `4`.  # 注意力掩码的输出维度,可以是 3 或 4Returns:  # 返回说明`torch.Tensor`: The prepared attention mask.  # 返回准备好的注意力掩码,类型为 torch.Tensor"""  # 结束文档字符串head_size = self.heads  # 获取头部大小,来自类的属性 headsif attention_mask is None:  # 检查注意力掩码是否为 Nonereturn attention_mask  # 如果是 None,直接返回current_length: int = attention_mask.shape[-1]  # 获取当前注意力掩码的长度if current_length != target_length:  # 检查当前长度是否与目标长度不匹配if attention_mask.device.type == "mps":  # 如果设备类型是 "mps"# HACK: MPS: Does not support padding by greater than dimension of input tensor.  # HACK: MPS 不支持填充超过输入张量的维度# Instead, we can manually construct the padding tensor.  # 所以我们手动构建填充张量padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)  # 定义填充张量的形状padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)  # 创建全零填充张量attention_mask = torch.cat([attention_mask, padding], dim=2)  # 在最后一个维度上拼接填充张量else:  # 如果不是 "mps" 设备# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:  # TODO: 对于如 stable-diffusion 的管道,填充交叉注意力掩码#       we want to instead pad by (0, remaining_length), where remaining_length is:  # 我们希望用 (0, remaining_length) 填充,其中 remaining_length 是#       remaining_length: int = target_length - current_length  # remaining_length 的计算# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding  # TODO: 重新启用相关测试attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)  # 用零填充注意力掩码到目标长度if out_dim == 3:  # 如果输出维度是 3if attention_mask.shape[0] < batch_size * head_size:  # 检查注意力掩码的第一维是否小于批处理大小乘以头部大小attention_mask = attention_mask.repeat_interleave(head_size, dim=0)  # 在第一维上重复注意力掩码elif out_dim == 4:  # 如果输出维度是 4attention_mask = attention_mask.unsqueeze(1)  # 在第一维增加一个维度attention_mask = attention_mask.repeat_interleave(head_size, dim=1)  # 在第二维上重复注意力掩码return attention_mask  # 返回准备好的注意力掩码# 定义一个函数用于规范化编码器的隐藏状态,接受一个张量作为输入并返回一个张量def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:r"""规范化编码器隐藏状态。构造 `Attention` 类时需要指定 `self.norm_cross`。参数:encoder_hidden_states (`torch.Tensor`): 编码器的隐藏状态。返回:`torch.Tensor`: 规范化后的编码器隐藏状态。"""# 确保在调用此方法之前已定义 `self.norm_cross`assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"# 检查 `self.norm_cross` 是否为 LayerNorm 类型if isinstance(self.norm_cross, nn.LayerNorm):# 对编码器隐藏状态进行层归一化encoder_hidden_states = self.norm_cross(encoder_hidden_states)# 检查 `self.norm_cross` 是否为 GroupNorm 类型elif isinstance(self.norm_cross, nn.GroupNorm):# GroupNorm 沿通道维度进行归一化,并期望输入形状为 (N, C, *)。# 此时我们希望沿隐藏维度进行归一化,因此需要调整形状# (batch_size, sequence_length, hidden_size) -># (batch_size, hidden_size, sequence_length)encoder_hidden_states = encoder_hidden_states.transpose(1, 2)  # 转置张量以调整维度顺序encoder_hidden_states = self.norm_cross(encoder_hidden_states)  # 对转置后的张量进行归一化encoder_hidden_states = encoder_hidden_states.transpose(1, 2)  # 再次转置回原始顺序else:# 如果 `self.norm_cross` 既不是 LayerNorm 也不是 GroupNorm,则触发断言失败assert False# 返回规范化后的编码器隐藏状态return encoder_hidden_states# 该装饰器在计算图中禁止梯度计算,以节省内存和加快推理速度@torch.no_grad()# 定义一个融合投影的方法,默认参数 fuse 为 Truedef fuse_projections(self, fuse=True):# 获取 to_q 权重的设备信息device = self.to_q.weight.data.device# 获取 to_q 权重的数据类型dtype = self.to_q.weight.data.dtype# 如果不是交叉注意力if not self.is_cross_attention:# 获取权重矩阵的拼接concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])# 输入特征数为拼接后权重的列数in_features = concatenated_weights.shape[1]# 输出特征数为拼接后权重的行数out_features = concatenated_weights.shape[0]# 创建一个新的线性投影层并复制权重self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)# 复制拼接后的权重到新的层self.to_qkv.weight.copy_(concatenated_weights)# 如果使用偏置if self.use_bias:# 拼接 q、k、v 的偏置concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])# 复制拼接后的偏置到新的层self.to_qkv.bias.copy_(concatenated_bias)# 如果是交叉注意力else:# 获取 k 和 v 权重的拼接concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])# 输入特征数为拼接后权重的列数in_features = concatenated_weights.shape[1]# 输出特征数为拼接后权重的行数out_features = concatenated_weights.shape[0]# 创建一个新的线性投影层并复制权重self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)# 复制拼接后的权重到新的层self.to_kv.weight.copy_(concatenated_weights)# 如果使用偏置if self.use_bias:# 拼接 k 和 v 的偏置concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])# 复制拼接后的偏置到新的层self.to_kv.bias.copy_(concatenated_bias)# 处理 SD3 和其他添加的投影if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):# 获取额外投影的权重拼接concatenated_weights = torch.cat([self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data])# 输入特征数为拼接后权重的列数in_features = concatenated_weights.shape[1]# 输出特征数为拼接后权重的行数out_features = concatenated_weights.shape[0]# 创建一个新的线性投影层并复制权重self.to_added_qkv = nn.Linear(in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype)# 复制拼接后的权重到新的层self.to_added_qkv.weight.copy_(concatenated_weights)# 如果使用偏置if self.added_proj_bias:# 拼接额外投影的偏置concatenated_bias = torch.cat([self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data])# 复制拼接后的偏置到新的层self.to_added_qkv.bias.copy_(concatenated_bias)# 将融合状态存储到属性中self.fused_projections = fuse
# 定义一个处理器类,用于执行与注意力相关的计算
class AttnProcessor:r"""默认处理器,用于执行与注意力相关的计算。"""# 实现可调用方法,处理注意力计算def __call__(self,attn: Attention,  # 注意力对象hidden_states: torch.Tensor,  # 输入的隐藏状态张量encoder_hidden_states: Optional[torch.Tensor] = None,  # 编码器隐藏状态(可选)attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码(可选)temb: Optional[torch.Tensor] = None,  # 额外的时间嵌入(可选)*args,  # 额外的位置参数**kwargs,  # 额外的关键字参数) -> torch.Tensor:  # 返回处理后的张量# 检查是否有额外参数或已弃用的 scale 参数if len(args) > 0 or kwargs.get("scale", None) is not None:# 构建弃用警告消息deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."# 调用弃用处理函数deprecate("scale", "1.0.0", deprecation_message)# 初始化残差为隐藏状态residual = hidden_states# 如果空间归一化存在,则应用于隐藏状态if attn.spatial_norm is not None:hidden_states = attn.spatial_norm(hidden_states, temb)# 获取输入张量的维度input_ndim = hidden_states.ndim# 如果输入是四维的,则调整形状if input_ndim == 4:# 解包隐藏状态的形状batch_size, channel, height, width = hidden_states.shape# 重新调整形状为(batch_size, channel, height*width)并转置hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)# 根据编码器隐藏状态的存在与否,获取批次大小和序列长度batch_size, sequence_length, _ = (hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape)# 准备注意力掩码attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)# 如果组归一化存在,则应用于隐藏状态if attn.group_norm is not None:hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)# 将隐藏状态转换为查询向量query = attn.to_q(hidden_states)# 如果没有编码器隐藏状态,使用隐藏状态作为编码器隐藏状态if encoder_hidden_states is None:encoder_hidden_states = hidden_states# 如果需要规范化编码器隐藏状态,则应用规范化elif attn.norm_cross:encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)# 从编码器隐藏状态中获取键和值key = attn.to_k(encoder_hidden_states)value = attn.to_v(encoder_hidden_states)# 将查询、键和值转换为批次维度query = attn.head_to_batch_dim(query)key = attn.head_to_batch_dim(key)value = attn.head_to_batch_dim(value)# 计算注意力分数attention_probs = attn.get_attention_scores(query, key, attention_mask)# 通过注意力分数加权求值hidden_states = torch.bmm(attention_probs, value)# 将隐藏状态转换回头维度hidden_states = attn.batch_to_head_dim(hidden_states)# 线性投影hidden_states = attn.to_out[0](hidden_states)# 应用 dropouthidden_states = attn.to_out[1](hidden_states)# 如果输入是四维的,调整回原始形状if input_ndim == 4:hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)# 如果存在残差连接,则将残差加回隐藏状态if attn.residual_connection:hidden_states = hidden_states + residual# 将隐藏状态归一化到输出因子hidden_states = hidden_states / attn.rescale_output_factor# 返回最终的隐藏状态return hidden_states# 定义一个处理器类,用于实现自定义扩散方法的注意力
class CustomDiffusionAttnProcessor(nn.Module):r"""实现自定义扩散方法的注意力处理器。# 定义参数说明Args:train_kv (`bool`, defaults to `True`):  # 是否重新训练对应于文本特征的键值矩阵Whether to newly train the key and value matrices corresponding to the text features.train_q_out (`bool`, defaults to `True`):  # 是否重新训练对应于潜在图像特征的查询矩阵Whether to newly train query matrices corresponding to the latent image features.hidden_size (`int`, *optional*, defaults to `None`):  # 注意力层的隐藏大小The hidden size of the attention layer.cross_attention_dim (`int`, *optional*, defaults to `None`):  # 编码器隐藏状态中的通道数量The number of channels in the `encoder_hidden_states`.out_bias (`bool`, defaults to `True`):  # 是否在 `train_q_out` 中包含偏置参数Whether to include the bias parameter in `train_q_out`.dropout (`float`, *optional*, defaults to 0.0):  # 使用的 dropout 概率The dropout probability to use."""# 初始化方法def __init__(self,  # 初始化方法的第一个参数,表示对象本身train_kv: bool = True,  # 设置键值矩阵训练的默认值为 Truetrain_q_out: bool = True,  # 设置查询矩阵训练的默认值为 Truehidden_size: Optional[int] = None,  # 隐藏层大小,默认为 Nonecross_attention_dim: Optional[int] = None,  # 跨注意力维度,默认为 Noneout_bias: bool = True,  # 输出偏置参数的默认值为 Truedropout: float = 0.0,  # 默认的 dropout 概率为 0.0):super().__init__()  # 调用父类的初始化方法self.train_kv = train_kv  # 保存键值训练标志self.train_q_out = train_q_out  # 保存查询输出训练标志self.hidden_size = hidden_size  # 保存隐藏层大小self.cross_attention_dim = cross_attention_dim  # 保存跨注意力维度# `_custom_diffusion` id 方便序列化和加载if self.train_kv:  # 如果需要训练键值self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)  # 创建键的线性层self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)  # 创建值的线性层if self.train_q_out:  # 如果需要训练查询输出self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)  # 创建查询的线性层self.to_out_custom_diffusion = nn.ModuleList([])  # 初始化输出层的模块列表self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))  # 添加线性输出层self.to_out_custom_diffusion.append(nn.Dropout(dropout))  # 添加 dropout 层# 可调用方法def __call__(  # 定义对象被调用时的行为self,  # 第一个参数,表示对象本身attn: Attention,  # 注意力对象hidden_states: torch.Tensor,  # 隐藏状态张量encoder_hidden_states: Optional[torch.Tensor] = None,  # 编码器隐藏状态,默认为 Noneattention_mask: Optional[torch.Tensor] = None,  # 注意力掩码,默认为 None# 返回类型为 torch.Tensor) -> torch.Tensor:# 获取隐藏状态的批量大小和序列长度batch_size, sequence_length, _ = hidden_states.shape# 准备注意力掩码以适应当前批量和序列长度attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)# 如果需要训练查询输出,则使用自定义扩散进行转换if self.train_q_out:query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)else:# 否则使用标准的查询转换query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))# 检查编码器隐藏状态是否为 Noneif encoder_hidden_states is None:# 如果是,则不进行交叉注意力crossattn = Falseencoder_hidden_states = hidden_stateselse:# 否则,启用交叉注意力crossattn = True# 如果需要归一化编码器隐藏状态,则进行归一化if attn.norm_cross:encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)# 如果需要训练键值对if self.train_kv:# 使用自定义扩散获取键和值key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))# 将键和值转换为查询的权重数据类型key = key.to(attn.to_q.weight.dtype)value = value.to(attn.to_q.weight.dtype)else:# 否则使用标准的键和值转换key = attn.to_k(encoder_hidden_states)value = attn.to_v(encoder_hidden_states)# 如果进行交叉注意力if crossattn:# 创建与键相同形状的张量以进行detach操作detach = torch.ones_like(key)detach[:, :1, :] = detach[:, :1, :] * 0.0# 应用detach逻辑以阻止梯度流动key = detach * key + (1 - detach) * key.detach()value = detach * value + (1 - detach) * value.detach()# 将查询、键和值转换为批次维度query = attn.head_to_batch_dim(query)key = attn.head_to_batch_dim(key)value = attn.head_to_batch_dim(value)# 计算注意力分数attention_probs = attn.get_attention_scores(query, key, attention_mask)# 使用注意力分数和值进行批量矩阵乘法hidden_states = torch.bmm(attention_probs, value)# 将隐藏状态转换回头维度hidden_states = attn.batch_to_head_dim(hidden_states)# 如果需要训练查询输出if self.train_q_out:# 线性投影hidden_states = self.to_out_custom_diffusion[0](hidden_states)# 应用dropouthidden_states = self.to_out_custom_diffusion[1](hidden_states)else:# 否则使用标准的线性投影hidden_states = attn.to_out[0](hidden_states)# 应用dropouthidden_states = attn.to_out[1](hidden_states)# 返回最终的隐藏状态return hidden_states
# 定义一个带有额外可学习的键和值矩阵的注意力处理器类
class AttnAddedKVProcessor:r"""处理器,用于执行与文本编码器相关的注意力计算"""# 定义调用方法,以实现注意力计算def __call__(self,attn: Attention,  # 注意力对象hidden_states: torch.Tensor,  # 输入的隐藏状态张量encoder_hidden_states: Optional[torch.Tensor] = None,  # 编码器的隐藏状态(可选)attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码(可选)*args,  # 其他位置参数**kwargs,  # 其他关键字参数) -> torch.Tensor:  # 返回类型为张量# 检查是否传递了多余的参数或已弃用的 scale 参数if len(args) > 0 or kwargs.get("scale", None) is not None:deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."# 发出弃用警告deprecate("scale", "1.0.0", deprecation_message)# 将隐藏状态赋值给残差residual = hidden_states# 重塑隐藏状态的形状,并转置维度hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)# 获取批大小和序列长度batch_size, sequence_length, _ = hidden_states.shape# 准备注意力掩码attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)# 如果没有编码器隐藏状态,则使用输入的隐藏状态if encoder_hidden_states is None:encoder_hidden_states = hidden_states# 如果需要进行归一化处理elif attn.norm_cross:encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)# 对隐藏状态进行分组归一化处理hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)# 将隐藏状态转换为查询query = attn.to_q(hidden_states)# 将查询从头维度转换为批维度query = attn.head_to_batch_dim(query)# 将编码器隐藏状态投影为键和值encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)# 将投影结果转换为批维度encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)# 如果不是仅进行交叉注意力if not attn.only_cross_attention:# 将隐藏状态转换为键和值key = attn.to_k(hidden_states)value = attn.to_v(hidden_states)# 转换为批维度key = attn.head_to_batch_dim(key)value = attn.head_to_batch_dim(value)# 将编码器键和值与当前键和值拼接key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)else:# 仅使用编码器的键和值key = encoder_hidden_states_key_projvalue = encoder_hidden_states_value_proj# 获取注意力概率attention_probs = attn.get_attention_scores(query, key, attention_mask)# 计算隐藏状态的新值hidden_states = torch.bmm(attention_probs, value)# 将隐藏状态转换回头维度hidden_states = attn.batch_to_head_dim(hidden_states)# 线性投影hidden_states = attn.to_out[0](hidden_states)# 应用 dropouthidden_states = attn.to_out[1](hidden_states)# 重塑隐藏状态,并将残差加回hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)hidden_states = hidden_states + residual# 返回处理后的隐藏状态return hidden_states# 定义另一个注意力处理器类
class AttnAddedKVProcessor2_0:r"""# 处理缩放点积注意力的处理器(如果使用 PyTorch 2.0,默认启用),# 其中为文本编码器添加了额外的可学习的键和值矩阵。"""# 初始化方法def __init__(self):# 检查 F 中是否有 "scaled_dot_product_attention" 属性if not hasattr(F, "scaled_dot_product_attention"):# 如果没有,抛出 ImportError,提示用户需要升级到 PyTorch 2.0raise ImportError("AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")# 定义调用方法def __call__(self,attn: Attention,  # 输入的注意力机制对象hidden_states: torch.Tensor,  # 隐藏状态张量encoder_hidden_states: Optional[torch.Tensor] = None,  # 可选的编码器隐藏状态张量attention_mask: Optional[torch.Tensor] = None,  # 可选的注意力掩码张量*args,  # 额外的位置参数**kwargs,  # 额外的关键字参数) -> torch.Tensor:  # 指定函数返回类型为 torch.Tensor# 检查参数是否存在或 scale 参数是否被提供if len(args) > 0 or kwargs.get("scale", None) is not None:# 设置弃用消息,告知 scale 参数将被忽略deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."# 调用 deprecate 函数发出弃用警告deprecate("scale", "1.0.0", deprecation_message)# 将输入的 hidden_states 赋值给 residualresidual = hidden_states# 调整 hidden_states 的形状并进行转置hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)# 获取 batch_size 和 sequence_lengthbatch_size, sequence_length, _ = hidden_states.shape# 准备注意力掩码attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)# 如果没有提供 encoder_hidden_states,则使用 hidden_statesif encoder_hidden_states is None:encoder_hidden_states = hidden_states# 如果需要归一化交叉隐藏状态elif attn.norm_cross:encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)# 对 hidden_states 进行分组归一化hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)# 计算查询向量query = attn.to_q(hidden_states)# 将查询向量转换为批次维度query = attn.head_to_batch_dim(query, out_dim=4)# 生成 encoder_hidden_states 的键和值的投影encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)# 将键和值转换为批次维度encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)# 如果不是只进行交叉注意力if not attn.only_cross_attention:# 计算当前 hidden_states 的键和值key = attn.to_k(hidden_states)value = attn.to_v(hidden_states)# 转换为批次维度key = attn.head_to_batch_dim(key, out_dim=4)value = attn.head_to_batch_dim(value, out_dim=4)# 将键和值与 encoder 的键和值连接key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)else:# 如果只进行交叉注意力,使用 encoder 的键和值key = encoder_hidden_states_key_projvalue = encoder_hidden_states_value_proj# 计算缩放点积注意力的输出,形状为 (batch, num_heads, seq_len, head_dim)# TODO: 在迁移到 Torch 2.1 时添加对 attn.scale 的支持hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)# 转置并重塑 hidden_stateshidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])# 进行线性投影hidden_states = attn.to_out[0](hidden_states)# 进行 dropouthidden_states = attn.to_out[1](hidden_states)# 转置并重塑回 residual 的形状hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)# 将 residual 加到 hidden_states 上hidden_states = hidden_states + residual# 返回最终的 hidden_statesreturn hidden_states
# 定义一个名为 JointAttnProcessor2_0 的类,用于处理自注意力投影
class JointAttnProcessor2_0:"""Attention processor used typically in processing the SD3-like self-attention projections."""# 初始化方法def __init__(self):# 检查 F 是否有 scaled_dot_product_attention 属性if not hasattr(F, "scaled_dot_product_attention"):# 如果没有,抛出导入错误,提示需要升级 PyTorch 到 2.0raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")# 定义调用方法,接受多个参数def __call__(self,attn: Attention,  # 自注意力对象hidden_states: torch.FloatTensor,  # 当前隐藏状态的张量encoder_hidden_states: torch.FloatTensor = None,  # 编码器的隐藏状态,默认为 Noneattention_mask: Optional[torch.FloatTensor] = None,  # 可选的注意力掩码,默认为 None*args,  # 额外的位置参数**kwargs,  # 额外的关键字参数# 返回一个浮点张量) -> torch.FloatTensor:# 保存输入的隐藏状态,以便后续使用residual = hidden_states# 获取隐藏状态的维度input_ndim = hidden_states.ndim# 如果隐藏状态是四维的if input_ndim == 4:# 解包隐藏状态的形状为批大小、通道、高度和宽度batch_size, channel, height, width = hidden_states.shape# 将隐藏状态重塑为三维,并进行转置hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)# 获取编码器隐藏状态的维度context_input_ndim = encoder_hidden_states.ndim# 如果编码器隐藏状态是四维的if context_input_ndim == 4:# 解包编码器隐藏状态的形状为批大小、通道、高度和宽度batch_size, channel, height, width = encoder_hidden_states.shape# 将编码器隐藏状态重塑为三维,并进行转置encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)# 获取编码器隐藏状态的批大小batch_size = encoder_hidden_states.shape[0]# 计算 `sample` 投影query = attn.to_q(hidden_states)key = attn.to_k(hidden_states)value = attn.to_v(hidden_states)# 计算 `context` 投影encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)# 合并注意力查询、键和值query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)# 获取键的最后一维大小inner_dim = key.shape[-1]# 计算每个头的维度head_dim = inner_dim // attn.heads# 重塑查询、键和值以适应多个头query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)# 计算缩放点积注意力hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)# 转置并重塑隐藏状态hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)# 转换为查询的类型hidden_states = hidden_states.to(query.dtype)# 拆分注意力输出hidden_states, encoder_hidden_states = (hidden_states[:, : residual.shape[1]],  # 获取原隐藏状态的部分hidden_states[:, residual.shape[1] :],  # 获取编码器隐藏状态的部分)# 进行线性投影hidden_states = attn.to_out[0](hidden_states)# 进行 dropouthidden_states = attn.to_out[1](hidden_states)# 如果上下文不是仅限于编码器if not attn.context_pre_only:# 对编码器隐藏状态进行额外处理encoder_hidden_states = attn.to_add_out(encoder_hidden_states)# 如果输入是四维的,进行转置和重塑if input_ndim == 4:hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)# 如果上下文输入是四维的,进行转置和重塑if context_input_ndim == 4:encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)# 返回处理后的隐藏状态和编码器隐藏状态return hidden_states, encoder_hidden_states
# 定义一个类,PAGJointAttnProcessor2_0,用于处理自注意力投影
class PAGJointAttnProcessor2_0:"""Attention processor used typically in processing the SD3-like self-attention projections."""# 初始化方法def __init__(self):# 检查是否存在名为"scaled_dot_product_attention"的属性if not hasattr(F, "scaled_dot_product_attention"):# 如果不存在,则抛出导入错误,提示需要升级PyTorch到2.0raise ImportError("PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")# 可调用方法,接受注意力对象和隐藏状态def __call__(self,attn: Attention,hidden_states: torch.FloatTensor,encoder_hidden_states: torch.FloatTensor = None,# 其他可选参数attention_mask: Optional[torch.FloatTensor] = None,*args,**kwargs,
# 定义另一个类,PAGCFGJointAttnProcessor2_0,类似于PAGJointAttnProcessor2_0
class PAGCFGJointAttnProcessor2_0:"""Attention processor used typically in processing the SD3-like self-attention projections."""# 初始化方法def __init__(self):# 检查是否存在名为"scaled_dot_product_attention"的属性if not hasattr(F, "scaled_dot_product_attention"):# 如果不存在,则抛出导入错误,提示需要升级PyTorch到2.0raise ImportError("PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")# 可调用方法,接受注意力对象和隐藏状态def __call__(self,attn: Attention,hidden_states: torch.FloatTensor,encoder_hidden_states: torch.FloatTensor = None,# 其他可选参数attention_mask: Optional[torch.FloatTensor] = None,*args,**kwargs,
# 定义第三个类,FusedJointAttnProcessor2_0,处理自注意力投影
class FusedJointAttnProcessor2_0:"""Attention processor used typically in processing the SD3-like self-attention projections."""# 初始化方法def __init__(self):# 检查是否存在名为"scaled_dot_product_attention"的属性if not hasattr(F, "scaled_dot_product_attention"):# 如果不存在,则抛出导入错误,提示需要升级PyTorch到2.0raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")# 可调用方法,接受注意力对象和隐藏状态def __call__(self,attn: Attention,hidden_states: torch.FloatTensor,encoder_hidden_states: torch.FloatTensor = None,# 其他可选参数attention_mask: Optional[torch.FloatTensor] = None,*args,**kwargs,) -> torch.FloatTensor:# 将隐藏状态赋值给残差变量residual = hidden_states# 获取隐藏状态的维度input_ndim = hidden_states.ndim# 如果隐藏状态是四维的,进行维度变换if input_ndim == 4:# 解包隐藏状态的形状batch_size, channel, height, width = hidden_states.shape# 将隐藏状态变形为(batch_size, channel, height * width)并转置hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)# 获取编码器隐藏状态的维度context_input_ndim = encoder_hidden_states.ndim# 如果编码器隐藏状态是四维的,进行维度变换if context_input_ndim == 4:# 解包编码器隐藏状态的形状batch_size, channel, height, width = encoder_hidden_states.shape# 将编码器隐藏状态变形为(batch_size, channel, height * width)并转置encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)# 获取编码器隐藏状态的批量大小batch_size = encoder_hidden_states.shape[0]# `sample` 进行投影qkv = attn.to_qkv(hidden_states)# 计算每个分量的大小split_size = qkv.shape[-1] // 3# 将qkv拆分为query、key和valuequery, key, value = torch.split(qkv, split_size, dim=-1)# `context` 进行投影encoder_qkv = attn.to_added_qkv(encoder_hidden_states)# 计算编码器qkv的分量大小split_size = encoder_qkv.shape[-1] // 3# 将编码器qkv拆分为查询、键和值的投影(encoder_hidden_states_query_proj,encoder_hidden_states_key_proj,encoder_hidden_states_value_proj,) = torch.split(encoder_qkv, split_size, dim=-1)# 进行注意力计算# 将query、key、value进行连接query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)# 获取key的最后一维大小inner_dim = key.shape[-1]# 计算每个头的维度head_dim = inner_dim // attn.heads# 调整query的形状以适应多头注意力query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)# 调整key的形状以适应多头注意力key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)# 调整value的形状以适应多头注意力value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)# 进行缩放点积注意力计算hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)# 调整hidden_states的形状hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)# 将hidden_states转换为与query相同的数据类型hidden_states = hidden_states.to(query.dtype)# 拆分注意力输出hidden_states, encoder_hidden_states = (# 保留残差形状的部分hidden_states[:, : residual.shape[1]],# 剩余的部分hidden_states[:, residual.shape[1] :],)# 线性投影hidden_states = attn.to_out[0](hidden_states)# 进行dropouthidden_states = attn.to_out[1](hidden_states)# 如果不是只使用上下文,进行编码器输出的投影if not attn.context_pre_only:encoder_hidden_states = attn.to_add_out(encoder_hidden_states)# 如果输入是四维的,调整hidden_states的形状if input_ndim == 4:hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)# 如果上下文输入是四维的,调整encoder_hidden_states的形状if context_input_ndim == 4:encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)# 返回hidden_states和encoder_hidden_statesreturn hidden_states, encoder_hidden_states
# 定义一个用于处理 Aura Flow 的注意力处理器类
class AuraFlowAttnProcessor2_0:"""Attention processor used typically in processing Aura Flow."""# 初始化方法def __init__(self):# 检查 F 是否具有 scaled_dot_product_attention 属性,并确保 PyTorch 版本符合要求if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):# 如果不满足条件,抛出导入错误,提示用户升级 PyTorchraise ImportError("AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. ")# 可调用方法,用于处理输入的注意力和隐藏状态def __call__(self,attn: Attention,hidden_states: torch.FloatTensor,encoder_hidden_states: torch.FloatTensor = None,*args,**kwargs,
# 定义一个用于处理 Aura Flow 的融合投影注意力处理器类
class FusedAuraFlowAttnProcessor2_0:"""Attention processor used typically in processing Aura Flow with fused projections."""# 初始化方法def __init__(self):# 检查 F 是否具有 scaled_dot_product_attention 属性,并确保 PyTorch 版本符合要求if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):# 如果不满足条件,抛出导入错误,提示用户升级 PyTorchraise ImportError("FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. ")# 可调用方法,用于处理输入的注意力和隐藏状态def __call__(self,attn: Attention,hidden_states: torch.FloatTensor,encoder_hidden_states: torch.FloatTensor = None,*args,**kwargs,
# YiYi 待办事项:重构与 rope 相关的函数/类
def apply_rope(xq, xk, freqs_cis):# 将 xq 转换为浮点型,并重新调整形状以便处理xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)# 将 xk 转换为浮点型,并重新调整形状以便处理xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)# 计算 xq 的输出,结合频率复数xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]# 计算 xk 的输出,结合频率复数xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]# 返回调整形状后的 xq_out 和 xk_out,并确保与原始类型匹配return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)# 定义一个实现缩放点积注意力的处理器类
class FluxSingleAttnProcessor2_0:r"""Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0)."""# 初始化方法def __init__(self):# 检查 F 是否具有 scaled_dot_product_attention 属性if not hasattr(F, "scaled_dot_product_attention"):# 如果不满足条件,抛出导入错误,提示用户升级 PyTorchraise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")# 可调用方法,用于处理输入的注意力和隐藏状态def __call__(self,attn: Attention,hidden_states: torch.Tensor,encoder_hidden_states: Optional[torch.Tensor] = None,attention_mask: Optional[torch.FloatTensor] = None,image_rotary_emb: Optional[torch.Tensor] = None,# 定义函数的返回类型为 torch.Tensor) -> torch.Tensor:# 获取 hidden_states 的维度数量input_ndim = hidden_states.ndim# 如果输入的维度为 4if input_ndim == 4:# 解包 hidden_states 的形状为 batch_size, channel, height, widthbatch_size, channel, height, width = hidden_states.shape# 将 hidden_states 视图调整为 (batch_size, channel, height * width) 并转置hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)# 如果 encoder_hidden_states 为 None,则获取 hidden_states 的形状# 否则获取 encoder_hidden_states 的形状batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape# 将 hidden_states 转换为查询向量query = attn.to_q(hidden_states)# 如果 encoder_hidden_states 为 None,将其设置为 hidden_statesif encoder_hidden_states is None:encoder_hidden_states = hidden_states# 将 encoder_hidden_states 转换为键向量key = attn.to_k(encoder_hidden_states)# 将 encoder_hidden_states 转换为值向量value = attn.to_v(encoder_hidden_states)# 获取键的最后一个维度的大小inner_dim = key.shape[-1]# 计算每个头的维度head_dim = inner_dim // attn.heads# 将查询向量调整视图为 (batch_size, -1, attn.heads, head_dim) 并转置query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)# 将键向量调整视图为 (batch_size, -1, attn.heads, head_dim) 并转置key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)# 将值向量调整视图为 (batch_size, -1, attn.heads, head_dim) 并转置value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)# 如果存在规范化查询的层,则对查询进行规范化if attn.norm_q is not None:query = attn.norm_q(query)# 如果存在规范化键的层,则对键进行规范化if attn.norm_k is not None:key = attn.norm_k(key)# 如果需要应用 RoPEif image_rotary_emb is not None:# 应用旋转嵌入到查询和键上query, key = apply_rope(query, key, image_rotary_emb)# 计算缩放点积注意力,输出形状为 (batch, num_heads, seq_len, head_dim)hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)# 转置并调整 hidden_states 的形状为 (batch_size, -1, attn.heads * head_dim)hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)# 将 hidden_states 转换为与查询相同的数据类型hidden_states = hidden_states.to(query.dtype)# 如果输入维度为 4,将 hidden_states 转置并调整形状回原始维度if input_ndim == 4:hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)# 返回处理后的 hidden_statesreturn hidden_states
# 定义一个名为 FluxAttnProcessor2_0 的类,通常用于处理 SD3 类自注意力投影
class FluxAttnProcessor2_0:"""Attention processor used typically in processing the SD3-like self-attention projections."""# 初始化方法def __init__(self):# 检查 F 是否有 scaled_dot_product_attention 属性,如果没有则抛出 ImportErrorif not hasattr(F, "scaled_dot_product_attention"):raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")# 定义调用方法,使类实例可被调用def __call__(self,attn: Attention,  # 接收 Attention 对象hidden_states: torch.FloatTensor,  # 接收隐藏状态张量encoder_hidden_states: torch.FloatTensor = None,  # 可选的编码器隐藏状态张量attention_mask: Optional[torch.FloatTensor] = None,  # 可选的注意力掩码张量image_rotary_emb: Optional[torch.Tensor] = None,  # 可选的图像旋转嵌入张量):# 此处将实现自注意力的具体处理逻辑# 定义一个名为 CogVideoXAttnProcessor2_0 的类,专用于 CogVideoX 模型的缩放点积注意力处理
class CogVideoXAttnProcessor2_0:r"""Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding onquery and key vectors, but does not include spatial normalization."""# 初始化方法def __init__(self):# 检查 F 是否有 scaled_dot_product_attention 属性,如果没有则抛出 ImportErrorif not hasattr(F, "scaled_dot_product_attention"):raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")# 定义调用方法,使类实例可被调用def __call__(self,attn: Attention,  # 接收 Attention 对象hidden_states: torch.Tensor,  # 接收隐藏状态张量encoder_hidden_states: torch.Tensor,  # 接收编码器隐藏状态张量attention_mask: Optional[torch.Tensor] = None,  # 可选的注意力掩码张量image_rotary_emb: Optional[torch.Tensor] = None,  # 可选的图像旋转嵌入张量):# 此处将实现自注意力的具体处理逻辑) -> torch.Tensor:  # 函数返回一个张量,表示隐藏状态text_seq_length = encoder_hidden_states.size(1)  # 获取编码器隐藏状态的序列长度hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)  # 在维度1上连接编码器隐藏状态和当前隐藏状态batch_size, sequence_length, _ = (  # 解包 batch_size 和 sequence_lengthhidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape  # 根据编码器隐藏状态的存在性决定形状)if attention_mask is not None:  # 如果存在注意力掩码attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)  # 准备注意力掩码attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])  # 调整注意力掩码的形状以适应头数query = attn.to_q(hidden_states)  # 将隐藏状态转换为查询向量key = attn.to_k(hidden_states)  # 将隐藏状态转换为键向量value = attn.to_v(hidden_states)  # 将隐藏状态转换为值向量inner_dim = key.shape[-1]  # 获取键向量的最后一个维度大小head_dim = inner_dim // attn.heads  # 计算每个头的维度query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  # 调整查询向量形状并转置以适应多头注意力key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  # 调整键向量形状并转置以适应多头注意力value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  # 调整值向量形状并转置以适应多头注意力if attn.norm_q is not None:  # 如果查询归一化层存在query = attn.norm_q(query)  # 对查询向量进行归一化if attn.norm_k is not None:  # 如果键归一化层存在key = attn.norm_k(key)  # 对键向量进行归一化# Apply RoPE if needed  # 如果需要应用旋转位置编码if image_rotary_emb is not None:  # 如果图像旋转嵌入存在from .embeddings import apply_rotary_emb  # 导入应用旋转嵌入的函数query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)  # 应用旋转嵌入到查询向量的后半部分if not attn.is_cross_attention:  # 如果不是交叉注意力key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)  # 应用旋转嵌入到键向量的后半部分hidden_states = F.scaled_dot_product_attention(  # 计算缩放点积注意力query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False  # 输入查询、键和值,以及注意力掩码)hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)  # 转置和重塑隐藏状态以合并头维度# linear proj  # 线性投影hidden_states = attn.to_out[0](hidden_states)  # 对隐藏状态应用输出线性变换# dropout  # 进行dropout操作hidden_states = attn.to_out[1](hidden_states)  # 对隐藏状态应用dropoutencoder_hidden_states, hidden_states = hidden_states.split(  # 将隐藏状态分割为编码器和当前隐藏状态[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1  # 根据文本序列长度和剩余部分进行分割)return hidden_states, encoder_hidden_states  # 返回当前隐藏状态和编码器隐藏状态
# 定义一个用于实现 CogVideoX 模型的缩放点积注意力的处理器类
class FusedCogVideoXAttnProcessor2_0:r"""Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding onquery and key vectors, but does not include spatial normalization."""# 初始化方法def __init__(self):# 检查 F 是否具有 scaled_dot_product_attention 属性,如果没有则抛出导入错误if not hasattr(F, "scaled_dot_product_attention"):raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")# 定义可调用方法,处理注意力计算def __call__(self,attn: Attention,hidden_states: torch.Tensor,encoder_hidden_states: torch.Tensor,attention_mask: Optional[torch.Tensor] = None,image_rotary_emb: Optional[torch.Tensor] = None,) -> torch.Tensor:# 获取编码器隐藏状态的序列长度text_seq_length = encoder_hidden_states.size(1)# 将编码器和当前隐藏状态按维度 1 连接hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)# 获取批次大小和序列长度,依据编码器隐藏状态是否为 Nonebatch_size, sequence_length, _ = (hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape)# 如果提供了注意力掩码,则准备掩码if attention_mask is not None:attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)# 将掩码调整为适当的形状attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])# 将隐藏状态转换为查询、键、值qkv = attn.to_qkv(hidden_states)# 计算每个部分的大小split_size = qkv.shape[-1] // 3# 分割成查询、键和值query, key, value = torch.split(qkv, split_size, dim=-1)# 获取键的内部维度inner_dim = key.shape[-1]# 计算每个头的维度head_dim = inner_dim // attn.heads# 调整查询、键和值的形状以适应多头注意力query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)# 如果存在查询的归一化,则应用归一化if attn.norm_q is not None:query = attn.norm_q(query)# 如果存在键的归一化,则应用归一化if attn.norm_k is not None:key = attn.norm_k(key)# 如果需要应用 RoPEif image_rotary_emb is not None:from .embeddings import apply_rotary_emb# 对查询的特定部分应用旋转嵌入query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)# 如果不是交叉注意力,则对键的特定部分应用旋转嵌入if not attn.is_cross_attention:key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)# 计算缩放点积注意力hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)# 调整隐藏状态的形状以便输出hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)# 线性投影hidden_states = attn.to_out[0](hidden_states)# 应用 dropouthidden_states = attn.to_out[1](hidden_states)# 将隐藏状态拆分为编码器隐藏状态和当前隐藏状态encoder_hidden_states, hidden_states = hidden_states.split([text_seq_length, hidden_states.size(1) - text_seq_length], dim=1)# 返回当前隐藏状态和编码器隐藏状态return hidden_states, encoder_hidden_states# 定义用于实现内存高效注意力的处理器类
class XFormersAttnAddedKVProcessor:r"""Processor for implementing memory efficient attention using xFormers.# 文档字符串,说明可选参数 attention_op 的作用Args:attention_op (`Callable`, *optional*, defaults to `None`):使用的基本注意力操作符,推荐设置为 `None` 让 xFormers 选择最佳操作符"""# 构造函数,初始化注意力操作符def __init__(self, attention_op: Optional[Callable] = None):# 将传入的注意力操作符赋值给实例变量self.attention_op = attention_op# 可调用方法,用于执行注意力计算def __call__(self,attn: Attention,  # 注意力对象hidden_states: torch.Tensor,  # 隐藏状态张量encoder_hidden_states: Optional[torch.Tensor] = None,  # 编码器隐藏状态,默认为 Noneattention_mask: Optional[torch.Tensor] = None,  # 注意力掩码,默认为 None) -> torch.Tensor:# 将当前隐藏状态保存为残差以便后续使用residual = hidden_states# 调整隐藏状态的形状并转置hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)# 获取批次大小和序列长度batch_size, sequence_length, _ = hidden_states.shape# 准备注意力掩码attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)# 如果没有编码器隐藏状态,则将其设置为当前的隐藏状态if encoder_hidden_states is None:encoder_hidden_states = hidden_states# 如果需要,则对编码器隐藏状态进行归一化处理elif attn.norm_cross:encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)# 对隐藏状态进行分组归一化处理hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)# 生成查询向量query = attn.to_q(hidden_states)# 将查询向量从头部维度转换为批次维度query = attn.head_to_batch_dim(query)# 对编码器隐藏状态进行键和值的投影encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)# 将编码器隐藏状态的键和值转换为批次维度encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)# 如果不是仅使用交叉注意力if not attn.only_cross_attention:# 生成当前隐藏状态的键和值key = attn.to_k(hidden_states)value = attn.to_v(hidden_states)# 转换键和值到批次维度key = attn.head_to_batch_dim(key)value = attn.head_to_batch_dim(value)# 将编码器的键和值与当前的键和值连接起来key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)else:# 如果仅使用交叉注意力,则直接使用编码器的键和值key = encoder_hidden_states_key_projvalue = encoder_hidden_states_value_proj# 计算高效的注意力hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale)# 将结果转换为查询的 dtypehidden_states = hidden_states.to(query.dtype)# 将隐藏状态从批次维度转换回头部维度hidden_states = attn.batch_to_head_dim(hidden_states)# 线性变换hidden_states = attn.to_out[0](hidden_states)# 应用 dropouthidden_states = attn.to_out[1](hidden_states)# 调整隐藏状态的形状以匹配残差hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)# 将当前隐藏状态与残差相加hidden_states = hidden_states + residual# 返回最终的隐藏状态return hidden_states
# 定义一个用于实现基于 xFormers 的内存高效注意力的处理器类
class XFormersAttnProcessor:r"""处理器,用于实现基于 xFormers 的内存高效注意力。参数:attention_op (`Callable`, *可选*, 默认为 `None`):基础[操作符](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase),用作注意力操作符。建议将其设置为 `None`,并让 xFormers 选择最佳操作符。"""# 初始化方法,接受一个可选的注意力操作符def __init__(self, attention_op: Optional[Callable] = None):# 将传入的注意力操作符赋值给实例变量self.attention_op = attention_op# 定义可调用方法,用于执行注意力计算def __call__(self,attn: Attention,hidden_states: torch.Tensor,encoder_hidden_states: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,temb: Optional[torch.Tensor] = None,*args,**kwargs,
# 定义一个用于实现 flash attention 的处理器类,使用 torch_npu
class AttnProcessorNPU:r"""处理器,用于使用 torch_npu 实现 flash attention。torch_npu 仅支持 fp16 和 bf16 数据类型。如果使用 fp32,将使用 F.scaled_dot_product_attention 进行计算,但在 NPU 上加速效果不明显。"""# 初始化方法def __init__(self):# 检查是否可用 torch_npu,如果不可用则抛出异常if not is_torch_npu_available():raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")# 定义可调用方法,用于执行注意力计算def __call__(self,attn: Attention,hidden_states: torch.Tensor,encoder_hidden_states: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,temb: Optional[torch.Tensor] = None,*args,**kwargs,
# 定义一个用于实现 scaled dot-product attention 的处理器类,默认在 PyTorch 2.0 中启用
class AttnProcessor2_0:r"""处理器,用于实现 scaled dot-product attention(如果您使用的是 PyTorch 2.0,默认启用)。"""# 初始化方法def __init__(self):# 检查 F 中是否有 scaled_dot_product_attention 属性,如果没有则抛出异常if not hasattr(F, "scaled_dot_product_attention"):raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")# 定义可调用方法,用于执行注意力计算def __call__(self,attn: Attention,hidden_states: torch.Tensor,encoder_hidden_states: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,temb: Optional[torch.Tensor] = None,*args,**kwargs,
# 定义一个用于实现 scaled dot-product attention 的处理器类,适用于稳定音频模型
class StableAudioAttnProcessor2_0:r"""处理器,用于实现 scaled dot-product attention(如果您使用的是 PyTorch 2.0,默认启用)。此处理器用于稳定音频模型。它在查询和键向量上应用旋转嵌入,并允许 MHA、GQA 或 MQA。"""# 初始化方法def __init__(self):# 检查 F 中是否有 scaled_dot_product_attention 属性,如果没有则抛出异常if not hasattr(F, "scaled_dot_product_attention"):raise ImportError("StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")# 定义方法,用于应用部分旋转嵌入def apply_partial_rotary_emb(self,x: torch.Tensor,freqs_cis: Tuple[torch.Tensor],# 定义返回类型为 torch.Tensor 的函数) -> torch.Tensor:# 从当前模块导入 apply_rotary_emb 函数from .embeddings import apply_rotary_emb# 获取频率余弦的最后一个维度大小,用于旋转rot_dim = freqs_cis[0].shape[-1]# 将输入张量 x 划分为需要旋转和不需要旋转的部分x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:]# 应用旋转嵌入到需要旋转的部分x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2)# 将旋转后的部分与未旋转的部分在最后一个维度上连接out = torch.cat((x_rotated, x_unrotated), dim=-1)# 返回连接后的输出张量return out# 定义可调用方法,接收注意力和隐藏状态def __call__(self,# 输入的注意力对象attn: Attention,# 隐藏状态的张量hidden_states: torch.Tensor,# 可选的编码器隐藏状态张量encoder_hidden_states: Optional[torch.Tensor] = None,# 可选的注意力掩码张量attention_mask: Optional[torch.Tensor] = None,# 可选的旋转嵌入张量rotary_emb: Optional[torch.Tensor] = None,
# 定义 HunyuanAttnProcessor2_0 类,处理缩放的点积注意力
class HunyuanAttnProcessor2_0:r"""处理器用于实现缩放的点积注意力(如果使用 PyTorch 2.0,默认启用)。这是HunyuanDiT 模型中使用的。它在查询和键向量上应用归一化层和旋转嵌入。"""# 初始化方法def __init__(self):# 检查 F 中是否有 scaled_dot_product_attention 属性if not hasattr(F, "scaled_dot_product_attention"):# 如果没有,则抛出导入错误,提示需要升级 PyTorch 到 2.0raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")# 定义调用方法def __call__(self,attn: Attention,  # 注意力机制实例hidden_states: torch.Tensor,  # 当前隐藏状态的张量encoder_hidden_states: Optional[torch.Tensor] = None,  # 编码器隐藏状态的可选张量attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码的可选张量temb: Optional[torch.Tensor] = None,  # 时间嵌入的可选张量image_rotary_emb: Optional[torch.Tensor] = None,  # 图像旋转嵌入的可选张量
class FusedHunyuanAttnProcessor2_0:r"""处理器用于实现缩放的点积注意力(如果使用 PyTorch 2.0,默认启用),带有融合的投影层。这是 HunyuanDiT 模型中使用的。它在查询和键向量上应用归一化层和旋转嵌入。"""# 初始化方法def __init__(self):# 检查 F 中是否有 scaled_dot_product_attention 属性if not hasattr(F, "scaled_dot_product_attention"):# 如果没有,则抛出导入错误,提示需要升级 PyTorch 到 2.0raise ImportError("FusedHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")# 定义调用方法def __call__(self,attn: Attention,  # 注意力机制实例hidden_states: torch.Tensor,  # 当前隐藏状态的张量encoder_hidden_states: Optional[torch.Tensor] = None,  # 编码器隐藏状态的可选张量attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码的可选张量temb: Optional[torch.Tensor] = None,  # 时间嵌入的可选张量image_rotary_emb: Optional[torch.Tensor] = None,  # 图像旋转嵌入的可选张量
class PAGHunyuanAttnProcessor2_0:r"""处理器用于实现缩放的点积注意力(如果使用 PyTorch 2.0,默认启用)。这是HunyuanDiT 模型中使用的。它在查询和键向量上应用归一化层和旋转嵌入。该处理器变体采用了 [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377)。"""# 初始化方法def __init__(self):# 检查 F 中是否有 scaled_dot_product_attention 属性if not hasattr(F, "scaled_dot_product_attention"):# 如果没有,则抛出导入错误,提示需要升级 PyTorch 到 2.0raise ImportError("PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")# 定义调用方法def __call__(self,attn: Attention,  # 注意力机制实例hidden_states: torch.Tensor,  # 当前隐藏状态的张量encoder_hidden_states: Optional[torch.Tensor] = None,  # 编码器隐藏状态的可选张量attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码的可选张量temb: Optional[torch.Tensor] = None,  # 时间嵌入的可选张量image_rotary_emb: Optional[torch.Tensor] = None,  # 图像旋转嵌入的可选张量
class PAGCFGHunyuanAttnProcessor2_0:r"""处理器用于实现缩放的点积注意力(如果使用 PyTorch 2.0,默认启用)。这是HunyuanDiT 模型中使用的。它在查询和键向量上应用归一化层和旋转嵌入。该处理器变体采用了 [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377)。"""# 初始化方法,用于创建类的实例def __init__(self):# 检查模块 F 是否具有属性 "scaled_dot_product_attention"if not hasattr(F, "scaled_dot_product_attention"):# 如果没有该属性,则抛出 ImportError,提示用户升级 PyTorchraise ImportError("PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")# 可调用方法,允许类的实例像函数一样被调用def __call__(self,attn: Attention,  # 注意力机制对象hidden_states: torch.Tensor,  # 当前隐藏状态的张量encoder_hidden_states: Optional[torch.Tensor] = None,  # 编码器的隐藏状态,可选参数attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码,可选参数temb: Optional[torch.Tensor] = None,  # 时间嵌入,可选参数image_rotary_emb: Optional[torch.Tensor] = None,  # 图像旋转嵌入,可选参数
# 定义一个用于实现缩放点积注意力的处理器类
class LuminaAttnProcessor2_0:r"""Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This isused in the LuminaNextDiT model. It applies a s normalization layer and rotary embedding on query and key vector."""# 初始化方法def __init__(self):# 检查 PyTorch 是否具有缩放点积注意力功能if not hasattr(F, "scaled_dot_product_attention"):# 如果没有,抛出导入错误,提示用户升级 PyTorch 到 2.0raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")# 定义调用方法,使类实例可调用def __call__(self,# 接收注意力对象attn: Attention,# 接收隐藏状态张量hidden_states: torch.Tensor,# 接收编码器隐藏状态张量encoder_hidden_states: torch.Tensor,# 可选的注意力掩码张量attention_mask: Optional[torch.Tensor] = None,# 可选的查询旋转嵌入张量query_rotary_emb: Optional[torch.Tensor] = None,# 可选的键旋转嵌入张量key_rotary_emb: Optional[torch.Tensor] = None,# 可选的基本序列长度base_sequence_length: Optional[int] = None,) -> torch.Tensor:  # 函数返回一个张量,表示处理后的隐藏状态from .embeddings import apply_rotary_emb  # 从当前包导入应用旋转嵌入的函数input_ndim = hidden_states.ndim  # 获取隐藏状态的维度数if input_ndim == 4:  # 如果隐藏状态是四维张量batch_size, channel, height, width = hidden_states.shape  # 解包出批次大小、通道、高度和宽度hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)  # 重塑并转置隐藏状态batch_size, sequence_length, _ = hidden_states.shape  # 解包出批次大小和序列长度# Get Query-Key-Value Pair  # 获取查询、键、值对query = attn.to_q(hidden_states)  # 将隐藏状态转换为查询张量key = attn.to_k(encoder_hidden_states)  # 将编码器的隐藏状态转换为键张量value = attn.to_v(encoder_hidden_states)  # 将编码器的隐藏状态转换为值张量query_dim = query.shape[-1]  # 获取查询的最后一个维度(特征维度)inner_dim = key.shape[-1]  # 获取键的最后一个维度head_dim = query_dim // attn.heads  # 计算每个头的维度dtype = query.dtype  # 获取查询张量的数据类型# Get key-value heads  # 获取键值头的数量kv_heads = inner_dim // head_dim  # 计算每个头的键值数量# Apply Query-Key Norm if needed  # 如果需要,应用查询-键归一化if attn.norm_q is not None:  # 如果定义了查询的归一化query = attn.norm_q(query)  # 对查询进行归一化if attn.norm_k is not None:  # 如果定义了键的归一化key = attn.norm_k(key)  # 对键进行归一化query = query.view(batch_size, -1, attn.heads, head_dim)  # 重塑查询张量以适应头的维度key = key.view(batch_size, -1, kv_heads, head_dim)  # 重塑键张量以适应头的维度value = value.view(batch_size, -1, kv_heads, head_dim)  # 重塑值张量以适应头的维度# Apply RoPE if needed  # 如果需要,应用旋转位置嵌入if query_rotary_emb is not None:  # 如果定义了查询的旋转嵌入query = apply_rotary_emb(query, query_rotary_emb, use_real=False)  # 应用旋转嵌入到查询if key_rotary_emb is not None:  # 如果定义了键的旋转嵌入key = apply_rotary_emb(key, key_rotary_emb, use_real=False)  # 应用旋转嵌入到键query, key = query.to(dtype), key.to(dtype)  # 将查询和键转换为相同的数据类型# Apply proportional attention if true  # 如果为真,应用比例注意力if key_rotary_emb is None:  # 如果没有键的旋转嵌入softmax_scale = None  # 设置缩放因子为 Noneelse:  # 如果有键的旋转嵌入if base_sequence_length is not None:  # 如果定义了基础序列长度softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale  # 计算缩放因子else:  # 如果没有定义基础序列长度softmax_scale = attn.scale  # 使用注意力的缩放因子# perform Grouped-query Attention (GQA)  # 执行分组查询注意力n_rep = attn.heads // kv_heads  # 计算每个键值头的重复数量if n_rep >= 1:  # 如果重复数量大于等于 1key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)  # 扩展并重复键value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)  # 扩展并重复值# scaled_dot_product_attention expects attention_mask shape to be  # 缩放点积注意力期望的注意力掩码形状# (batch, heads, source_length, target_length)attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)  # 将注意力掩码转换为布尔值并调整形状attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)  # 扩展注意力掩码以匹配头的数量query = query.transpose(1, 2)  # 转置查询张量key = key.transpose(1, 2)  # 转置键张量value = value.transpose(1, 2)  # 转置值张量# the output of sdp = (batch, num_heads, seq_len, head_dim)  # 缩放点积注意力的输出形状# TODO: add support for attn.scale when we move to Torch 2.1  # TODO: 在迁移到 Torch 2.1 时支持 attn.scalehidden_states = F.scaled_dot_product_attention(  # 计算缩放点积注意力query, key, value, attn_mask=attention_mask, scale=softmax_scale  # 输入查询、键、值及注意力掩码和缩放因子)hidden_states = hidden_states.transpose(1, 2).to(dtype)  # 转置输出并转换为相应的数据类型return hidden_states  # 返回处理后的隐藏状态
# 定义一个用于实现缩放点积注意力的处理器类,默认启用(如果使用 PyTorch 2.0)
class FusedAttnProcessor2_0:r"""Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It usesfused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.For cross-attention modules, key and value projection matrices are fused.<Tip warning={true}>This API is currently 🧪 experimental in nature and can change in future.</Tip>"""# 初始化方法def __init__(self):# 检查 F 库是否具有缩放点积注意力功能if not hasattr(F, "scaled_dot_product_attention"):# 如果没有,抛出导入错误,提示用户升级 PyTorch 版本raise ImportError("FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0.")# 调用方法,处理注意力计算def __call__(self,attn: Attention,  # 注意力模块hidden_states: torch.Tensor,  # 隐藏状态张量encoder_hidden_states: Optional[torch.Tensor] = None,  # 编码器的隐藏状态,可选attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码,可选temb: Optional[torch.Tensor] = None,  # 时间嵌入,可选*args,  # 可变位置参数**kwargs,  # 可变关键字参数):pass  # 此处省略具体实现# 定义一个用于实现内存高效注意力的处理器类,使用 xFormers 方法
class CustomDiffusionXFormersAttnProcessor(nn.Module):r"""Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.Args:train_kv (`bool`, defaults to `True`):Whether to newly train the key and value matrices corresponding to the text features.train_q_out (`bool`, defaults to `True`):Whether to newly train query matrices corresponding to the latent image features.hidden_size (`int`, *optional*, defaults to `None`):The hidden size of the attention layer.cross_attention_dim (`int`, *optional*, defaults to `None`):The number of channels in the `encoder_hidden_states`.out_bias (`bool`, defaults to `True`):Whether to include the bias parameter in `train_q_out`.dropout (`float`, *optional*, defaults to 0.0):The dropout probability to use.attention_op (`Callable`, *optional*, defaults to `None`):The base[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to useas the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator."""# 初始化方法,设置各种参数def __init__(self,train_kv: bool = True,  # 是否训练与文本特征对应的键值矩阵train_q_out: bool = False,  # 是否训练与潜在图像特征对应的查询矩阵hidden_size: Optional[int] = None,  # 注意力层的隐藏大小cross_attention_dim: Optional[int] = None,  # 编码器隐藏状态的通道数out_bias: bool = True,  # 是否在 train_q_out 中包含偏置参数dropout: float = 0.0,  # 使用的丢弃概率attention_op: Optional[Callable] = None,  # 要使用的基础注意力操作):pass  # 此处省略具体实现):# 调用父类的初始化方法super().__init__()# 存储训练键值对的标志self.train_kv = train_kv# 存储训练查询输出的标志self.train_q_out = train_q_out# 存储隐藏层大小self.hidden_size = hidden_size# 存储交叉注意力维度self.cross_attention_dim = cross_attention_dim# 存储注意力操作类型self.attention_op = attention_op# `_custom_diffusion` id 用于简化序列化和加载if self.train_kv:# 创建线性层,将交叉注意力维度或隐藏层大小映射到隐藏层大小,且不使用偏置self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)# 创建线性层,将交叉注意力维度或隐藏层大小映射到隐藏层大小,且不使用偏置self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)if self.train_q_out:# 创建线性层,将隐藏层大小映射到隐藏层大小,且不使用偏置self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)# 创建一个空的模块列表以存储输出相关的层self.to_out_custom_diffusion = nn.ModuleList([])# 将线性层添加到模块列表中,用于输出映射self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))# 将 Dropout 层添加到模块列表中,用于正则化self.to_out_custom_diffusion.append(nn.Dropout(dropout))def __call__(# 定义调用方法,接收注意力对象和隐藏状态张量self,attn: Attention,hidden_states: torch.Tensor,# 可选参数:编码器的隐藏状态张量encoder_hidden_states: Optional[torch.Tensor] = None,# 可选参数:注意力掩码张量attention_mask: Optional[torch.Tensor] = None,# 定义函数的返回类型为 torch.Tensor) -> torch.Tensor:# 获取批量大小和序列长度,根据 encoder_hidden_states 是否为 None 决定来源batch_size, sequence_length, _ = (hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape)# 准备注意力掩码attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)# 判断是否在训练阶段并应用不同的查询生成方式if self.train_q_out:query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)else:query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))# 判断是否存在编码器隐藏状态,并设置 crossattn 标志if encoder_hidden_states is None:crossattn = Falseencoder_hidden_states = hidden_stateselse:crossattn = True# 如果需要对编码器隐藏状态进行归一化处理if attn.norm_cross:encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)# 判断是否在训练阶段并应用不同的键值生成方式if self.train_kv:key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))key = key.to(attn.to_q.weight.dtype)value = value.to(attn.to_q.weight.dtype)else:key = attn.to_k(encoder_hidden_states)value = attn.to_v(encoder_hidden_states)# 如果使用交叉注意力,进行键值的分离和处理if crossattn:detach = torch.ones_like(key)detach[:, :1, :] = detach[:, :1, :] * 0.0key = detach * key + (1 - detach) * key.detach()value = detach * value + (1 - detach) * value.detach()# 将查询、键、值转换为批处理维度并保持连续性query = attn.head_to_batch_dim(query).contiguous()key = attn.head_to_batch_dim(key).contiguous()value = attn.head_to_batch_dim(value).contiguous()# 使用内存高效的注意力计算隐藏状态hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale)# 将隐藏状态转换为查询的类型hidden_states = hidden_states.to(query.dtype)# 将隐藏状态转换回头部维度hidden_states = attn.batch_to_head_dim(hidden_states)# 根据训练标志决定输出的处理方式if self.train_q_out:# 线性变换hidden_states = self.to_out_custom_diffusion[0](hidden_states)# 进行 dropout 操作hidden_states = self.to_out_custom_diffusion[1](hidden_states)else:# 线性变换hidden_states = attn.to_out[0](hidden_states)# 进行 dropout 操作hidden_states = attn.to_out[1](hidden_states)# 返回处理后的隐藏状态return hidden_states
# 自定义扩散注意力处理器类,继承自 PyTorch 的 nn.Module
class CustomDiffusionAttnProcessor2_0(nn.Module):r"""用于实现自定义扩散方法的注意力处理器,使用 PyTorch 2.0 的内存高效缩放点积注意力。参数:train_kv (`bool`, 默认值为 `True`):是否新训练与文本特征对应的键和值矩阵。train_q_out (`bool`, 默认值为 `True`):是否新训练与潜在图像特征对应的查询矩阵。hidden_size (`int`, *可选*, 默认值为 `None`):注意力层的隐藏大小。cross_attention_dim (`int`, *可选*, 默认值为 `None`):`encoder_hidden_states` 中的通道数。out_bias (`bool`, 默认值为 `True`):是否在 `train_q_out` 中包含偏置参数。dropout (`float`, *可选*, 默认值为 0.0):使用的 dropout 概率。"""# 初始化方法,设置类的属性def __init__(self,train_kv: bool = True,train_q_out: bool = True,hidden_size: Optional[int] = None,cross_attention_dim: Optional[int] = None,out_bias: bool = True,dropout: float = 0.0,):# 调用父类的初始化方法super().__init__()# 设置是否训练键值矩阵的标志self.train_kv = train_kv# 设置是否训练查询输出矩阵的标志self.train_q_out = train_q_out# 设置隐藏层的大小self.hidden_size = hidden_size# 设置交叉注意力的维度self.cross_attention_dim = cross_attention_dim# 如果需要训练键值矩阵,则创建对应的线性层if self.train_kv:# 创建从交叉注意力维度到隐藏层的线性变换,且不使用偏置self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)# 创建从交叉注意力维度到隐藏层的线性变换,且不使用偏置self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)# 如果需要训练查询输出,则创建对应的线性层if self.train_q_out:# 创建从隐藏层到隐藏层的线性变换,且不使用偏置self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)# 创建一个空的模块列表,用于存储输出层self.to_out_custom_diffusion = nn.ModuleList([])# 将线性层添加到输出模块列表中self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))# 添加 dropout 层到输出模块列表中self.to_out_custom_diffusion.append(nn.Dropout(dropout))# 定义类的调用方法,处理输入的注意力和隐藏状态def __call__(self,attn: Attention,hidden_states: torch.Tensor,encoder_hidden_states: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,) -> torch.Tensor:  # 定义返回类型为 torch.Tensorbatch_size, sequence_length, _ = hidden_states.shape  # 解包 hidden_states 的形状,获取批大小和序列长度attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)  # 准备注意力掩码if self.train_q_out:  # 检查是否在训练查询输出query = self.to_q_custom_diffusion(hidden_states)  # 使用自定义扩散方法生成查询向量else:  # 否则query = attn.to_q(hidden_states)  # 使用标准方法生成查询向量if encoder_hidden_states is None:  # 检查编码器隐藏状态是否为空crossattn = False  # 设置交叉注意力标志为假encoder_hidden_states = hidden_states  # 将编码器隐藏状态设置为隐藏状态else:  # 如果编码器隐藏状态不为空crossattn = True  # 设置交叉注意力标志为真if attn.norm_cross:  # 如果需要归一化encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)  # 归一化编码器隐藏状态if self.train_kv:  # 检查是否在训练键值对key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))  # 生成键向量value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))  # 生成值向量key = key.to(attn.to_q.weight.dtype)  # 将键向量转换为查询权重的数据类型value = value.to(attn.to_q.weight.dtype)  # 将值向量转换为查询权重的数据类型else:  # 否则key = attn.to_k(encoder_hidden_states)  # 使用标准方法生成键向量value = attn.to_v(encoder_hidden_states)  # 使用标准方法生成值向量if crossattn:  # 如果进行交叉注意力detach = torch.ones_like(key)  # 创建与键相同形状的全1张量detach[:, :1, :] = detach[:, :1, :] * 0.0  # 将第一时间步的值设置为0key = detach * key + (1 - detach) * key.detach()  # 根据 detach 张量计算键的最终值value = detach * value + (1 - detach) * value.detach()  # 根据 detach 张量计算值的最终值inner_dim = hidden_states.shape[-1]  # 获取隐藏状态的最后一维大小head_dim = inner_dim // attn.heads  # 计算每个头的维度query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  # 重新调整查询的形状并转置key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  # 重新调整键的形状并转置value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  # 重新调整值的形状并转置# the output of sdp = (batch, num_heads, seq_len, head_dim)# TODO: add support for attn.scale when we move to Torch 2.1hidden_states = F.scaled_dot_product_attention(  # 计算缩放点积注意力query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False  # 输入查询、键和值以及注意力掩码)hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)  # 转置并重塑隐藏状态hidden_states = hidden_states.to(query.dtype)  # 将隐藏状态转换为查询的类型if self.train_q_out:  # 如果在训练查询输出# linear projhidden_states = self.to_out_custom_diffusion[0](hidden_states)  # 线性变换# dropouthidden_states = self.to_out_custom_diffusion[1](hidden_states)  # 应用 dropoutelse:  # 否则# linear projhidden_states = attn.to_out[0](hidden_states)  # 线性变换# dropouthidden_states = attn.to_out[1](hidden_states)  # 应用 dropoutreturn hidden_states  # 返回最终的隐藏状态
# 定义一个用于实现切片注意力的处理器类
class SlicedAttnProcessor:r"""处理器用于实现切片注意力。参数:slice_size (`int`, *可选*):计算注意力的步骤数量。使用的切片数量为 `attention_head_dim // slice_size`,并且`attention_head_dim` 必须是 `slice_size` 的整数倍。"""# 初始化方法,接受切片大小作为参数def __init__(self, slice_size: int):# 将传入的切片大小保存为实例变量self.slice_size = slice_size# 定义可调用方法,以便实例可以像函数一样被调用def __call__(self,attn: Attention,  # 输入的注意力对象hidden_states: torch.Tensor,  # 当前隐藏状态的张量encoder_hidden_states: Optional[torch.Tensor] = None,  # 编码器的隐藏状态,可选参数attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码,可选参数) -> torch.Tensor:# 保存输入的隐藏状态,用于残差连接residual = hidden_states# 获取隐藏状态的维度数量input_ndim = hidden_states.ndim# 如果输入维度是4,调整隐藏状态的形状if input_ndim == 4:# 解包隐藏状态的形状为批量大小、通道、高度和宽度batch_size, channel, height, width = hidden_states.shape# 将隐藏状态展平为(batch_size, channel, height * width)并转置hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)# 确定序列长度和批量大小,根据是否有编码器隐藏状态决定batch_size, sequence_length, _ = (hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape)# 准备注意力掩码attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)# 如果有分组归一化,应用于隐藏状态if attn.group_norm is not None:hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)# 将隐藏状态转换为查询向量query = attn.to_q(hidden_states)# 获取查询向量的最后一维大小dim = query.shape[-1]# 将查询向量转换为批量维度格式query = attn.head_to_batch_dim(query)# 如果没有编码器隐藏状态,使用当前隐藏状态if encoder_hidden_states is None:encoder_hidden_states = hidden_states# 如果需要,归一化编码器隐藏状态elif attn.norm_cross:encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)# 将编码器隐藏状态转换为键和值key = attn.to_k(encoder_hidden_states)value = attn.to_v(encoder_hidden_states)# 将键和值转换为批量维度格式key = attn.head_to_batch_dim(key)value = attn.head_to_batch_dim(value)# 获取查询向量的批量大小和令牌数量batch_size_attention, query_tokens, _ = query.shape# 初始化隐藏状态张量为零hidden_states = torch.zeros((batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype)# 按切片处理查询、键和值for i in range((batch_size_attention - 1) // self.slice_size + 1):# 计算当前切片的起始和结束索引start_idx = i * self.slice_sizeend_idx = (i + 1) * self.slice_size# 获取当前切片的查询、键和注意力掩码query_slice = query[start_idx:end_idx]key_slice = key[start_idx:end_idx]attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None# 计算当前切片的注意力分数attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)# 将注意力分数与值相乘,获取注意力结果attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])# 将注意力结果存储到隐藏状态中hidden_states[start_idx:end_idx] = attn_slice# 将隐藏状态转换回头维度格式hidden_states = attn.batch_to_head_dim(hidden_states)# 对隐藏状态进行线性变换hidden_states = attn.to_out[0](hidden_states)# 应用 dropouthidden_states = attn.to_out[1](hidden_states)# 如果输入维度是4,调整隐藏状态的形状if input_ndim == 4:hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)# 如果需要残差连接,将残差加到隐藏状态中if attn.residual_connection:hidden_states = hidden_states + residual# 根据缩放因子调整输出hidden_states = hidden_states / attn.rescale_output_factor# 返回最终的隐藏状态return hidden_states
# 定义一个处理器类,用于实现切片注意力,并额外学习键和值矩阵
class SlicedAttnAddedKVProcessor:r"""处理器,用于实现带有额外可学习的键和值矩阵的切片注意力,用于文本编码器。参数:slice_size (`int`, *可选*):计算注意力的步数。使用 `attention_head_dim // slice_size` 的切片数量,并且 `attention_head_dim` 必须是 `slice_size` 的倍数。"""# 初始化方法,接收切片大小作为参数def __init__(self, slice_size):# 将传入的切片大小赋值给实例变量self.slice_size = slice_size# 定义调用方法,使类的实例可以像函数一样被调用def __call__(self,attn: "Attention",  # 接收一个注意力对象hidden_states: torch.Tensor,  # 输入的隐藏状态张量encoder_hidden_states: Optional[torch.Tensor] = None,  # 可选的编码器隐藏状态张量attention_mask: Optional[torch.Tensor] = None,  # 可选的注意力掩码张量temb: Optional[torch.Tensor] = None,  # 可选的时间嵌入张量# 返回类型为 torch.Tensor) -> torch.Tensor:# 保存输入的隐藏状态作为残差residual = hidden_states# 如果空间归一化存在,则应用于隐藏状态和时间嵌入if attn.spatial_norm is not None:hidden_states = attn.spatial_norm(hidden_states, temb)# 将隐藏状态重塑为三维张量并转置维度hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)# 获取批量大小和序列长度batch_size, sequence_length, _ = hidden_states.shape# 准备注意力掩码attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)# 如果没有编码器隐藏状态,则将其设置为当前隐藏状态if encoder_hidden_states is None:encoder_hidden_states = hidden_states# 如果需要归一化编码器隐藏状态elif attn.norm_cross:encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)# 对隐藏状态应用组归一化并转置维度hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)# 生成查询向量query = attn.to_q(hidden_states)# 获取查询向量的最后一维大小dim = query.shape[-1]# 将查询向量的维度转换为批次维度query = attn.head_to_batch_dim(query)# 生成编码器隐藏状态的键和值的投影encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)# 将编码器隐藏状态的键和值转换为批次维度encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)# 如果不只使用交叉注意力if not attn.only_cross_attention:# 生成当前隐藏状态的键和值key = attn.to_k(hidden_states)value = attn.to_v(hidden_states)# 将键和值转换为批次维度key = attn.head_to_batch_dim(key)value = attn.head_to_batch_dim(value)# 将编码器键与当前键拼接key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)# 将编码器值与当前值拼接value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)else:# 直接使用编码器的键和值key = encoder_hidden_states_key_projvalue = encoder_hidden_states_value_proj# 获取批量大小、查询令牌数量和最后一维大小batch_size_attention, query_tokens, _ = query.shape# 初始化隐藏状态为零张量hidden_states = torch.zeros((batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype)# 按切片大小进行迭代处理for i in range((batch_size_attention - 1) // self.slice_size + 1):start_idx = i * self.slice_size  # 切片起始索引end_idx = (i + 1) * self.slice_size  # 切片结束索引# 获取当前查询切片、键切片和注意力掩码切片query_slice = query[start_idx:end_idx]key_slice = key[start_idx:end_idx]attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None# 获取当前切片的注意力分数attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)# 将注意力分数与当前值进行批量矩阵乘法attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])# 将结果存储到隐藏状态hidden_states[start_idx:end_idx] = attn_slice# 将隐藏状态的维度转换回头部维度hidden_states = attn.batch_to_head_dim(hidden_states)# 线性投影hidden_states = attn.to_out[0](hidden_states)# 应用丢弃层hidden_states = attn.to_out[1](hidden_states)# 转置最后两维并重塑为残差形状hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)# 将残差添加到当前隐藏状态hidden_states = hidden_states + residual# 返回最终的隐藏状态return hidden_states
# 定义一个空间归一化类,继承自 nn.Module
class SpatialNorm(nn.Module):"""空间条件归一化,定义在 https://arxiv.org/abs/2209.09002 中。参数:f_channels (`int`):输入到组归一化层的通道数,以及空间归一化层的输出通道数。zq_channels (`int`):量化向量的通道数,如论文中所述。"""# 初始化方法,接收通道数作为参数def __init__(self,f_channels: int,zq_channels: int,):# 调用父类的初始化方法super().__init__()# 创建组归一化层,指定通道数、组数和其他参数self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)# 创建卷积层,输入通道为 zq_channels,输出通道为 f_channelsself.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)# 创建另一个卷积层,功能相同但用于偏置项self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)# 前向传播方法,定义输入和输出def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:# 获取输入张量 f 的空间尺寸f_size = f.shape[-2:]# 对 zq 进行上采样,使其尺寸与 f 相同zq = F.interpolate(zq, size=f_size, mode="nearest")# 对输入 f 应用归一化层norm_f = self.norm_layer(f)# 计算新的输出张量,通过归一化后的 f 和卷积结果结合new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)# 返回新的张量return new_f# 定义一个 IPAdapter 注意力处理器类,继承自 nn.Module
class IPAdapterAttnProcessor(nn.Module):r"""多个 IP-Adapter 的注意力处理器。参数:hidden_size (`int`):注意力层的隐藏尺寸。cross_attention_dim (`int`):`encoder_hidden_states` 中的通道数。num_tokens (`int`, `Tuple[int]` 或 `List[int]`, 默认为 `(4,)`):图像特征的上下文长度。scale (`float` 或 List[`float`], 默认为 1.0):图像提示的权重缩放。"""# 初始化方法,接收多个参数def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):# 调用父类的初始化方法super().__init__()# 保存隐藏尺寸self.hidden_size = hidden_size# 保存交叉注意力维度self.cross_attention_dim = cross_attention_dim# 确保 num_tokens 为元组或列表if not isinstance(num_tokens, (tuple, list)):num_tokens = [num_tokens]# 保存 num_tokensself.num_tokens = num_tokens# 确保 scale 为列表if not isinstance(scale, list):scale = [scale] * len(num_tokens)# 验证 scale 和 num_tokens 长度相同if len(scale) != len(num_tokens):raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")# 保存缩放因子self.scale = scale# 创建用于键的线性变换列表self.to_k_ip = nn.ModuleList([nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))])# 创建用于值的线性变换列表self.to_v_ip = nn.ModuleList([nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))])# 定义调用方法,处理输入的注意力信息def __call__(self,attn: Attention,hidden_states: torch.Tensor,encoder_hidden_states: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,temb: Optional[torch.Tensor] = None,scale: float = 1.0,ip_adapter_masks: Optional[torch.Tensor] = None,
class IPAdapterAttnProcessor2_0(torch.nn.Module):r"""PyTorch 2.0 的 IP-Adapter 注意力处理器。# 定义参数说明文档,列出类构造函数的参数及其类型和默认值Args:hidden_size (`int`):注意力层的隐藏层大小cross_attention_dim (`int`):编码器隐藏状态的通道数num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):图像特征的上下文长度scale (`float` or `List[float]`, defaults to 1.0):图像提示的权重比例"""# 初始化类的构造函数,设置类属性def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):# 调用父类的构造函数super().__init__()# 检查 PyTorch 是否支持缩放点积注意力if not hasattr(F, "scaled_dot_product_attention"):# 如果不支持,抛出导入错误raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")# 设置隐藏层大小属性self.hidden_size = hidden_size# 设置交叉注意力维度属性self.cross_attention_dim = cross_attention_dim# 如果 num_tokens 不是元组或列表,则将其转换为列表if not isinstance(num_tokens, (tuple, list)):num_tokens = [num_tokens]# 设置 num_tokens 属性self.num_tokens = num_tokens# 如果 scale 不是列表,则创建与 num_tokens 长度相同的列表if not isinstance(scale, list):scale = [scale] * len(num_tokens)# 检查 scale 的长度是否与 num_tokens 相同if len(scale) != len(num_tokens):# 如果不同,抛出值错误raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")# 设置 scale 属性self.scale = scale# 创建一个包含多个线性层的模块列表,用于输入到 K 的映射self.to_k_ip = nn.ModuleList([nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))])# 创建一个包含多个线性层的模块列表,用于输入到 V 的映射self.to_v_ip = nn.ModuleList([nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))])# 定义类的调用方法,用于执行注意力计算def __call__(self,attn: Attention,hidden_states: torch.Tensor,encoder_hidden_states: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,temb: Optional[torch.Tensor] = None,scale: float = 1.0,ip_adapter_masks: Optional[torch.Tensor] = None,
# 定义用于实现 PAG 的处理器类,使用缩放点积注意力(默认启用 PyTorch 2.0)
class PAGIdentitySelfAttnProcessor2_0:r"""处理器用于实现 PAG,使用缩放点积注意力(默认在 PyTorch 2.0 中启用)。PAG 参考: https://arxiv.org/abs/2403.17377"""# 初始化函数def __init__(self):# 检查 F 中是否有缩放点积注意力功能if not hasattr(F, "scaled_dot_product_attention"):# 如果没有,则抛出导入错误,提示需要升级 PyTorchraise ImportError("PAGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")# 可调用方法,定义了注意力处理的输入参数def __call__(self,attn: Attention,  # 输入的注意力对象hidden_states: torch.FloatTensor,  # 当前的隐藏状态encoder_hidden_states: Optional[torch.FloatTensor] = None,  # 编码器的隐藏状态(可选)attention_mask: Optional[torch.FloatTensor] = None,  # 注意力掩码(可选)temb: Optional[torch.FloatTensor] = None,  # 额外的时间嵌入(可选)
class PAGCFGIdentitySelfAttnProcessor2_0:r"""处理器用于实现 PAG,使用缩放点积注意力(默认启用 PyTorch 2.0)。PAG 参考: https://arxiv.org/abs/2403.17377"""# 初始化函数def __init__(self):# 检查 F 中是否有缩放点积注意力功能if not hasattr(F, "scaled_dot_product_attention"):# 如果没有,则抛出导入错误,提示需要升级 PyTorchraise ImportError("PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")# 可调用方法,定义了注意力处理的输入参数def __call__(self,attn: Attention,  # 输入的注意力对象hidden_states: torch.FloatTensor,  # 当前的隐藏状态encoder_hidden_states: Optional[torch.FloatTensor] = None,  # 编码器的隐藏状态(可选)attention_mask: Optional[torch.FloatTensor] = None,  # 注意力掩码(可选)temb: Optional[torch.FloatTensor] = None,  # 额外的时间嵌入(可选)
class LoRAAttnProcessor:# 初始化函数def __init__(self):# 该类的构造函数,目前没有初始化操作passclass LoRAAttnProcessor2_0:# 初始化函数def __init__(self):# 该类的构造函数,目前没有初始化操作passclass LoRAXFormersAttnProcessor:# 初始化函数def __init__(self):# 该类的构造函数,目前没有初始化操作passclass LoRAAttnAddedKVProcessor:# 初始化函数def __init__(self):# 该类的构造函数,目前没有初始化操作pass# 定义一个包含添加键值注意力处理器的元组
ADDED_KV_ATTENTION_PROCESSORS = (AttnAddedKVProcessor,  # 添加键值注意力处理器SlicedAttnAddedKVProcessor,  # 切片添加键值注意力处理器AttnAddedKVProcessor2_0,  # 添加键值注意力处理器版本2.0XFormersAttnAddedKVProcessor,  # XFormers 添加键值注意力处理器
)# 定义一个包含交叉注意力处理器的元组
CROSS_ATTENTION_PROCESSORS = (AttnProcessor,  # 注意力处理器AttnProcessor2_0,  # 注意力处理器版本2.0XFormersAttnProcessor,  # XFormers 注意力处理器SlicedAttnProcessor,  # 切片注意力处理器IPAdapterAttnProcessor,  # IPAdapter 注意力处理器IPAdapterAttnProcessor2_0,  # IPAdapter 注意力处理器版本2.0
)# 定义一个包含所有注意力处理器的联合类型
AttentionProcessor = Union[AttnProcessor,  # 注意力处理器AttnProcessor2_0,  # 注意力处理器版本2.0FusedAttnProcessor2_0,  # 融合注意力处理器版本2.0XFormersAttnProcessor,  # XFormers 注意力处理器SlicedAttnProcessor,  # 切片注意力处理器AttnAddedKVProcessor,  # 添加键值注意力处理器SlicedAttnAddedKVProcessor,  # 切片添加键值注意力处理器AttnAddedKVProcessor2_0,  # 添加键值注意力处理器版本2.0XFormersAttnAddedKVProcessor,  # XFormers 添加键值注意力处理器CustomDiffusionAttnProcessor,  # 自定义扩散注意力处理器CustomDiffusionXFormersAttnProcessor,  # 自定义扩散 XFormers 注意力处理器CustomDiffusionAttnProcessor2_0,  # 自定义扩散注意力处理器版本2.0PAGCFGIdentitySelfAttnProcessor2_0,  # PAGCFG 身份自注意力处理器版本2.0PAGIdentitySelfAttnProcessor2_0,  # PAG 身份自注意力处理器版本2.0PAGCFGHunyuanAttnProcessor2_0,  # PAGCGHunyuan 注意力处理器版本2.0PAGHunyuanAttnProcessor2_0,  # PAG Hunyuan 注意力处理器版本2.0
]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.ryyt.cn/news/74594.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈,一经查实,立即删除!

相关文章

剖析Air724UG的硬件设计,还有大发现?01部分

​一、绪论 Air724UG是一款基于紫光展锐UIS8910DM平台设计的LTE Cat 1无线通信模组。支持FDD-LTE/TDD-LTE的4G远距离通讯和Bluetooth近距离无线传输技术,支持 WiFi Scan 和 WiFi 定位,支持VoLTE、Audio、Camera、LCD、Keypad等功能。另外,模组提供了USB/UART/SPI/I2C/SDIO等…

习题5.7

习题5.7代码 import cvxpy import cvxpy as cp import numpy as np import pandas as pd from scipy.optimize import minimize import sympy as sp sp.init_printing(use_unicode=True) import matplotlib.pyplot as plt x = cp.Variable(3, integer=True) cumulative_output …

diffusers-源码解析-二十三-

diffusers 源码解析(二十三) .\diffusers\pipelines\controlnet\pipeline_controlnet_sd_xl_img2img.py # 版权所有 2024 HuggingFace 团队。保留所有权利。 # # 根据 Apache 许可证第 2.0 版(“许可证”)许可; # 除非遵守许可证,否则您不得使用此文件。 # 您可以在以下网…

diffusers-源码解析-二十六-

diffusers 源码解析(二十六) .\diffusers\pipelines\deepfloyd_if\pipeline_if_inpainting_superresolution.py # 导入 html 模块,用于处理 HTML 文本 import html # 导入 inspect 模块,用于获取对象的信息 import inspect # 导入 re 模块,用于正则表达式匹配 import re #…

diffusers-源码解析-二十九-

diffusers 源码解析(二十九) .\diffusers\pipelines\deprecated\stable_diffusion_variants\pipeline_stable_diffusion_model_editing.py # 版权信息,声明版权和许可协议 # Copyright 2024 TIME Authors and The HuggingFace Team. All rights reserved." # 根据 Apac…

习题6.3

习题6.3代码 import numpy as np import pandas as pd import cvxpy as cp import networkx as nx import matplotlib.pyplot as plt plt.rcParams[font.sans-serif]=[Times New Roman + SimSun + WFM Sans SC] plt.rcParams[mathtext.fontset]=stix Times New Roman + SimSun …

苹果笔记本和微软Surface哪个更适合商务使用

在商务环境中,选择合适的笔记本电脑对于提高工作效率至关重要。本文对苹果笔记本和微软Surface进行比较分析,探讨哪种更适合商务使用。主要考虑因素包括:1.性能和可靠性;2.操作系统与软件兼容性;3.设计与便携性;4.电池续航力;5.价格与性价比;6.售后服务与支持。通过全面…

【日记】今天好忙(459 字)

写在前面今天没有什么可看的,可以不用看。 正文爆炸忙。整个下午我的手似乎就没停过,现在写这则日记,回想那个时候的自己,觉得好陌生。整体来说,那段时间也一片空白,什么印象都没有了。太忙,也没有做其它事情的空间。BAEA 台灯到了。我还以为又送到发改局去了,先去那边…