diffusers-源码解析-二十九-

news/2024/10/22 12:37:02

diffusers 源码解析(二十九)

.\diffusers\pipelines\deprecated\stable_diffusion_variants\pipeline_stable_diffusion_model_editing.py

# 版权信息,声明版权和许可协议
# Copyright 2024 TIME Authors and The HuggingFace Team. All rights reserved."
# 根据 Apache License 2.0 许可协议进行许可
# 此文件只能在遵守许可证的情况下使用
# 可通过以下网址获取许可证副本
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则软件按“现状”分发,不提供任何形式的担保或条件
# 具体许可条款和限制请参见许可证# 导入复制模块,用于对象复制
import copy
# 导入检查模块,用于检查对象的信息
import inspect
# 导入类型提示相关的模块
from typing import Any, Callable, Dict, List, Optional, Union# 导入 PyTorch 库
import torch
# 从 transformers 库中导入图像处理器、文本模型和标记器
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer# 从相对路径导入自定义图像处理器
from ....image_processor import VaeImageProcessor
# 从相对路径导入自定义加载器混合类
from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
# 从相对路径导入自定义模型类
from ....models import AutoencoderKL, UNet2DConditionModel
# 从相对路径导入调整 LoRA 规模的函数
from ....models.lora import adjust_lora_scale_text_encoder
# 从相对路径导入调度器类
from ....schedulers import PNDMScheduler
# 从调度器工具导入调度器混合类
from ....schedulers.scheduling_utils import SchedulerMixin
# 从工具库中导入多个功能模块
from ....utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
# 从自定义的 PyTorch 工具库中导入随机张量生成函数
from ....utils.torch_utils import randn_tensor
# 从管道工具导入扩散管道和稳定扩散混合类
from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin
# 从稳定扩散相关模块导入管道输出类
from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
# 从稳定扩散安全检查器模块导入安全检查器类
from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker# 创建一个日志记录器,用于记录当前模块的信息
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name# 定义一个常量列表,包含不同的图像描述前缀
AUGS_CONST = ["A photo of ", "An image of ", "A picture of "]# 定义一个稳定扩散模型编辑管道类,继承自多个基类
class StableDiffusionModelEditingPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
):r"""文本到图像模型编辑的管道。该模型继承自 [`DiffusionPipeline`]。请查阅超类文档以获取所有管道实现的通用方法(下载、保存、在特定设备上运行等)。该管道还继承以下加载方法:- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] 用于加载文本反转嵌入- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] 用于加载 LoRA 权重- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] 用于保存 LoRA 权重# 文档字符串,描述类或方法的参数Args:vae ([`AutoencoderKL`]):# Variational Auto-Encoder (VAE) 模型,用于将图像编码和解码为潜在表示Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.text_encoder ([`~transformers.CLIPTextModel`]):# 冻结的文本编码器,使用 CLIP 的大型视觉变换模型Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).tokenizer ([`~transformers.CLIPTokenizer`]):# 用于文本标记化的 CLIPTokenizerA `CLIPTokenizer` to tokenize text.unet ([`UNet2DConditionModel`]):# 用于去噪编码图像潜在的 UNet2DConditionModelA `UNet2DConditionModel` to denoise the encoded image latents.scheduler ([`SchedulerMixin`]):# 调度器,与 unet 一起用于去噪编码的图像潜在,可以是 DDIMScheduler、LMSDiscreteScheduler 或 PNDMSchedulerA scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].safety_checker ([`StableDiffusionSafetyChecker`]):# 分类模块,用于估计生成的图像是否可能被视为冒犯或有害Classification module that estimates whether generated images could be considered offensive or harmful.# 参阅模型卡以获取有关模型潜在危害的更多详细信息Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more detailsabout a model's potential harms.feature_extractor ([`~transformers.CLIPImageProcessor`]):# 用于从生成的图像中提取特征的 CLIPImageProcessor;作为输入传递给 safety_checkerA `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.with_to_k ([`bool`]):# 是否在编辑文本到图像模型时编辑键投影矩阵与值投影矩阵Whether to edit the key projection matrices along with the value projection matrices.with_augs ([`list`]):# 在编辑文本到图像模型时应用的文本增强,设置为 [] 表示不进行增强Textual augmentations to apply while editing the text-to-image model. Set to `[]` for no augmentations."""# 定义模型的 CPU 卸载顺序model_cpu_offload_seq = "text_encoder->unet->vae"# 定义可选组件列表_optional_components = ["safety_checker", "feature_extractor"]# 定义不参与 CPU 卸载的组件_exclude_from_cpu_offload = ["safety_checker"]# 初始化方法,设置模型和参数def __init__(self,vae: AutoencoderKL,text_encoder: CLIPTextModel,tokenizer: CLIPTokenizer,unet: UNet2DConditionModel,scheduler: SchedulerMixin,safety_checker: StableDiffusionSafetyChecker,feature_extractor: CLIPImageProcessor,requires_safety_checker: bool = True,with_to_k: bool = True,with_augs: list = AUGS_CONST,# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt 复制的代码def _encode_prompt(self,prompt,device,num_images_per_prompt,do_classifier_free_guidance,negative_prompt=None,prompt_embeds: Optional[torch.Tensor] = None,negative_prompt_embeds: Optional[torch.Tensor] = None,lora_scale: Optional[float] = None,**kwargs,):# 定义一个弃用消息,提示用户 `_encode_prompt()` 已弃用,建议使用 `encode_prompt()`,并说明输出格式变化deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."# 调用 deprecate 函数记录弃用警告,指定版本和警告信息,标准警告设置为 Falsedeprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)# 调用 encode_prompt 方法,将参数传入以获取提示嵌入的元组prompt_embeds_tuple = self.encode_prompt(prompt=prompt,  # 用户输入的提示文本device=device,  # 指定运行设备num_images_per_prompt=num_images_per_prompt,  # 每个提示生成的图像数量do_classifier_free_guidance=do_classifier_free_guidance,  # 是否使用无分类器的引导negative_prompt=negative_prompt,  # 负面提示文本prompt_embeds=prompt_embeds,  # 现有的提示嵌入(如果有的话)negative_prompt_embeds=negative_prompt_embeds,  # 负面提示嵌入(如果有的话)lora_scale=lora_scale,  # Lora 缩放参数**kwargs,  # 额外参数)# 连接提示嵌入元组的两个部分,适配旧版本的兼容性prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])# 返回连接后的提示嵌入return prompt_embeds# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt 复制def encode_prompt(self,prompt,  # 用户输入的提示文本device,  # 指定运行设备num_images_per_prompt,  # 每个提示生成的图像数量do_classifier_free_guidance,  # 是否使用无分类器的引导negative_prompt=None,  # 负面提示文本,默认为 Noneprompt_embeds: Optional[torch.Tensor] = None,  # 现有的提示嵌入,默认为 Nonenegative_prompt_embeds: Optional[torch.Tensor] = None,  # 负面提示嵌入,默认为 Nonelora_scale: Optional[float] = None,  # Lora 缩放参数,默认为 Noneclip_skip: Optional[int] = None,  # 跳过的剪辑层,默认为 None# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker 复制def run_safety_checker(self, image, device, dtype):  # 定义安全检查函数,接收图像、设备和数据类型# 检查安全检查器是否存在if self.safety_checker is None:has_nsfw_concept = None  # 如果没有安全检查器,则 NSFW 概念为 Noneelse:# 如果图像是张量类型,进行后处理,转换为 PIL 格式if torch.is_tensor(image):feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")else:# 如果图像是 numpy 数组,直接转换为 PIL 格式feature_extractor_input = self.image_processor.numpy_to_pil(image)# 获取安全检查器的输入,转换为张量并移至指定设备safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)# 使用安全检查器检查图像,返回处理后的图像和 NSFW 概念image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))# 返回处理后的图像和 NSFW 概念return image, has_nsfw_concept# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents 复制# 解码潜在向量的方法def decode_latents(self, latents):# 警告信息,提示此方法已弃用,将在 1.0.0 中移除deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"# 调用 deprecate 函数记录弃用信息deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)# 根据配置的缩放因子调整潜在向量latents = 1 / self.vae.config.scaling_factor * latents# 解码潜在向量并返回图像数据image = self.vae.decode(latents, return_dict=False)[0]# 将图像数据归一化到 [0, 1] 范围内image = (image / 2 + 0.5).clamp(0, 1)# 始终将图像数据转换为 float32 类型,以确保兼容性并降低开销image = image.cpu().permute(0, 2, 3, 1).float().numpy()# 返回处理后的图像数据return image# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 复制def prepare_extra_step_kwargs(self, generator, eta):# 为调度器步骤准备额外的关键字参数,不同调度器的签名不同# eta (η) 仅在 DDIMScheduler 中使用,其他调度器会忽略# eta 对应于 DDIM 论文中的 η,范围应在 [0, 1] 之间# 检查调度器步骤的参数是否接受 etaaccepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())# 初始化额外步骤关键字参数字典extra_step_kwargs = {}# 如果调度器接受 eta,则将其添加到字典中if accepts_eta:extra_step_kwargs["eta"] = eta# 检查调度器步骤的参数是否接受 generatoraccepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())# 如果调度器接受 generator,则将其添加到字典中if accepts_generator:extra_step_kwargs["generator"] = generator# 返回额外步骤关键字参数字典return extra_step_kwargs# 检查输入参数的方法def check_inputs(self,prompt,height,width,callback_steps,negative_prompt=None,prompt_embeds=None,negative_prompt_embeds=None,callback_on_step_end_tensor_inputs=None,):# 检查高度和宽度是否能被8整除if height % 8 != 0 or width % 8 != 0:# 抛出异常,给出高度和宽度的信息raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")# 检查回调步数是否为正整数if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):# 抛出异常,给出回调步数的信息raise ValueError(f"`callback_steps` has to be a positive integer but is {callback_steps} of type"f" {type(callback_steps)}.")# 检查回调输入是否在预期的输入列表中if callback_on_step_end_tensor_inputs is not None and not all(k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs):# 抛出异常,列出不在预期列表中的输入raise ValueError(f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}")# 检查是否同时提供了提示和提示嵌入if prompt is not None and prompt_embeds is not None:# 抛出异常,提示只能提供一个raise ValueError(f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"" only forward one of the two.")# 检查是否同时未提供提示和提示嵌入elif prompt is None and prompt_embeds is None:# 抛出异常,提示至少要提供一个raise ValueError("Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.")# 检查提示类型是否合法elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):# 抛出异常,提示类型不符合要求raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")# 检查是否同时提供了负提示和负提示嵌入if negative_prompt is not None and negative_prompt_embeds is not None:# 抛出异常,提示只能提供一个raise ValueError(f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"f" {negative_prompt_embeds}. Please make sure to only forward one of the two.")# 检查提示嵌入和负提示嵌入的形状是否一致if prompt_embeds is not None and negative_prompt_embeds is not None:if prompt_embeds.shape != negative_prompt_embeds.shape:# 抛出异常,给出形状不一致的信息raise ValueError("`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"f" {negative_prompt_embeds.shape}.")# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 复制# 准备潜在变量的函数,接受多个参数以控制形状和生成方式def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):# 定义潜在变量的形状,包括批量大小、通道数和调整后的高度宽度shape = (batch_size,num_channels_latents,int(height) // self.vae_scale_factor,int(width) // self.vae_scale_factor,)# 检查生成器列表的长度是否与批量大小一致if isinstance(generator, list) and len(generator) != batch_size:# 如果不一致,则抛出值错误并提供相关信息raise ValueError(f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"f" size of {batch_size}. Make sure the batch size matches the length of the generators.")# 如果潜在变量为空,则生成新的随机潜在变量if latents is None:latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)else:# 如果已提供潜在变量,则将其转移到指定设备latents = latents.to(device)# 根据调度器要求的标准差缩放初始噪声latents = latents * self.scheduler.init_noise_sigma# 返回准备好的潜在变量return latents# 装饰器,指示该函数不需要计算梯度@torch.no_grad()def edit_model(self,source_prompt: str,destination_prompt: str,lamb: float = 0.1,restart_params: bool = True,# 装饰器,指示该函数不需要计算梯度@torch.no_grad()def __call__(self,# 允许输入字符串或字符串列表作为提示prompt: Union[str, List[str]] = None,# 可选参数,指定生成图像的高度height: Optional[int] = None,# 可选参数,指定生成图像的宽度width: Optional[int] = None,# 设置推理步骤的默认数量为50num_inference_steps: int = 50,# 设置引导比例的默认值为7.5guidance_scale: float = 7.5,# 可选参数,允许输入负面提示negative_prompt: Optional[Union[str, List[str]]] = None,# 可选参数,指定每个提示生成的图像数量,默认值为1num_images_per_prompt: Optional[int] = 1,# 设置eta的默认值为0.0eta: float = 0.0,# 可选参数,允许输入生成器或生成器列表generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,# 可选参数,允许输入潜在变量的张量latents: Optional[torch.Tensor] = None,# 可选参数,允许输入提示嵌入的张量prompt_embeds: Optional[torch.Tensor] = None,# 可选参数,允许输入负面提示嵌入的张量negative_prompt_embeds: Optional[torch.Tensor] = None,# 可选参数,指定输出类型,默认为"pil"output_type: Optional[str] = "pil",# 可选参数,控制是否返回字典形式的输出,默认为Truereturn_dict: bool = True,# 可选回调函数,用于处理生成过程中每一步的信息callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,# 可选参数,指定回调的步骤数,默认为1callback_steps: int = 1,# 可选参数,允许传入交叉注意力的关键字参数cross_attention_kwargs: Optional[Dict[str, Any]] = None,# 可选参数,允许指定跳过的clip层数clip_skip: Optional[int] = None,

.\diffusers\pipelines\deprecated\stable_diffusion_variants\pipeline_stable_diffusion_paradigms.py

# 版权所有 2024 ParaDiGMS 作者和 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)许可;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下地址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,依据许可证分发的软件
# 是按“原样”提供的,没有任何形式的明示或暗示的担保或条件。
# 有关许可证的特定权限和限制,请参见许可证。# 导入 inspect 模块以进行对象检查
import inspect
# 从 typing 模块导入类型提示相关的类
from typing import Any, Callable, Dict, List, Optional, Union# 导入 PyTorch 库以进行深度学习操作
import torch
# 从 transformers 库导入 CLIP 图像处理器、文本模型和分词器
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer# 导入自定义的图像处理器
from ....image_processor import VaeImageProcessor
# 导入与加载相关的混合类
from ....loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
# 导入用于自动编码器和条件 UNet 的模型
from ....models import AutoencoderKL, UNet2DConditionModel
# 从 lora 模块导入调整 lora 规模的函数
from ....models.lora import adjust_lora_scale_text_encoder
# 导入 Karras 扩散调度器
from ....schedulers import KarrasDiffusionSchedulers
# 导入实用工具模块中的各种功能
from ....utils import (USE_PEFT_BACKEND,  # 用于 PEFT 后端的标志deprecate,  # 用于标记弃用功能的装饰器logging,  # 日志记录功能replace_example_docstring,  # 替换示例文档字符串的功能scale_lora_layers,  # 调整 lora 层规模的功能unscale_lora_layers,  # 反调整 lora 层规模的功能
)
# 从 torch_utils 模块导入生成随机张量的功能
from ....utils.torch_utils import randn_tensor
# 导入扩散管道和稳定扩散混合类
from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin
# 导入稳定扩散管道输出类
from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
# 导入稳定扩散安全检查器
from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker# 创建日志记录器,用于当前模块的日志
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name# 示例文档字符串,展示用法示例
EXAMPLE_DOC_STRING = """Examples:```py>>> import torch>>> from diffusers import DDPMParallelScheduler>>> from diffusers import StableDiffusionParadigmsPipeline>>> scheduler = DDPMParallelScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")>>> pipe = StableDiffusionParadigmsPipeline.from_pretrained(...     "runwayml/stable-diffusion-v1-5", scheduler=scheduler, torch_dtype=torch.float16... )>>> pipe = pipe.to("cuda")>>> ngpu, batch_per_device = torch.cuda.device_count(), 5>>> pipe.wrapped_unet = torch.nn.DataParallel(pipe.unet, device_ids=[d for d in range(ngpu)])>>> prompt = "a photo of an astronaut riding a horse on mars">>> image = pipe(prompt, parallel=ngpu * batch_per_device, num_inference_steps=1000).images[0]```py
"""# 定义 StableDiffusionParadigmsPipeline 类,继承多个混合类以实现功能
class StableDiffusionParadigmsPipeline(DiffusionPipeline,  # 从扩散管道继承StableDiffusionMixin,  # 从稳定扩散混合类继承TextualInversionLoaderMixin,  # 从文本反转加载混合类继承StableDiffusionLoraLoaderMixin,  # 从稳定扩散 lora 加载混合类继承FromSingleFileMixin,  # 从单文件加载混合类继承
):r"""用于文本到图像生成的管道,使用稳定扩散的并行化版本。此模型继承自 [`DiffusionPipeline`]。有关通用方法的文档,请查看超类文档# 实现所有管道的功能(下载、保存、在特定设备上运行等)。# 管道还继承以下加载方法:# - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] 用于加载文本反转嵌入# - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] 用于加载 LoRA 权重# - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] 用于保存 LoRA 权重# - [`~loaders.FromSingleFileMixin.from_single_file`] 用于加载 `.ckpt` 文件# 参数说明:# vae ([`AutoencoderKL`]):#    变分自编码器(VAE)模型,用于将图像编码和解码为潜在表示。# text_encoder ([`~transformers.CLIPTextModel`]):#    冻结的文本编码器([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14))。# tokenizer ([`~transformers.CLIPTokenizer`]):#    一个 `CLIPTokenizer` 用于对文本进行标记化。# unet ([`UNet2DConditionModel`]):#    一个 `UNet2DConditionModel` 用于去噪编码的图像潜在。# scheduler ([`SchedulerMixin`]):#    用于与 `unet` 结合使用的调度器,用于去噪编码的图像潜在。可以是#    [`DDIMScheduler`], [`LMSDiscreteScheduler`] 或 [`PNDMScheduler`] 中的一个。# safety_checker ([`StableDiffusionSafetyChecker`]):#    分类模块,估计生成的图像是否可能被认为是冒犯或有害的。#    请参考 [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) 获取更多关于模型潜在危害的详细信息。# feature_extractor ([`~transformers.CLIPImageProcessor`]):#    一个 `CLIPImageProcessor` 用于从生成的图像中提取特征;作为 `safety_checker` 的输入。# 定义模型的 CPU 离线加载顺序model_cpu_offload_seq = "text_encoder->unet->vae"# 定义可选组件列表_optional_components = ["safety_checker", "feature_extractor"]# 定义排除在 CPU 离线加载之外的组件_exclude_from_cpu_offload = ["safety_checker"]# 初始化方法,接受多个参数def __init__(self,vae: AutoencoderKL,  # 变分自编码器模型text_encoder: CLIPTextModel,  # 文本编码器tokenizer: CLIPTokenizer,  # 文本标记器unet: UNet2DConditionModel,  # UNet2D 条件模型scheduler: KarrasDiffusionSchedulers,  # 调度器safety_checker: StableDiffusionSafetyChecker,  # 安全检查模块feature_extractor: CLIPImageProcessor,  # 特征提取器requires_safety_checker: bool = True,  # 是否需要安全检查器):# 调用父类的初始化方法super().__init__()# 检查是否禁用安全检查器,并且需要安全检查器if safety_checker is None and requires_safety_checker:# 记录警告信息,提醒用户禁用安全检查器的风险logger.warning(f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"" results in services or applications open to the public. Both the diffusers team and Hugging Face"" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"" it only for use-cases that involve analyzing network behavior or auditing its results. For more"" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .")# 检查是否提供了安全检查器但未提供特征提取器if safety_checker is not None and feature_extractor is None:# 抛出错误,提示用户需要定义特征提取器raise ValueError("Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.")# 注册各个模块到当前实例self.register_modules(vae=vae,  # 变分自编码器text_encoder=text_encoder,  # 文本编码器tokenizer=tokenizer,  # 分词器unet=unet,  # U-Net 模型scheduler=scheduler,  # 调度器safety_checker=safety_checker,  # 安全检查器feature_extractor=feature_extractor,  # 特征提取器)# 计算 VAE 的缩放因子self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)# 初始化图像处理器,使用 VAE 缩放因子self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)# 将是否需要安全检查器的配置注册到当前实例self.register_to_config(requires_safety_checker=requires_safety_checker)# 用于在多个 GPU 上运行多个去噪步骤时,将 unet 包装为 torch.nn.DataParallelself.wrapped_unet = self.unet# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt 复制的函数def _encode_prompt(self,prompt,  # 输入的提示文本device,  # 设备类型(CPU或GPU)num_images_per_prompt,  # 每个提示生成的图像数量do_classifier_free_guidance,  # 是否使用无分类器引导negative_prompt=None,  # 可选的负面提示文本prompt_embeds: Optional[torch.Tensor] = None,  # 可选的提示嵌入negative_prompt_embeds: Optional[torch.Tensor] = None,  # 可选的负面提示嵌入lora_scale: Optional[float] = None,  # 可选的 LoRA 缩放因子**kwargs,  # 其他可选参数):# 设置弃用信息,提醒用户该方法即将被移除,建议使用新方法deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."# 调用 deprecate 函数,传递弃用信息和版本号deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)# 调用 encode_prompt 方法,获取提示的嵌入元组prompt_embeds_tuple = self.encode_prompt(prompt=prompt,  # 输入提示文本device=device,  # 计算设备num_images_per_prompt=num_images_per_prompt,  # 每个提示生成的图像数量do_classifier_free_guidance=do_classifier_free_guidance,  # 是否使用无分类器引导negative_prompt=negative_prompt,  # 负提示文本prompt_embeds=prompt_embeds,  # 提示嵌入negative_prompt_embeds=negative_prompt_embeds,  # 负提示嵌入lora_scale=lora_scale,  # LORA 缩放因子**kwargs,  # 其他额外参数)# 将返回的嵌入元组进行拼接,以支持向后兼容prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])# 返回拼接后的提示嵌入return prompt_embeds# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline 中复制的 encode_prompt 方法def encode_prompt(self,prompt,  # 输入的提示文本device,  # 计算设备num_images_per_prompt,  # 每个提示生成的图像数量do_classifier_free_guidance,  # 是否使用无分类器引导negative_prompt=None,  # 负提示文本,默认为 Noneprompt_embeds: Optional[torch.Tensor] = None,  # 可选的提示嵌入negative_prompt_embeds: Optional[torch.Tensor] = None,  # 可选的负提示嵌入lora_scale: Optional[float] = None,  # 可选的 LORA 缩放因子clip_skip: Optional[int] = None,  # 可选的跳过剪辑参数# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker 复制的代码def run_safety_checker(self, image, device, dtype):# 检查是否存在安全检查器if self.safety_checker is None:has_nsfw_concept = None  # 如果没有安全检查器,则设置无 NSFW 概念为 Noneelse:# 如果输入是张量格式,则进行后处理为 PIL 格式if torch.is_tensor(image):feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")else:# 如果输入不是张量,转换为 PIL 格式feature_extractor_input = self.image_processor.numpy_to_pil(image)# 使用特征提取器处理图像,返回张量形式safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)# 调用安全检查器,检查图像是否包含 NSFW 概念image, has_nsfw_concept = self.safety_checker(images=image,  # 输入图像clip_input=safety_checker_input.pixel_values.to(dtype)  # 安全检查的特征输入)# 返回处理后的图像及是否存在 NSFW 概念return image, has_nsfw_concept# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 复制的代码# 定义一个方法,用于准备额外的参数以供调度器步骤使用def prepare_extra_step_kwargs(self, generator, eta):# 为调度器步骤准备额外的关键字参数,因为并非所有调度器都有相同的签名# eta (η) 仅在 DDIMScheduler 中使用,对于其他调度器将被忽略# eta 对应于 DDIM 论文中的 η: https://arxiv.org/abs/2010.02502# 其值应在 [0, 1] 之间# 检查调度器的 step 方法是否接受 eta 参数accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())# 初始化一个字典以存储额外的步骤参数extra_step_kwargs = {}# 如果接受 eta 参数,则将其添加到字典中if accepts_eta:extra_step_kwargs["eta"] = eta# 检查调度器的 step 方法是否接受 generator 参数accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())# 如果接受 generator 参数,则将其添加到字典中if accepts_generator:extra_step_kwargs["generator"] = generator# 返回准备好的额外步骤参数字典return extra_step_kwargs# 定义一个方法,用于检查输入参数的有效性def check_inputs(self,prompt,  # 输入的提示文本height,  # 图像的高度width,   # 图像的宽度callback_steps,  # 回调步骤的频率negative_prompt=None,  # 可选的负面提示文本prompt_embeds=None,  # 可选的提示嵌入negative_prompt_embeds=None,  # 可选的负面提示嵌入callback_on_step_end_tensor_inputs=None,  # 可选的在步骤结束时的回调张量输入):# 检查高度和宽度是否为 8 的倍数if height % 8 != 0 or width % 8 != 0:# 抛出错误,如果高度或宽度不符合要求raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")# 检查回调步数是否为正整数if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):# 抛出错误,如果回调步数不符合要求raise ValueError(f"`callback_steps` has to be a positive integer but is {callback_steps} of type"f" {type(callback_steps)}.")# 检查回调结束时的张量输入是否在允许的输入中if callback_on_step_end_tensor_inputs is not None and not all(k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs):# 抛出错误,如果不在允许的输入中raise ValueError(f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}")# 检查是否同时提供了提示和提示嵌入if prompt is not None and prompt_embeds is not None:# 抛出错误,不能同时提供提示和提示嵌入raise ValueError(f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"" only forward one of the two.")# 检查提示和提示嵌入是否均未定义elif prompt is None and prompt_embeds is None:# 抛出错误,必须提供一个提示或提示嵌入raise ValueError("Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.")# 检查提示的类型elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):# 抛出错误,如果提示不是字符串或列表raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")# 检查是否同时提供了负提示和负提示嵌入if negative_prompt is not None and negative_prompt_embeds is not None:# 抛出错误,不能同时提供负提示和负提示嵌入raise ValueError(f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"f" {negative_prompt_embeds}. Please make sure to only forward one of the two.")# 检查提示嵌入和负提示嵌入的形状是否相同if prompt_embeds is not None and negative_prompt_embeds is not None:if prompt_embeds.shape != negative_prompt_embeds.shape:# 抛出错误,如果形状不一致raise ValueError("`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"f" {negative_prompt_embeds.shape}.")# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 复制的代码# 准备潜在变量,设置其形状和属性def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):# 定义潜在变量的形状,考虑批量大小、通道数和缩放因子shape = (batch_size,num_channels_latents,int(height) // self.vae_scale_factor,int(width) // self.vae_scale_factor,)# 检查生成器列表的长度是否与批量大小匹配if isinstance(generator, list) and len(generator) != batch_size:raise ValueError(f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"f" size of {batch_size}. Make sure the batch size matches the length of the generators.")# 如果潜在变量未提供,则生成随机潜在变量if latents is None:latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)else:# 将提供的潜在变量移动到指定设备latents = latents.to(device)# 根据调度器要求的标准差缩放初始噪声latents = latents * self.scheduler.init_noise_sigma# 返回处理后的潜在变量return latents# 计算输入张量在指定维度上的累积和def _cumsum(self, input, dim, debug=False):# 如果调试模式开启,则在CPU上执行累积和以确保可重复性if debug:# cumsum_cuda_kernel没有确定性实现,故在CPU上执行return torch.cumsum(input.cpu().float(), dim=dim).to(input.device)else:# 在指定维度上直接计算累积和return torch.cumsum(input, dim=dim)# 调用方法,接受多种参数以生成输出@torch.no_grad()@replace_example_docstring(EXAMPLE_DOC_STRING)def __call__(self,prompt: Union[str, List[str]] = None,height: Optional[int] = None,width: Optional[int] = None,num_inference_steps: int = 50,parallel: int = 10,tolerance: float = 0.1,guidance_scale: float = 7.5,negative_prompt: Optional[Union[str, List[str]]] = None,num_images_per_prompt: Optional[int] = 1,eta: float = 0.0,generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,latents: Optional[torch.Tensor] = None,prompt_embeds: Optional[torch.Tensor] = None,negative_prompt_embeds: Optional[torch.Tensor] = None,output_type: Optional[str] = "pil",return_dict: bool = True,callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,callback_steps: int = 1,cross_attention_kwargs: Optional[Dict[str, Any]] = None,debug: bool = False,clip_skip: int = None,

.\diffusers\pipelines\deprecated\stable_diffusion_variants\pipeline_stable_diffusion_pix2pix_zero.py

# 版权声明,说明版权信息及持有者
# Copyright 2024 Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved.
#
# 使用 Apache License 2.0 许可协议
# Licensed under the Apache License, Version 2.0 (the "License");
# 该文件仅在遵循许可协议的情况下使用
# you may not use this file except in compliance with the License.
# 许可协议的获取链接
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用的法律要求或书面协议,否则软件按“原样”分发
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何明示或暗示的保证或条件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可协议中关于权限和限制的详细信息
# See the License for the specific language governing permissions and
# limitations under the License.import inspect  # 导入 inspect 模块,用于获取对象信息
from dataclasses import dataclass  # 从 dataclasses 模块导入 dataclass 装饰器
from typing import Any, Callable, Dict, List, Optional, Union  # 导入类型提示import numpy as np  # 导入 numpy 模块,常用于数值计算
import PIL.Image  # 导入 PIL.Image 模块,用于图像处理
import torch  # 导入 PyTorch 库,主要用于深度学习
import torch.nn.functional as F  # 导入 PyTorch 的功能性神经网络模块
from transformers import (  # 从 transformers 模块导入以下类BlipForConditionalGeneration,  # 导入用于条件生成的 Blip 模型BlipProcessor,  # 导入 Blip 处理器CLIPImageProcessor,  # 导入 CLIP 图像处理器CLIPTextModel,  # 导入 CLIP 文本模型CLIPTokenizer,  # 导入 CLIP 分词器
)from ....image_processor import PipelineImageInput, VaeImageProcessor  # 从自定义模块导入图像处理相关类
from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin  # 导入稳定扩散和文本反转加载器混合类
from ....models import AutoencoderKL, UNet2DConditionModel  # 导入自动编码器和 UNet 模型
from ....models.attention_processor import Attention  # 导入注意力处理器
from ....models.lora import adjust_lora_scale_text_encoder  # 导入调整 Lora 文本编码器规模的函数
from ....schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler  # 导入多种调度器
from ....schedulers.scheduling_ddim_inverse import DDIMInverseScheduler  # 导入 DDIM 反向调度器
from ....utils import (  # 从自定义工具模块导入实用函数和常量PIL_INTERPOLATION,  # 导入 PIL 图像插值方法USE_PEFT_BACKEND,  # 导入是否使用 PEFT 后端的常量BaseOutput,  # 导入基础输出类deprecate,  # 导入废弃标记装饰器logging,  # 导入日志记录模块replace_example_docstring,  # 导入替换示例文档字符串的函数scale_lora_layers,  # 导入缩放 Lora 层的函数unscale_lora_layers,  # 导入反缩放 Lora 层的函数
)
from ....utils.torch_utils import randn_tensor  # 从 PyTorch 工具模块导入生成随机张量的函数
from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin  # 从管道工具模块导入扩散管道和稳定扩散混合类
from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput  # 导入稳定扩散管道输出类
from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker  # 导入稳定扩散安全检查器logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器@dataclass  # 将下面的类定义为数据类
class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin):  # 定义输出类,继承基础输出和文本反转加载器混合类"""输出类用于稳定扩散管道。参数:latents (`torch.Tensor`)反转的潜在张量images (`List[PIL.Image.Image]` or `np.ndarray`)长度为 `batch_size` 的去噪 PIL 图像列表或形状为 `(batch_size, height, width,num_channels)` 的 numpy 数组。PIL 图像或 numpy 数组呈现扩散管道的去噪图像。"""latents: torch.Tensor  # 定义潜在张量属性images: Union[List[PIL.Image.Image], np.ndarray]  # 定义图像属性,可以是图像列表或 numpy 数组EXAMPLE_DOC_STRING = """  # 定义示例文档字符串的初始部分
```  # 示例文档字符串的开始
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束
```py  # 示例文档字符串的结束
```  # 示例文档字符串的结束# 示例代码展示如何使用 Diffusers 库进行图像生成Examples:```py# 导入所需的库>>> import requests  # 用于发送 HTTP 请求>>> import torch  # 用于处理张量和深度学习模型# 从 Diffusers 库导入必要的类>>> from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline# 定义下载嵌入文件的函数>>> def download(embedding_url, local_filepath):...     # 发送 GET 请求获取嵌入文件...     r = requests.get(embedding_url)...     # 以二进制模式打开本地文件并写入获取的内容...     with open(local_filepath, "wb") as f:...         f.write(r.content)# 定义模型检查点的名称>>> model_ckpt = "CompVis/stable-diffusion-v1-4"# 从预训练模型加载管道并设置数据类型为 float16>>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16)# 根据管道配置创建 DDIM 调度器>>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)# 将模型移动到 GPU>>> pipeline.to("cuda")# 定义文本提示>>> prompt = "a high resolution painting of a cat in the style of van gough"# 定义源和目标嵌入文件的 URL>>> source_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt">>> target_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt"# 遍历源和目标嵌入 URL 进行下载>>> for url in [source_emb_url, target_emb_url]:...     # 调用下载函数,将文件保存到本地...     download(url, url.split("/")[-1])# 从本地加载源嵌入>>> src_embeds = torch.load(source_emb_url.split("/")[-1])# 从本地加载目标嵌入>>> target_embeds = torch.load(target_emb_url.split("/")[-1])# 使用管道生成图像>>> images = pipeline(...     prompt,  # 输入的文本提示...     source_embeds=src_embeds,  # 源嵌入...     target_embeds=target_embeds,  # 目标嵌入...     num_inference_steps=50,  # 推理步骤数...     cross_attention_guidance_amount=0.15,  # 跨注意力引导的强度... ).images  # 生成的图像# 保存生成的第一张图像>>> images[0].save("edited_image_dog.png")  # 将图像保存为 PNG 文件
"""
# 示例文档字符串,提供了使用示例和说明
EXAMPLE_INVERT_DOC_STRING = """Examples:```py>>> import torch  # 导入 PyTorch 库>>> from transformers import BlipForConditionalGeneration, BlipProcessor  # 从 transformers 导入模型和处理器>>> from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline  # 从 diffusers 导入调度器和管道>>> import requests  # 导入 requests 库,用于发送网络请求>>> from PIL import Image  # 从 PIL 导入 Image 类,用于处理图像>>> captioner_id = "Salesforce/blip-image-captioning-base"  # 定义图像说明生成模型的 ID>>> processor = BlipProcessor.from_pretrained(captioner_id)  # 从预训练模型加载处理器>>> model = BlipForConditionalGeneration.from_pretrained(  # 从预训练模型加载图像说明生成模型...     captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True  # 指定数据类型和低内存使用模式... )>>> sd_model_ckpt = "CompVis/stable-diffusion-v1-4"  # 定义稳定扩散模型的检查点 ID>>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(  # 从预训练模型加载 Pix2Pix 零管道...     sd_model_ckpt,  # 指定检查点...     caption_generator=model,  # 指定图像说明生成器...     caption_processor=processor,  # 指定图像说明处理器...     torch_dtype=torch.float16,  # 指定数据类型...     safety_checker=None,  # 关闭安全检查器... )>>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)  # 使用调度器配置初始化 DDIM 调度器>>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)  # 使用调度器配置初始化 DDIM 反向调度器>>> pipeline.enable_model_cpu_offload()  # 启用模型的 CPU 卸载>>> img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"  # 定义要处理的图像 URL>>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512))  # 从 URL 加载图像并调整大小>>> # 生成说明>>> caption = pipeline.generate_caption(raw_image)  # 生成图像的说明>>> # "a photography of a cat with flowers and dai dai daie - daie - daie kasaii"  # 生成的说明示例>>> inv_latents = pipeline.invert(caption, image=raw_image).latents  # 根据说明和原始图像进行反向处理,获取潜变量>>> # 我们需要生成源和目标嵌入>>> source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]  # 定义源提示列表>>> target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"]  # 定义目标提示列表>>> source_embeds = pipeline.get_embeds(source_prompts)  # 获取源提示的嵌入表示>>> target_embeds = pipeline.get_embeds(target_prompts)  # 获取目标提示的嵌入表示>>> # 潜变量可以用于编辑真实图像>>> # 在使用稳定扩散 2 或其他使用 v-prediction 的模型时>>> # 将 `cross_attention_guidance_amount` 设置为 0.01 或更低,以避免输入潜变量梯度爆炸>>> image = pipeline(  # 使用管道生成新的图像...     caption,  # 使用生成的说明...     source_embeds=source_embeds,  # 传递源嵌入...     target_embeds=target_embeds,  # 传递目标嵌入...     num_inference_steps=50,  # 指定推理步骤数量...     cross_attention_guidance_amount=0.15,  # 指定交叉注意力指导量...     generator=generator,  # 使用指定的生成器...     latents=inv_latents,  # 传递潜变量...     negative_prompt=caption,  # 使用生成的说明作为负提示... ).images[0]  # 获取生成的图像>>> image.save("edited_image.png")  # 保存生成的图像```py
"""# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img 导入的 preprocess 函数
def preprocess(image):  # 定义 preprocess 函数,接受图像作为参数# 设置一个警告信息,提示用户 preprocess 方法已被弃用,并将在 diffusers 1.0.0 中移除deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"# 调用 deprecate 函数,记录弃用信息,设定标准警告为 Falsedeprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)# 检查输入的 image 是否是一个 Torch 张量if isinstance(image, torch.Tensor):# 如果是,直接返回该张量return image# 检查输入的 image 是否是一个 PIL 图像elif isinstance(image, PIL.Image.Image):# 如果是,将其封装为一个单元素列表image = [image]# 检查列表中的第一个元素是否为 PIL 图像if isinstance(image[0], PIL.Image.Image):# 获取第一个图像的宽度和高度w, h = image[0].size# 将宽度和高度调整为 8 的整数倍w, h = (x - x % 8 for x in (w, h))  # resize to integer multiple of 8# 对每个图像进行调整大小,转换为 numpy 数组,并在新维度上增加一维image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]# 将所有图像在第 0 维上连接成一个大的数组image = np.concatenate(image, axis=0)# 将数据转换为 float32 类型并归一化到 [0, 1] 范围image = np.array(image).astype(np.float32) / 255.0# 调整数组维度顺序为 (batch_size, channels, height, width)image = image.transpose(0, 3, 1, 2)# 将像素值范围从 [0, 1] 转换到 [-1, 1]image = 2.0 * image - 1.0# 将 numpy 数组转换为 Torch 张量image = torch.from_numpy(image)# 检查列表中的第一个元素是否为 Torch 张量elif isinstance(image[0], torch.Tensor):# 将多个张量在第 0 维上连接成一个大的张量image = torch.cat(image, dim=0)# 返回处理后的图像return image
# 准备 UNet 模型以执行 Pix2Pix Zero 优化
def prepare_unet(unet: UNet2DConditionModel):# 初始化一个空字典,用于存储 Pix2Pix Zero 注意力处理器pix2pix_zero_attn_procs = {}# 遍历 UNet 的注意力处理器的键for name in unet.attn_processors.keys():# 将处理器名称中的 ".processor" 替换为空module_name = name.replace(".processor", "")# 获取 UNet 中对应的子模块module = unet.get_submodule(module_name)# 如果名称包含 "attn2"if "attn2" in name:# 将处理器设置为 Pix2Pix Zero 模式pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=True)# 允许该模块进行梯度更新module.requires_grad_(True)else:# 设置为非 Pix2Pix Zero 模式pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=False)# 不允许该模块进行梯度更新module.requires_grad_(False)# 设置 UNet 的注意力处理器为修改后的处理器字典unet.set_attn_processor(pix2pix_zero_attn_procs)# 返回修改后的 UNet 模型return unetclass Pix2PixZeroL2Loss:# 初始化损失类def __init__(self):# 设置初始损失值为 0self.loss = 0.0# 计算损失的方法def compute_loss(self, predictions, targets):# 更新损失值为预测值与目标值之间的均方差self.loss += ((predictions - targets) ** 2).sum((1, 2)).mean(0)class Pix2PixZeroAttnProcessor:"""注意力处理器类,用于存储注意力权重。在 Pix2Pix Zero 中,该过程发生在交叉注意力块的计算中。"""# 初始化注意力处理器def __init__(self, is_pix2pix_zero=False):# 记录是否为 Pix2Pix Zero 模式self.is_pix2pix_zero = is_pix2pix_zero# 如果是 Pix2Pix Zero 模式,初始化参考交叉注意力映射if self.is_pix2pix_zero:self.reference_cross_attn_map = {}# 定义调用方法def __call__(self,attn: Attention,hidden_states,encoder_hidden_states=None,attention_mask=None,timestep=None,loss=None,):# 获取隐藏状态的批次大小和序列长度batch_size, sequence_length, _ = hidden_states.shape# 准备注意力掩码attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)# 将隐藏状态转换为查询向量query = attn.to_q(hidden_states)# 如果没有编码器隐藏状态,则使用隐藏状态本身if encoder_hidden_states is None:encoder_hidden_states = hidden_states# 如果需要进行交叉规范化elif attn.norm_cross:encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)# 将编码器隐藏状态转换为键和值key = attn.to_k(encoder_hidden_states)value = attn.to_v(encoder_hidden_states)# 将查询、键和值转换为批次维度query = attn.head_to_batch_dim(query)key = attn.head_to_batch_dim(key)value = attn.head_to_batch_dim(value)# 计算注意力分数attention_probs = attn.get_attention_scores(query, key, attention_mask)# 如果是 Pix2Pix Zero 模式且时间步不为 Noneif self.is_pix2pix_zero and timestep is not None:# 新的记录以保存注意力权重if loss is None:self.reference_cross_attn_map[timestep.item()] = attention_probs.detach().cpu()# 计算损失elif loss is not None:# 获取之前的注意力概率prev_attn_probs = self.reference_cross_attn_map.pop(timestep.item())# 计算损失loss.compute_loss(attention_probs, prev_attn_probs.to(attention_probs.device))# 将注意力概率与值相乘以获得新的隐藏状态hidden_states = torch.bmm(attention_probs, value)# 将隐藏状态转换回头维度hidden_states = attn.batch_to_head_dim(hidden_states)# 线性变换hidden_states = attn.to_out[0](hidden_states)# 应用 dropouthidden_states = attn.to_out[1](hidden_states)# 返回新的隐藏状态return hidden_states
# 定义一个用于像素级图像编辑的管道类,基于 Stable Diffusion
class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin):r"""使用 Pix2Pix Zero 进行像素级图像编辑的管道。基于 Stable Diffusion。该模型继承自 [`DiffusionPipeline`]。请查阅超类文档以获取库为所有管道实现的通用方法(例如下载或保存、在特定设备上运行等)。参数:vae ([`AutoencoderKL`]):用于将图像编码和解码到潜在表示的变分自编码器(VAE)模型。text_encoder ([`CLIPTextModel`]):冻结的文本编码器。Stable Diffusion 使用[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) 的文本部分,特别是 [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) 变体。tokenizer (`CLIPTokenizer`):类的分词器[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)。unet ([`UNet2DConditionModel`]): 用于去噪编码图像潜在的条件 U-Net 架构。scheduler ([`SchedulerMixin`]):与 `unet` 一起使用以去噪编码图像潜在的调度器。可以是[`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] 或 [`DDPMScheduler`] 的之一。safety_checker ([`StableDiffusionSafetyChecker`]):估计生成图像是否可能被视为攻击性或有害的分类模块。请参阅 [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) 以获取详细信息。feature_extractor ([`CLIPImageProcessor`]):从生成图像中提取特征的模型,以便作为 `safety_checker` 的输入。requires_safety_checker (bool):管道是否需要安全检查器。如果您公开使用该管道,我们建议将其设置为 True。"""# 定义 CPU 卸载的模型组件顺序model_cpu_offload_seq = "text_encoder->unet->vae"# 可选组件列表_optional_components = ["safety_checker","feature_extractor","caption_generator","caption_processor","inverse_scheduler",]# 从 CPU 卸载中排除的组件列表_exclude_from_cpu_offload = ["safety_checker"]# 初始化方法,定义管道的参数def __init__(self,vae: AutoencoderKL,  # 变分自编码器模型text_encoder: CLIPTextModel,  # 文本编码器模型tokenizer: CLIPTokenizer,  # 分词器模型unet: UNet2DConditionModel,  # 条件 U-Net 模型scheduler: Union[DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler],  # 调度器类型feature_extractor: CLIPImageProcessor,  # 特征提取器模型safety_checker: StableDiffusionSafetyChecker,  # 安全检查器模型inverse_scheduler: DDIMInverseScheduler,  # 反向调度器caption_generator: BlipForConditionalGeneration,  # 描述生成器caption_processor: BlipProcessor,  # 描述处理器requires_safety_checker: bool = True,  # 是否需要安全检查器的标志# 定义一个构造函数):# 调用父类的构造函数super().__init__()# 如果没有提供安全检查器且需要安全检查器,发出警告if safety_checker is None and requires_safety_checker:logger.warning(# 输出关于禁用安全检查器的警告信息f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"" results in services or applications open to the public. Both the diffusers team and Hugging Face"" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"" it only for use-cases that involve analyzing network behavior or auditing its results. For more"" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .")# 如果提供了安全检查器但没有提供特征提取器,抛出错误if safety_checker is not None and feature_extractor is None:raise ValueError(# 提示用户必须定义特征提取器以使用安全检查器"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.")# 注册模块,设置各个组件self.register_modules(vae=vae,text_encoder=text_encoder,tokenizer=tokenizer,unet=unet,scheduler=scheduler,safety_checker=safety_checker,feature_extractor=feature_extractor,caption_processor=caption_processor,caption_generator=caption_generator,inverse_scheduler=inverse_scheduler,)# 计算 VAE 的缩放因子self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)# 创建图像处理器,使用 VAE 缩放因子self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)# 将配置项注册到当前实例self.register_to_config(requires_safety_checker=requires_safety_checker)# 从 StableDiffusionPipeline 类复制的编码提示的方法def _encode_prompt(self,prompt,device,num_images_per_prompt,do_classifier_free_guidance,negative_prompt=None,# 可选参数,表示提示的嵌入prompt_embeds: Optional[torch.Tensor] = None,# 可选参数,表示负面提示的嵌入negative_prompt_embeds: Optional[torch.Tensor] = None,# 可选参数,表示 LORA 的缩放因子lora_scale: Optional[float] = None,# 接收任意额外参数**kwargs,# 开始定义一个方法,处理已弃用的编码提示功能):# 定义弃用信息,说明该方法将被移除,并推荐使用新方法deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."# 调用弃用函数,记录弃用警告deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)# 调用新的编码提示方法,获取结果元组prompt_embeds_tuple = self.encode_prompt(# 传入提示文本prompt=prompt,# 设备类型(CPU/GPU)device=device,# 每个提示生成的图像数量num_images_per_prompt=num_images_per_prompt,# 是否进行分类器自由引导do_classifier_free_guidance=do_classifier_free_guidance,# 负面提示文本negative_prompt=negative_prompt,# 提示嵌入prompt_embeds=prompt_embeds,# 负面提示嵌入negative_prompt_embeds=negative_prompt_embeds,# Lora缩放因子lora_scale=lora_scale,# 其他可选参数**kwargs,)# 将返回的元组中的两个嵌入连接起来以兼容旧版prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])# 返回最终的提示嵌入return prompt_embeds# 从指定的管道中复制的 encode_prompt 方法定义def encode_prompt(# 提示文本self,prompt,# 设备类型device,# 每个提示生成的图像数量num_images_per_prompt,# 是否进行分类器自由引导do_classifier_free_guidance,# 负面提示文本(可选)negative_prompt=None,# 提示嵌入(可选)prompt_embeds: Optional[torch.Tensor] = None,# 负面提示嵌入(可选)negative_prompt_embeds: Optional[torch.Tensor] = None,# Lora缩放因子(可选)lora_scale: Optional[float] = None,# 跳过的clip层数(可选)clip_skip: Optional[int] = None,# 从指定的管道中复制的 run_safety_checker 方法定义def run_safety_checker(self, image, device, dtype):# 检查是否存在安全检查器if self.safety_checker is None:# 如果没有安全检查器,标记为无概念has_nsfw_concept = Noneelse:# 如果图像是张量,则进行后处理以转换为PIL格式if torch.is_tensor(image):feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")else:# 如果不是张量,则将其转换为PIL格式feature_extractor_input = self.image_processor.numpy_to_pil(image)# 将处理后的图像提取特征,准备进行安全检查safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)# 使用安全检查器检查图像,返回图像及其概念状态image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))# 返回检查后的图像及其概念状态return image, has_nsfw_concept# 从指定的管道中复制的 decode_latents 方法# 解码潜在向量并返回生成的图像def decode_latents(self, latents):# 警告用户该方法已过时,将在1.0.0版本中删除deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"# 调用deprecate函数记录该方法的弃用信息deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)# 使用配置的缩放因子对潜在向量进行缩放latents = 1 / self.vae.config.scaling_factor * latents# 解码潜在向量,返回生成的图像image = self.vae.decode(latents, return_dict=False)[0]# 将图像值从[-1, 1]映射到[0, 1]并限制范围image = (image / 2 + 0.5).clamp(0, 1)# 将图像转换为float32格式以确保兼容性,并将其转换为numpy数组image = image.cpu().permute(0, 2, 3, 1).float().numpy()# 返回处理后的图像return image# 从diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs复制def prepare_extra_step_kwargs(self, generator, eta):# 准备额外的参数以供调度器步骤使用,因调度器的参数签名可能不同# eta (η) 仅在DDIMScheduler中使用,其他调度器将忽略它。# eta在DDIM论文中对应于η:https://arxiv.org/abs/2010.02502# eta的值应在[0, 1]之间# 检查调度器步骤是否接受eta参数accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())# 初始化额外参数字典extra_step_kwargs = {}# 如果接受eta,则将其添加到额外参数中if accepts_eta:extra_step_kwargs["eta"] = eta# 检查调度器步骤是否接受generator参数accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())# 如果接受generator,则将其添加到额外参数中if accepts_generator:extra_step_kwargs["generator"] = generator# 返回准备好的额外参数return extra_step_kwargsdef check_inputs(self,prompt,source_embeds,target_embeds,callback_steps,prompt_embeds=None,):# 检查callback_steps是否为正整数if (callback_steps is None) or (callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)):raise ValueError(f"`callback_steps` has to be a positive integer but is {callback_steps} of type"f" {type(callback_steps)}.")# 确保source_embeds和target_embeds不能同时未定义if source_embeds is None and target_embeds is None:raise ValueError("`source_embeds` and `target_embeds` cannot be undefined.")# 检查prompt和prompt_embeds不能同时被定义if prompt is not None and prompt_embeds is not None:raise ValueError(f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"" only forward one of the two.")# 检查prompt和prompt_embeds不能同时未定义elif prompt is None and prompt_embeds is None:raise ValueError("Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.")# 确保prompt的类型为str或listelif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")#  从 StableDiffusionPipeline 的 prepare_latents 方法复制的内容def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):# 定义潜在张量的形状,包括批次大小、通道数、高度和宽度shape = (batch_size,num_channels_latents,int(height) // self.vae_scale_factor,int(width) // self.vae_scale_factor,)# 检查生成器是否为列表且其长度与批次大小匹配if isinstance(generator, list) and len(generator) != batch_size:raise ValueError(f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"f" size of {batch_size}. Make sure the batch size matches the length of the generators.")# 如果未提供潜在张量,则生成随机张量if latents is None:latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)else:# 如果提供了潜在张量,则将其移动到指定设备latents = latents.to(device)# 按调度器所需的标准差缩放初始噪声latents = latents * self.scheduler.init_noise_sigma# 返回处理后的潜在张量return latents@torch.no_grad()def generate_caption(self, images):"""为给定图像生成标题。"""# 初始化生成标题的文本text = "a photography of"# 保存当前设备prev_device = self.caption_generator.device# 获取执行设备device = self._execution_device# 处理输入图像并转换为张量inputs = self.caption_processor(images, text, return_tensors="pt").to(device=device, dtype=self.caption_generator.dtype)# 将标题生成器移动到指定设备self.caption_generator.to(device)# 生成标题输出outputs = self.caption_generator.generate(**inputs, max_new_tokens=128)# 将标题生成器移回先前设备self.caption_generator.to(prev_device)# 解码输出以获取标题caption = self.caption_processor.batch_decode(outputs, skip_special_tokens=True)[0]# 返回生成的标题return captiondef construct_direction(self, embs_source: torch.Tensor, embs_target: torch.Tensor):"""构造用于引导图像生成过程的编辑方向。"""# 返回目标和源嵌入的均值之差,并增加一个维度return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0)@torch.no_grad()def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch.Tensor:# 获取提示的数量num_prompts = len(prompt)# 初始化嵌入列表embeds = []# 分批处理提示for i in range(0, num_prompts, batch_size):prompt_slice = prompt[i : i + batch_size]# 将提示转换为输入 ID,进行填充和截断input_ids = self.tokenizer(prompt_slice,padding="max_length",max_length=self.tokenizer.model_max_length,truncation=True,return_tensors="pt",).input_ids# 将输入 ID 移动到文本编码器设备input_ids = input_ids.to(self.text_encoder.device)# 获取嵌入并追加到列表embeds.append(self.text_encoder(input_ids)[0])# 将所有嵌入拼接并计算均值return torch.cat(embeds, dim=0).mean(0)[None]# 准备图像的潜在表示,接收图像和其他参数,返回潜在向量def prepare_image_latents(self, image, batch_size, dtype, device, generator=None):# 检查输入图像的类型是否为有效类型if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):# 抛出类型错误,提示用户输入类型不正确raise ValueError(f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}")# 将图像转换到指定的设备和数据类型image = image.to(device=device, dtype=dtype)# 如果图像有四个通道,直接将其作为潜在表示if image.shape[1] == 4:latents = imageelse:# 检查生成器列表的长度是否与批次大小匹配if isinstance(generator, list) and len(generator) != batch_size:# 抛出错误,提示生成器列表长度与批次大小不匹配raise ValueError(f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"f" size of {batch_size}. Make sure the batch size matches the length of the generators.")# 如果生成器是列表,逐个图像编码并生成潜在表示if isinstance(generator, list):latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)]# 将潜在表示合并到一个张量中latents = torch.cat(latents, dim=0)else:# 使用单个生成器编码图像并生成潜在表示latents = self.vae.encode(image).latent_dist.sample(generator)# 根据配置的缩放因子调整潜在表示latents = self.vae.config.scaling_factor * latents# 检查潜在表示的批次大小是否与请求的匹配if batch_size != latents.shape[0]:# 如果可以整除,则扩展潜在表示以匹配批次大小if batch_size % latents.shape[0] == 0:# 构建弃用消息,提示用户行为即将被移除deprecation_message = (f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial"" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"" your script to pass as many initial images as text prompts to suppress this warning.")# 触发弃用警告,提醒用户修改代码deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)# 计算每个图像需要复制的次数additional_latents_per_image = batch_size // latents.shape[0]# 将潜在表示按需重复以匹配批次大小latents = torch.cat([latents] * additional_latents_per_image, dim=0)else:# 抛出错误,提示无法复制图像以匹配批次大小raise ValueError(f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts.")else:# 将潜在表示封装为一个张量latents = torch.cat([latents], dim=0)# 返回最终的潜在表示return latents# 定义一个获取epsilon的函数,输入为模型输出、样本和时间步def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int):# 获取反向调度器的预测类型配置pred_type = self.inverse_scheduler.config.prediction_type# 计算在当前时间步的累积alpha值alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep]# 计算beta值为1减去alpha值beta_prod_t = 1 - alpha_prod_t# 根据预测类型返回相应的结果if pred_type == "epsilon":return model_outputelif pred_type == "sample":# 根据样本和模型输出计算返回值return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5)elif pred_type == "v_prediction":# 根据alpha和beta值结合模型输出和样本计算返回值return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sampleelse:# 如果预测类型无效,抛出异常raise ValueError(f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`")# 定义一个自动相关损失计算的函数,输入为隐藏状态和可选生成器def auto_corr_loss(self, hidden_states, generator=None):# 初始化正则化损失为0reg_loss = 0.0# 遍历隐藏状态的每一个维度for i in range(hidden_states.shape[0]):for j in range(hidden_states.shape[1]):# 选取当前噪声noise = hidden_states[i : i + 1, j : j + 1, :, :]# 进行循环,直到噪声尺寸小于等于8while True:# 随机生成滚动的位移量roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()# 计算并累加正则化损失reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2# 如果噪声尺寸小于等于8,跳出循环if noise.shape[2] <= 8:break# 对噪声进行2x2的平均池化noise = F.avg_pool2d(noise, kernel_size=2)# 返回计算得到的正则化损失return reg_loss# 定义一个计算KL散度的函数,输入为隐藏状态def kl_divergence(self, hidden_states):# 计算隐藏状态的均值mean = hidden_states.mean()# 计算隐藏状态的方差var = hidden_states.var()# 返回KL散度的计算结果return var + mean**2 - 1 - torch.log(var + 1e-7)# 定义调用函数,使用@torch.no_grad()装饰器禁止梯度计算@torch.no_grad()@replace_example_docstring(EXAMPLE_DOC_STRING)def __call__(# 输入参数包括提示、源和目标嵌入、图像的高和宽、推理步骤等prompt: Optional[Union[str, List[str]]] = None,source_embeds: torch.Tensor = None,target_embeds: torch.Tensor = None,height: Optional[int] = None,width: Optional[int] = None,num_inference_steps: int = 50,guidance_scale: float = 7.5,negative_prompt: Optional[Union[str, List[str]]] = None,num_images_per_prompt: Optional[int] = 1,eta: float = 0.0,generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,latents: Optional[torch.Tensor] = None,prompt_embeds: Optional[torch.Tensor] = None,negative_prompt_embeds: Optional[torch.Tensor] = None,cross_attention_guidance_amount: float = 0.1,output_type: Optional[str] = "pil",return_dict: bool = True,callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,callback_steps: Optional[int] = 1,cross_attention_kwargs: Optional[Dict[str, Any]] = None,clip_skip: Optional[int] = None,# 使用@torch.no_grad()和装饰器替换文档字符串@torch.no_grad()@replace_example_docstring(EXAMPLE_INVERT_DOC_STRING)# 定义一个名为 invert 的方法,包含多个可选参数def invert(# 输入提示,默认为 Noneself,prompt: Optional[str] = None,# 输入图像,默认为 Noneimage: PipelineImageInput = None,# 推理步骤的数量,默认为 50num_inference_steps: int = 50,# 指导比例,默认为 1guidance_scale: float = 1,# 随机数生成器,可以是单个或多个生成器,默认为 Nonegenerator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,# 潜在变量,默认为 Nonelatents: Optional[torch.Tensor] = None,# 提示嵌入,默认为 Noneprompt_embeds: Optional[torch.Tensor] = None,# 跨注意力引导量,默认为 0.1cross_attention_guidance_amount: float = 0.1,# 输出类型,默认为 "pil"output_type: Optional[str] = "pil",# 是否返回字典,默认为 Truereturn_dict: bool = True,# 回调函数,默认为 Nonecallback: Optional[Callable[[int, int, torch.Tensor], None]] = None,# 回调步骤,默认为 1callback_steps: Optional[int] = 1,# 跨注意力参数,默认为 Nonecross_attention_kwargs: Optional[Dict[str, Any]] = None,# 自动相关的权重,默认为 20.0lambda_auto_corr: float = 20.0,# KL 散度的权重,默认为 20.0lambda_kl: float = 20.0,# 正则化步骤的数量,默认为 5num_reg_steps: int = 5,# 自动相关滚动的数量,默认为 5num_auto_corr_rolls: int = 5,

.\diffusers\pipelines\deprecated\stable_diffusion_variants\__init__.py

# 从类型检查模块导入类型检查相关功能
from typing import TYPE_CHECKING# 从 utils 模块导入所需的工具和常量
from ....utils import (DIFFUSERS_SLOW_IMPORT,  # 导入用于延迟导入的常量OptionalDependencyNotAvailable,  # 导入可选依赖不可用的异常类_LazyModule,  # 导入延迟模块加载的工具get_objects_from_module,  # 导入从模块获取对象的函数is_torch_available,  # 导入检查 Torch 库是否可用的函数is_transformers_available,  # 导入检查 Transformers 库是否可用的函数
)# 初始化一个空字典用于存放虚拟对象
_dummy_objects = {}
# 初始化一个空字典用于存放模块导入结构
_import_structure = {}try:# 检查 Transformers 和 Torch 库是否都可用if not (is_transformers_available() and is_torch_available()):# 如果不可用,则抛出异常raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:# 导入虚拟对象以避免在依赖不可用时导致错误from ....utils import dummy_torch_and_transformers_objects# 更新虚拟对象字典,填充从虚拟对象模块获取的对象_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:# 如果依赖可用,添加相关的管道到导入结构字典_import_structure["pipeline_cycle_diffusion"] = ["CycleDiffusionPipeline"]_import_structure["pipeline_stable_diffusion_inpaint_legacy"] = ["StableDiffusionInpaintPipelineLegacy"]_import_structure["pipeline_stable_diffusion_model_editing"] = ["StableDiffusionModelEditingPipeline"]_import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"]_import_structure["pipeline_stable_diffusion_pix2pix_zero"] = ["StableDiffusionPix2PixZeroPipeline"]# 根据类型检查标志或慢速导入标志进行条件判断
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:try:# 再次检查依赖是否可用if not (is_transformers_available() and is_torch_available()):# 如果不可用,则抛出异常raise OptionalDependencyNotAvailable()except OptionalDependencyNotAvailable:# 导入虚拟对象以避免在依赖不可用时导致错误from ....utils.dummy_torch_and_transformers_objects import *else:# 导入具体的管道类,确保它们在依赖可用时被加载from .pipeline_cycle_diffusion import CycleDiffusionPipelinefrom .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacyfrom .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipelinefrom .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipelinefrom .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipelineelse:# 如果不是类型检查或慢速导入,则进行懒加载处理import sys# 使用懒加载模块构造当前模块,指定导入结构和模块规格sys.modules[__name__] = _LazyModule(__name__,globals()["__file__"],_import_structure,module_spec=__spec__,)# 遍历虚拟对象字典,将对象属性设置到当前模块for name, value in _dummy_objects.items():setattr(sys.modules[__name__], name, value)

.\diffusers\pipelines\deprecated\stochastic_karras_ve\pipeline_stochastic_karras_ve.py

# 版权声明,表明此文件的版权所有者及保留权利
# 
# 根据 Apache 许可证第 2.0 版(“许可证”)进行许可;
# 除非遵循许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证的副本:
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非适用法律或书面同意,软件在许可证下分发,按“原样”基础,
# 不提供任何形式的保证或条件,无论是明示或暗示的。
# 请参见许可证以获取有关权限和
# 限制的具体规定。# 从 typing 模块导入所需的类型提示
from typing import List, Optional, Tuple, Union# 导入 PyTorch 库
import torch# 从相对路径导入 UNet2DModel 模型
from ....models import UNet2DModel
# 从相对路径导入调度器 KarrasVeScheduler
from ....schedulers import KarrasVeScheduler
# 从相对路径导入随机张量生成工具
from ....utils.torch_utils import randn_tensor
# 从相对路径导入扩散管道和图像输出
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput# 定义 KarrasVePipeline 类,继承自 DiffusionPipeline
class KarrasVePipeline(DiffusionPipeline):r"""无条件图像生成的管道。参数:unet ([`UNet2DModel`]):用于去噪编码图像的 `UNet2DModel`。scheduler ([`KarrasVeScheduler`]):用于与 `unet` 结合去噪编码图像的调度器。"""# 为 linting 添加类型提示unet: UNet2DModel  # 定义 unet 类型为 UNet2DModelscheduler: KarrasVeScheduler  # 定义 scheduler 类型为 KarrasVeScheduler# 初始化函数,接受 UNet2DModel 和 KarrasVeScheduler 作为参数def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):# 调用父类的初始化函数super().__init__()# 注册模块,将 unet 和 scheduler 注册到当前实例中self.register_modules(unet=unet, scheduler=scheduler)# 装饰器,表明此函数不需要梯度计算@torch.no_grad()def __call__(self,batch_size: int = 1,  # 定义批处理大小,默认为 1num_inference_steps: int = 50,  # 定义推理步骤数量,默认为 50generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,  # 可选生成器output_type: Optional[str] = "pil",  # 可选输出类型,默认为 "pil"return_dict: bool = True,  # 是否返回字典,默认为 True**kwargs,  # 允许额外的关键字参数

.\diffusers\pipelines\deprecated\stochastic_karras_ve\__init__.py

# 从 typing 模块导入 TYPE_CHECKING,用于类型检查
from typing import TYPE_CHECKING# 从相对路径导入工具模块中的 DIFFUSERS_SLOW_IMPORT 和 _LazyModule
from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule# 定义一个字典,描述要导入的模块及其对应的类
_import_structure = {"pipeline_stochastic_karras_ve": ["KarrasVePipeline"]}# 如果正在进行类型检查或 DIFFUSERS_SLOW_IMPORT 为真
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:# 从 pipeline_stochastic_karras_ve 模块导入 KarrasVePipeline 类from .pipeline_stochastic_karras_ve import KarrasVePipeline# 否则
else:# 导入 sys 模块,用于动态修改模块import sys# 使用 _LazyModule 创建懒加载模块,并将其赋值给当前模块名称sys.modules[__name__] = _LazyModule(__name__,  # 当前模块的名称globals()["__file__"],  # 当前模块的文件路径_import_structure,  # 模块结构字典module_spec=__spec__,  # 当前模块的规格)

.\diffusers\pipelines\deprecated\versatile_diffusion\modeling_text_unet.py

# 从 typing 模块导入各种类型注解
from typing import Any, Dict, List, Optional, Tuple, Union# 导入 numpy 库,用于数组和矩阵操作
import numpy as np
# 导入 PyTorch 库,进行深度学习模型的构建和训练
import torch
# 导入 PyTorch 的神经网络模块
import torch.nn as nn
# 导入 PyTorch 的功能性模块,提供常用操作
import torch.nn.functional as F# 从 diffusers.utils 模块导入 deprecate 函数,用于处理弃用警告
from diffusers.utils import deprecate# 导入配置相关的类和函数
from ....configuration_utils import ConfigMixin, register_to_config
# 导入模型相关的基类
from ....models import ModelMixin
# 导入激活函数获取工具
from ....models.activations import get_activation
# 导入注意力处理器相关组件
from ....models.attention_processor import (ADDED_KV_ATTENTION_PROCESSORS,  # 额外键值注意力处理器CROSS_ATTENTION_PROCESSORS,      # 交叉注意力处理器Attention,                       # 注意力机制类AttentionProcessor,              # 注意力处理器基类AttnAddedKVProcessor,            # 额外键值注意力处理器类AttnAddedKVProcessor2_0,         # 版本 2.0 的额外键值注意力处理器AttnProcessor,                   # 基础注意力处理器
)
# 导入嵌入层相关组件
from ....models.embeddings import (GaussianFourierProjection,        # 高斯傅里叶投影类ImageHintTimeEmbedding,           # 图像提示时间嵌入类ImageProjection,                  # 图像投影类ImageTimeEmbedding,               # 图像时间嵌入类TextImageProjection,              # 文本图像投影类TextImageTimeEmbedding,           # 文本图像时间嵌入类TextTimeEmbedding,                # 文本时间嵌入类TimestepEmbedding,                # 时间步嵌入类Timesteps,                        # 时间步类
)
# 导入 ResNet 相关组件
from ....models.resnet import ResnetBlockCondNorm2D
# 导入 2D 双重变换器模型
from ....models.transformers.dual_transformer_2d import DualTransformer2DModel
# 导入 2D 变换器模型
from ....models.transformers.transformer_2d import Transformer2DModel
# 导入 2D 条件 UNet 输出类
from ....models.unets.unet_2d_condition import UNet2DConditionOutput
# 导入工具函数和常量
from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
# 导入 PyTorch 相关工具函数
from ....utils.torch_utils import apply_freeu# 创建日志记录器实例
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name# 定义获取下采样块的函数
def get_down_block(down_block_type,                    # 下采样块类型num_layers,                         # 层数in_channels,                        # 输入通道数out_channels,                       # 输出通道数temb_channels,                      # 时间嵌入通道数add_downsample,                    # 是否添加下采样resnet_eps,                         # ResNet 中的 epsilon 值resnet_act_fn,                     # ResNet 激活函数num_attention_heads,               # 注意力头数量transformer_layers_per_block,      # 每个块中的变换器层数attention_type,                    # 注意力类型attention_head_dim,                # 注意力头维度resnet_groups=None,                 # ResNet 组数(可选)cross_attention_dim=None,           # 交叉注意力维度(可选)downsample_padding=None,            # 下采样填充(可选)dual_cross_attention=False,         # 是否使用双重交叉注意力use_linear_projection=False,        # 是否使用线性投影only_cross_attention=False,         # 是否仅使用交叉注意力upcast_attention=False,             # 是否上升注意力resnet_time_scale_shift="default",  # ResNet 时间缩放偏移resnet_skip_time_act=False,         # ResNet 是否跳过时间激活resnet_out_scale_factor=1.0,       # ResNet 输出缩放因子cross_attention_norm=None,          # 交叉注意力归一化(可选)dropout=0.0,                       # dropout 概率
):# 如果下采样块类型以 "UNetRes" 开头,则去掉前缀down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type# 如果下采样块类型为 "DownBlockFlat",则返回相应的块实例if down_block_type == "DownBlockFlat":return DownBlockFlat(num_layers=num_layers,        # 层数in_channels=in_channels,      # 输入通道数out_channels=out_channels,    # 输出通道数temb_channels=temb_channels,  # 时间嵌入通道数dropout=dropout,              # dropout 概率add_downsample=add_downsample, # 是否添加下采样resnet_eps=resnet_eps,        # ResNet 中的 epsilon 值resnet_act_fn=resnet_act_fn,  # ResNet 激活函数resnet_groups=resnet_groups,   # ResNet 组数(可选)downsample_padding=downsample_padding, # 下采样填充(可选)resnet_time_scale_shift=resnet_time_scale_shift, # ResNet 时间缩放偏移)# 检查下采样块类型是否为 CrossAttnDownBlockFlatelif down_block_type == "CrossAttnDownBlockFlat":# 如果没有指定 cross_attention_dim,则抛出值错误if cross_attention_dim is None:raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockFlat")# 创建并返回 CrossAttnDownBlockFlat 实例,传入所需参数return CrossAttnDownBlockFlat(# 设置网络层数num_layers=num_layers,# 设置输入通道数in_channels=in_channels,# 设置输出通道数out_channels=out_channels,# 设置时间嵌入通道数temb_channels=temb_channels,# 设置 dropout 比率dropout=dropout,# 设置是否添加下采样层add_downsample=add_downsample,# 设置 ResNet 中的 epsilon 参数resnet_eps=resnet_eps,# 设置 ResNet 激活函数resnet_act_fn=resnet_act_fn,# 设置 ResNet 组的数量resnet_groups=resnet_groups,# 设置下采样的填充参数downsample_padding=downsample_padding,# 设置交叉注意力维度cross_attention_dim=cross_attention_dim,# 设置注意力头的数量num_attention_heads=num_attention_heads,# 设置是否使用双交叉注意力dual_cross_attention=dual_cross_attention,# 设置是否使用线性投影use_linear_projection=use_linear_projection,# 设置是否仅使用交叉注意力only_cross_attention=only_cross_attention,# 设置 ResNet 的时间尺度偏移resnet_time_scale_shift=resnet_time_scale_shift,)# 如果下采样块类型不被支持,则抛出值错误raise ValueError(f"{down_block_type} is not supported.")
# 根据给定参数创建上采样块的函数
def get_up_block(# 上采样块类型up_block_type,# 网络层数num_layers,# 输入通道数in_channels,# 输出通道数out_channels,# 上一层输出通道数prev_output_channel,# 条件嵌入通道数temb_channels,# 是否添加上采样add_upsample,# ResNet 的 epsilon 值resnet_eps,# ResNet 的激活函数resnet_act_fn,# 注意力头数num_attention_heads,# 每个块的 Transformer 层数transformer_layers_per_block,# 分辨率索引resolution_idx,# 注意力类型attention_type,# 注意力头维度attention_head_dim,# ResNet 组数,可选参数resnet_groups=None,# 跨注意力维度,可选参数cross_attention_dim=None,# 是否使用双重跨注意力dual_cross_attention=False,# 是否使用线性投影use_linear_projection=False,# 是否仅使用跨注意力only_cross_attention=False,# 是否上溯注意力upcast_attention=False,# ResNet 时间尺度偏移,默认为 "default"resnet_time_scale_shift="default",# ResNet 是否跳过时间激活resnet_skip_time_act=False,# ResNet 输出缩放因子resnet_out_scale_factor=1.0,# 跨注意力归一化类型,可选参数cross_attention_norm=None,# dropout 概率dropout=0.0,
):# 如果上采样块类型以 "UNetRes" 开头,去掉前缀up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type# 如果块类型是 "UpBlockFlat",则返回相应的实例if up_block_type == "UpBlockFlat":return UpBlockFlat(# 传入各个参数num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,prev_output_channel=prev_output_channel,temb_channels=temb_channels,dropout=dropout,add_upsample=add_upsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,resnet_time_scale_shift=resnet_time_scale_shift,)# 如果块类型是 "CrossAttnUpBlockFlat"elif up_block_type == "CrossAttnUpBlockFlat":# 检查跨注意力维度是否指定if cross_attention_dim is None:raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockFlat")# 返回相应的跨注意力上采样块实例return CrossAttnUpBlockFlat(# 传入各个参数num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,prev_output_channel=prev_output_channel,temb_channels=temb_channels,dropout=dropout,add_upsample=add_upsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,cross_attention_dim=cross_attention_dim,num_attention_heads=num_attention_heads,dual_cross_attention=dual_cross_attention,use_linear_projection=use_linear_projection,only_cross_attention=only_cross_attention,resnet_time_scale_shift=resnet_time_scale_shift,)# 如果块类型不支持,抛出异常raise ValueError(f"{up_block_type} is not supported.")# 定义一个 Fourier 嵌入器类,继承自 nn.Module
class FourierEmbedder(nn.Module):# 初始化方法,设置频率和温度def __init__(self, num_freqs=64, temperature=100):# 调用父类构造函数super().__init__()# 保存频率数self.num_freqs = num_freqs# 保存温度self.temperature = temperature# 计算频率带freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)# 扩展维度以便后续操作freq_bands = freq_bands[None, None, None]# 注册频率带为缓冲区,设为非持久性self.register_buffer("freq_bands", freq_bands, persistent=False)# 定义调用方法,用于处理输入def __call__(self, x):# 将输入与频率带相乘x = self.freq_bands * x.unsqueeze(-1)# 返回处理后的结果,包含正弦和余弦return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)# 定义 GLIGEN 文本边界框投影类,继承自 nn.Module
class GLIGENTextBoundingboxProjection(nn.Module):# 初始化方法,设置对象的基本参数def __init__(self, positive_len, out_dim, feature_type, fourier_freqs=8):# 调用父类的初始化方法super().__init__()# 存储正样本的长度self.positive_len = positive_len# 存储输出的维度self.out_dim = out_dim# 初始化傅里叶嵌入器,设置频率数量self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)# 计算位置特征的维度,包含 sin 和 cosself.position_dim = fourier_freqs * 2 * 4  # 2: sin/cos, 4: xyxy# 如果输出维度是元组,取第一个元素if isinstance(out_dim, tuple):out_dim = out_dim[0]# 根据特征类型设置线性层if feature_type == "text-only":self.linears = nn.Sequential(# 第一层线性变换,输入为正样本长度加位置维度nn.Linear(self.positive_len + self.position_dim, 512),# 激活函数使用 SiLUnn.SiLU(),# 第二层线性变换nn.Linear(512, 512),# 激活函数使用 SiLUnn.SiLU(),# 输出层nn.Linear(512, out_dim),)# 定义一个全为零的参数,用于文本特征的空值处理self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))# 处理文本和图像的特征类型elif feature_type == "text-image":self.linears_text = nn.Sequential(# 第一层线性变换nn.Linear(self.positive_len + self.position_dim, 512),# 激活函数使用 SiLUnn.SiLU(),# 第二层线性变换nn.Linear(512, 512),# 激活函数使用 SiLUnn.SiLU(),# 输出层nn.Linear(512, out_dim),)self.linears_image = nn.Sequential(# 第一层线性变换nn.Linear(self.positive_len + self.position_dim, 512),# 激活函数使用 SiLUnn.SiLU(),# 第二层线性变换nn.Linear(512, 512),# 激活函数使用 SiLUnn.SiLU(),# 输出层nn.Linear(512, out_dim),)# 定义文本特征的空值处理参数self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))# 定义图像特征的空值处理参数self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))# 定义位置特征的空值处理参数self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))# 前向传播方法定义def forward(self,boxes,masks,positive_embeddings=None,phrases_masks=None,image_masks=None,phrases_embeddings=None,image_embeddings=None,):# 在最后一维增加一个维度,便于后续操作masks = masks.unsqueeze(-1)# 通过傅里叶嵌入函数生成 boxes 的嵌入表示xyxy_embedding = self.fourier_embedder(boxes)# 获取空白位置的特征,并调整形状为 (1, 1, -1)xyxy_null = self.null_position_feature.view(1, 1, -1)# 计算加权嵌入,结合 masks 和空白位置特征xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null# 如果存在正样本嵌入if positive_embeddings:# 获取正样本的空白特征,并调整形状为 (1, 1, -1)positive_null = self.null_positive_feature.view(1, 1, -1)# 计算正样本嵌入的加权,结合 masks 和空白特征positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null# 将正样本嵌入与 xyxy 嵌入连接并通过线性层处理objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))else:# 在最后一维增加一个维度,便于后续操作phrases_masks = phrases_masks.unsqueeze(-1)image_masks = image_masks.unsqueeze(-1)# 获取文本和图像的空白特征,并调整形状为 (1, 1, -1)text_null = self.null_text_feature.view(1, 1, -1)image_null = self.null_image_feature.view(1, 1, -1)# 计算文本嵌入的加权,结合 phrases_masks 和空白特征phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null# 计算图像嵌入的加权,结合 image_masks 和空白特征image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null# 将文本嵌入与 xyxy 嵌入连接并通过文本线性层处理objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1))# 将图像嵌入与 xyxy 嵌入连接并通过图像线性层处理objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1))# 将文本和图像的处理结果在维度 1 上连接objs = torch.cat([objs_text, objs_image], dim=1)# 返回最终的对象结果return objs
# 定义一个名为 UNetFlatConditionModel 的类,继承自 ModelMixin 和 ConfigMixin
class UNetFlatConditionModel(ModelMixin, ConfigMixin):r"""一个条件 2D UNet 模型,它接收一个有噪声的样本、条件状态和时间步,并返回一个样本形状的输出。该模型继承自 [`ModelMixin`]。请查看父类文档以了解其为所有模型实现的通用方法(例如下载或保存)。"""# 设置该模型支持梯度检查点_supports_gradient_checkpointing = True# 定义不进行拆分的模块名称列表_no_split_modules = ["BasicTransformerBlock", "ResnetBlockFlat", "CrossAttnUpBlockFlat"]# 注册到配置的装饰器@register_to_config# 初始化方法,设置类的基本参数def __init__(# 样本大小,可选参数self,sample_size: Optional[int] = None,# 输入通道数,默认为4in_channels: int = 4,# 输出通道数,默认为4out_channels: int = 4,# 是否将输入样本居中,默认为Falsecenter_input_sample: bool = False,# 是否将正弦函数翻转为余弦函数,默认为Trueflip_sin_to_cos: bool = True,# 频率偏移量,默认为0freq_shift: int = 0,# 向下采样块的类型,默认为三个CrossAttnDownBlockFlat和一个DownBlockFlatdown_block_types: Tuple[str] = ("CrossAttnDownBlockFlat","CrossAttnDownBlockFlat","CrossAttnDownBlockFlat","DownBlockFlat",),# 中间块的类型,默认为UNetMidBlockFlatCrossAttnmid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn",# 向上采样块的类型,默认为一个UpBlockFlat和三个CrossAttnUpBlockFlatup_block_types: Tuple[str] = ("UpBlockFlat","CrossAttnUpBlockFlat","CrossAttnUpBlockFlat","CrossAttnUpBlockFlat",),# 是否仅使用交叉注意力,默认为Falseonly_cross_attention: Union[bool, Tuple[bool]] = False,# 块输出通道数,默认为320, 640, 1280, 1280block_out_channels: Tuple[int] = (320, 640, 1280, 1280),# 每个块的层数,默认为2layers_per_block: Union[int, Tuple[int]] = 2,# 向下采样时的填充大小,默认为1downsample_padding: int = 1,# 中间块的缩放因子,默认为1mid_block_scale_factor: float = 1,# dropout比例,默认为0.0dropout: float = 0.0,# 激活函数类型,默认为siluact_fn: str = "silu",# 归一化的组数,可选参数,默认为32norm_num_groups: Optional[int] = 32,# 归一化的epsilon值,默认为1e-5norm_eps: float = 1e-5,# 交叉注意力的维度,默认为1280cross_attention_dim: Union[int, Tuple[int]] = 1280,# 每个块的变换器层数,默认为1transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,# 反向变换器层数的可选配置reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,# 编码器隐藏维度的可选参数encoder_hid_dim: Optional[int] = None,# 编码器隐藏维度类型的可选参数encoder_hid_dim_type: Optional[str] = None,# 注意力头的维度,默认为8attention_head_dim: Union[int, Tuple[int]] = 8,# 注意力头数量的可选参数num_attention_heads: Optional[Union[int, Tuple[int]]] = None,# 是否使用双交叉注意力,默认为Falsedual_cross_attention: bool = False,# 是否使用线性投影,默认为Falseuse_linear_projection: bool = False,# 类嵌入类型的可选参数class_embed_type: Optional[str] = None,# 附加嵌入类型的可选参数addition_embed_type: Optional[str] = None,# 附加时间嵌入维度的可选参数addition_time_embed_dim: Optional[int] = None,# 类嵌入数量的可选参数num_class_embeds: Optional[int] = None,# 是否向上投射注意力,默认为Falseupcast_attention: bool = False,# ResNet时间缩放偏移的默认值resnet_time_scale_shift: str = "default",# ResNet跳过时间激活的设置,默认为Falseresnet_skip_time_act: bool = False,# ResNet输出缩放因子,默认为1.0resnet_out_scale_factor: int = 1.0,# 时间嵌入类型,默认为positionaltime_embedding_type: str = "positional",# 时间嵌入维度的可选参数time_embedding_dim: Optional[int] = None,# 时间嵌入激活函数的可选参数time_embedding_act_fn: Optional[str] = None,# 时间步后激活的可选参数timestep_post_act: Optional[str] = None,# 时间条件投影维度的可选参数time_cond_proj_dim: Optional[int] = None,# 输入卷积核的大小,默认为3conv_in_kernel: int = 3,# 输出卷积核的大小,默认为3conv_out_kernel: int = 3,# 投影类嵌入输入维度的可选参数projection_class_embeddings_input_dim: Optional[int] = None,# 注意力类型,默认为defaultattention_type: str = "default",# 类嵌入是否连接,默认为Falseclass_embeddings_concat: bool = False,# 中间块是否仅使用交叉注意力的可选参数mid_block_only_cross_attention: Optional[bool] = None,# 交叉注意力的归一化类型的可选参数cross_attention_norm: Optional[str] = None,# 附加嵌入类型的头数量,默认为64addition_embed_type_num_heads=64,# 声明该方法为属性@property# 定义一个返回注意力处理器字典的方法def attn_processors(self) -> Dict[str, AttentionProcessor]:r"""返回值:`dict` 的注意力处理器: 一个字典,包含模型中使用的所有注意力处理器,以其权重名称为索引。"""# 初始化一个空字典以递归存储处理器processors = {}# 定义一个递归函数来添加处理器def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):# 如果模块有获取处理器的方法,则添加到字典中if hasattr(module, "get_processor"):processors[f"{name}.processor"] = module.get_processor()# 遍历模块的子模块,递归调用函数for sub_name, child in module.named_children():fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)# 返回处理器字典return processors# 遍历当前模块的子模块,并调用递归函数for name, module in self.named_children():fn_recursive_add_processors(name, module, processors)# 返回所有注意力处理器的字典return processors# 定义一个设置注意力处理器的方法def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):r"""设置用于计算注意力的处理器。参数:processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):实例化的处理器类或处理器类的字典,将被设置为所有 `Attention` 层的处理器。如果 `processor` 是字典,则键需要定义对应的交叉注意力处理器的路径。在设置可训练的注意力处理器时,强烈推荐这种做法。"""# 计算当前注意力处理器的数量count = len(self.attn_processors.keys())# 如果传入的是字典且数量不匹配,则引发错误if isinstance(processor, dict) and len(processor) != count:raise ValueError(f"传入的是处理器字典,但处理器的数量 {len(processor)} 与注意力层的数量 {count} 不匹配。"f" 请确保传入 {count} 个处理器类。")# 定义一个递归函数来设置注意力处理器def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):# 如果模块有设置处理器的方法,则根据传入的处理器设置if hasattr(module, "set_processor"):if not isinstance(processor, dict):module.set_processor(processor)else:module.set_processor(processor.pop(f"{name}.processor"))# 遍历模块的子模块,递归调用函数for sub_name, child in module.named_children():fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)# 遍历当前模块的子模块,并调用递归函数for name, module in self.named_children():fn_recursive_attn_processor(name, module, processor)# 设置默认的注意力处理器def set_default_attn_processor(self):"""禁用自定义注意力处理器,并设置默认的注意力实现。"""# 检查所有注意力处理器是否属于已添加的 KV 注意力处理器if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):# 使用 AttnAddedKVProcessor 作为处理器processor = AttnAddedKVProcessor()# 检查所有注意力处理器是否属于交叉注意力处理器elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):# 使用 AttnProcessor 作为处理器processor = AttnProcessor()else:# 如果处理器类型不匹配,则引发值错误raise ValueError(f"当注意力处理器的类型为 {next(iter(self.attn_processors.values()))} 时,无法调用 `set_default_attn_processor`")# 设置选定的注意力处理器self.set_attn_processor(processor)# 设置梯度检查点def _set_gradient_checkpointing(self, module, value=False):# 如果模块具有 gradient_checkpointing 属性,则设置其值if hasattr(module, "gradient_checkpointing"):module.gradient_checkpointing = value# 启用 FreeU 机制def enable_freeu(self, s1, s2, b1, b2):r"""启用来自 https://arxiv.org/abs/2309.11497 的 FreeU 机制。缩放因子的后缀表示应用的阶段块。请参考 [官方库](https://github.com/ChenyangSi/FreeU) 以获取已知在不同管道(如 Stable Diffusion v1、v2 和 Stable Diffusion XL)中表现良好的值组合。参数:s1 (`float`):阶段 1 的缩放因子,用于减弱跳过特征的贡献。这是为了减轻增强去噪过程中的“过平滑效应”。s2 (`float`):阶段 2 的缩放因子,用于减弱跳过特征的贡献。这是为了减轻增强去噪过程中的“过平滑效应”。b1 (`float`): 阶段 1 的缩放因子,用于增强主干特征的贡献。b2 (`float`): 阶段 2 的缩放因子,用于增强主干特征的贡献。"""# 遍历上采样块并设置相应的缩放因子for i, upsample_block in enumerate(self.up_blocks):setattr(upsample_block, "s1", s1)  # 设置阶段 1 的缩放因子setattr(upsample_block, "s2", s2)  # 设置阶段 2 的缩放因子setattr(upsample_block, "b1", b1)  # 设置阶段 1 的主干特征缩放因子setattr(upsample_block, "b2", b2)  # 设置阶段 2 的主干特征缩放因子# 禁用 FreeU 机制def disable_freeu(self):"""禁用 FreeU 机制。"""freeu_keys = {"s1", "s2", "b1", "b2"}  # FreeU 机制的关键字集合# 遍历上采样块并将关键字的值设置为 Nonefor i, upsample_block in enumerate(self.up_blocks):for k in freeu_keys:# 如果上采样块具有该属性或属性值不为 None,则将其设置为 Noneif hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:setattr(upsample_block, k, None)# 定义一个用于融合 QKV 投影的函数def fuse_qkv_projections(self):# 文档字符串,描述该函数的作用及实验性质"""Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)are fused. For cross-attention modules, key and value projection matrices are fused.<Tip warning={true}>This API is 🧪 experimental.</Tip>"""# 初始化原始注意力处理器为 Noneself.original_attn_processors = None# 遍历所有注意力处理器for _, attn_processor in self.attn_processors.items():# 检查处理器的类名是否包含 "Added"if "Added" in str(attn_processor.__class__.__name__):# 如果是,抛出错误提示不支持融合raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")# 保存原始的注意力处理器self.original_attn_processors = self.attn_processors# 遍历所有模块for module in self.modules():# 检查模块是否是 Attention 类的实例if isinstance(module, Attention):# 融合投影module.fuse_projections(fuse=True)# 定义一个用于取消 QKV 投影融合的函数def unfuse_qkv_projections(self):# 文档字符串,描述该函数的作用及实验性质"""Disables the fused QKV projection if enabled.<Tip warning={true}>This API is 🧪 experimental.</Tip>"""# 检查原始注意力处理器是否不为 Noneif self.original_attn_processors is not None:# 恢复到原始的注意力处理器self.set_attn_processor(self.original_attn_processors)# 定义一个用于卸载 LoRA 权重的函数def unload_lora(self):# 文档字符串,描述该函数的作用"""Unloads LoRA weights."""# 发出卸载的弃用警告deprecate("unload_lora","0.28.0","Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",)# 遍历所有模块for module in self.modules():# 检查模块是否具有 set_lora_layer 属性if hasattr(module, "set_lora_layer"):# 将 LoRA 层设置为 Nonemodule.set_lora_layer(None)# 定义前向传播函数def forward(self,sample: torch.Tensor,timestep: Union[torch.Tensor, float, int],encoder_hidden_states: torch.Tensor,class_labels: Optional[torch.Tensor] = None,timestep_cond: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,cross_attention_kwargs: Optional[Dict[str, Any]] = None,added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,mid_block_additional_residual: Optional[torch.Tensor] = None,down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,encoder_attention_mask: Optional[torch.Tensor] = None,return_dict: bool = True,
# 定义一个继承自 nn.Linear 的线性多维层
class LinearMultiDim(nn.Linear):# 初始化方法,接受输入特征、输出特征及其他参数def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs):# 如果 in_features 是整数,则将其转换为包含三个维度的列表in_features = [in_features, second_dim, 1] if isinstance(in_features, int) else list(in_features)# 如果未提供 out_features,则将其设置为 in_featuresif out_features is None:out_features = in_features# 如果 out_features 是整数,则转换为包含三个维度的列表out_features = [out_features, second_dim, 1] if isinstance(out_features, int) else list(out_features)# 保存输入特征的多维信息self.in_features_multidim = in_features# 保存输出特征的多维信息self.out_features_multidim = out_features# 调用父类的初始化方法,计算输入和输出特征的总数量super().__init__(np.array(in_features).prod(), np.array(out_features).prod())# 定义前向传播方法def forward(self, input_tensor, *args, **kwargs):# 获取输入张量的形状shape = input_tensor.shape# 获取输入特征的维度数量n_dim = len(self.in_features_multidim)# 将输入张量重塑为适合线性层的形状input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features)# 调用父类的前向传播方法,得到输出张量output_tensor = super().forward(input_tensor)# 将输出张量重塑为目标形状output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_features_multidim)# 返回输出张量return output_tensor# 定义一个平坦的残差块类,继承自 nn.Module
class ResnetBlockFlat(nn.Module):# 初始化方法,接受多个参数,包括通道数、丢弃率等def __init__(self,*,in_channels,out_channels=None,dropout=0.0,temb_channels=512,groups=32,groups_out=None,pre_norm=True,eps=1e-6,time_embedding_norm="default",use_in_shortcut=None,second_dim=4,**kwargs,# 初始化方法的结束,接收参数):# 调用父类的初始化方法super().__init__()# 是否进行预归一化,设置为传入的值self.pre_norm = pre_norm# 将预归一化设置为 Trueself.pre_norm = True# 如果输入通道是整数,则构造一个包含三个维度的列表in_channels = [in_channels, second_dim, 1] if isinstance(in_channels, int) else list(in_channels)# 计算输入通道数的乘积self.in_channels_prod = np.array(in_channels).prod()# 保存输入通道的多维信息self.channels_multidim = in_channels# 如果输出通道不为 Noneif out_channels is not None:# 如果输出通道是整数,构造一个包含三个维度的列表out_channels = [out_channels, second_dim, 1] if isinstance(out_channels, int) else list(out_channels)# 计算输出通道数的乘积out_channels_prod = np.array(out_channels).prod()# 保存输出通道的多维信息self.out_channels_multidim = out_channelselse:# 如果输出通道为 None,则输出通道乘积等于输入通道乘积out_channels_prod = self.in_channels_prod# 输出通道的多维信息与输入通道相同self.out_channels_multidim = self.channels_multidim# 保存时间嵌入的归一化状态self.time_embedding_norm = time_embedding_norm# 如果输出组数为 None,使用传入的组数if groups_out is None:groups_out = groups# 创建第一个归一化层,使用组归一化self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True)# 创建第一个卷积层,使用输入通道和输出通道乘积self.conv1 = torch.nn.Conv2d(self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0)# 如果时间嵌入通道不为 Noneif temb_channels is not None:# 创建时间嵌入投影层self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod)else:# 如果时间嵌入通道为 None,则不进行投影self.time_emb_proj = None# 创建第二个归一化层,使用输出组数和输出通道乘积self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels_prod, eps=eps, affine=True)# 创建丢弃层,使用传入的丢弃率self.dropout = torch.nn.Dropout(dropout)# 创建第二个卷积层,使用输出通道乘积self.conv2 = torch.nn.Conv2d(out_channels_prod, out_channels_prod, kernel_size=1, padding=0)# 设置非线性激活函数为 SiLUself.nonlinearity = nn.SiLU()# 检查是否使用输入短路,如果短路使用参数为 None,则根据通道数判断self.use_in_shortcut = (self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut)# 初始化快捷连接卷积为 Noneself.conv_shortcut = None# 如果使用输入短路if self.use_in_shortcut:# 创建快捷连接卷积层self.conv_shortcut = torch.nn.Conv2d(self.in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0)# 定义前向传播方法,接收输入张量和时间嵌入def forward(self, input_tensor, temb):# 获取输入张量的形状shape = input_tensor.shape# 获取多维通道的维度数n_dim = len(self.channels_multidim)# 调整输入张量形状,合并通道维度并增加两个维度input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_channels_prod, 1, 1)# 将张量视图转换为指定形状,保持通道数并增加两个维度input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1)# 初始化隐藏状态为输入张量hidden_states = input_tensor# 对隐藏状态进行归一化处理hidden_states = self.norm1(hidden_states)# 应用非线性激活函数hidden_states = self.nonlinearity(hidden_states)# 通过第一个卷积层处理隐藏状态hidden_states = self.conv1(hidden_states)# 如果时间嵌入不为空if temb is not None:# 对时间嵌入进行非线性处理并调整形状temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]# 将时间嵌入与隐藏状态相加hidden_states = hidden_states + temb# 对隐藏状态进行第二次归一化处理hidden_states = self.norm2(hidden_states)# 再次应用非线性激活函数hidden_states = self.nonlinearity(hidden_states)# 对隐藏状态应用 dropout 操作hidden_states = self.dropout(hidden_states)# 通过第二个卷积层处理隐藏状态hidden_states = self.conv2(hidden_states)# 如果存在短路卷积层if self.conv_shortcut is not None:# 通过短路卷积层处理输入张量input_tensor = self.conv_shortcut(input_tensor)# 将输入张量与隐藏状态相加,生成输出张量output_tensor = input_tensor + hidden_states# 将输出张量调整为指定形状,去掉多余的维度output_tensor = output_tensor.view(*shape[0:-n_dim], -1)# 再次调整输出张量的形状,匹配输出通道的多维结构output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_channels_multidim)# 返回最终的输出张量return output_tensor
# 定义一个名为 DownBlockFlat 的类,继承自 nn.Module
class DownBlockFlat(nn.Module):# 初始化方法,接受多个参数用于配置模型def __init__(self,in_channels: int,  # 输入通道数out_channels: int,  # 输出通道数temb_channels: int,  # 时间嵌入通道数dropout: float = 0.0,  # dropout 概率num_layers: int = 1,  # ResNet 层数resnet_eps: float = 1e-6,  # ResNet 的 epsilon 值resnet_time_scale_shift: str = "default",  # ResNet 的时间缩放偏移resnet_act_fn: str = "swish",  # ResNet 的激活函数resnet_groups: int = 32,  # ResNet 的分组数resnet_pre_norm: bool = True,  # 是否在 ResNet 前进行归一化output_scale_factor: float = 1.0,  # 输出缩放因子add_downsample: bool = True,  # 是否添加下采样层downsample_padding: int = 1,  # 下采样时的填充):# 调用父类的初始化方法super().__init__()# 初始化一个空列表,用于存放 ResNet 层resnets = []# 循环创建指定数量的 ResNet 层for i in range(num_layers):# 第一层使用输入通道,之后的层使用输出通道in_channels = in_channels if i == 0 else out_channels# 将 ResNet 层添加到列表中resnets.append(ResnetBlockFlat(in_channels=in_channels,  # 当前层的输入通道数out_channels=out_channels,  # 当前层的输出通道数temb_channels=temb_channels,  # 时间嵌入通道数eps=resnet_eps,  # epsilon 值groups=resnet_groups,  # 分组数dropout=dropout,  # dropout 概率time_embedding_norm=resnet_time_scale_shift,  # 时间嵌入归一化方式non_linearity=resnet_act_fn,  # 激活函数output_scale_factor=output_scale_factor,  # 输出缩放因子pre_norm=resnet_pre_norm,  # 是否前归一化))# 将 ResNet 层列表转为 nn.ModuleList 以便于管理self.resnets = nn.ModuleList(resnets)# 根据参数决定是否添加下采样层if add_downsample:self.downsamplers = nn.ModuleList([LinearMultiDim(out_channels,  # 输入通道数use_conv=True,  # 使用卷积out_channels=out_channels,  # 输出通道数padding=downsample_padding,  # 填充name="op"  # 下采样层名称)])else:# 如果不添加下采样层,设置为 Noneself.downsamplers = None# 初始化梯度检查点为 Falseself.gradient_checkpointing = False# 定义前向传播方法def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None  # 输入的隐藏状态和可选的时间嵌入) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:# 初始化输出状态为一个空元组output_states = ()# 遍历所有的 ResNet 层for resnet in self.resnets:# 如果在训练模式且开启了梯度检查点if self.training and self.gradient_checkpointing:# 定义一个创建自定义前向传播的方法def create_custom_forward(module):# 定义自定义前向传播函数def custom_forward(*inputs):return module(*inputs)  # 调用模块进行前向传播return custom_forward# 检查 PyTorch 版本,使用不同的调用方式if is_torch_version(">=", "1.11.0"):hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb, use_reentrant=False  # 进行梯度检查点的前向传播)else:hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb  # 进行梯度检查点的前向传播)else:# 正常调用 ResNet 层进行前向传播hidden_states = resnet(hidden_states, temb)# 将当前隐藏状态添加到输出状态中output_states = output_states + (hidden_states,)# 如果存在下采样层if self.downsamplers is not None:# 遍历所有下采样层for downsampler in self.downsamplers:hidden_states = downsampler(hidden_states)  # 对隐藏状态进行下采样# 将下采样后的隐藏状态添加到输出状态中output_states = output_states + (hidden_states,)# 返回最终的隐藏状态和所有输出状态return hidden_states, output_states
# 定义一个名为 CrossAttnDownBlockFlat 的类,继承自 nn.Module
class CrossAttnDownBlockFlat(nn.Module):# 初始化方法,定义类的属性def __init__(# 输入通道数self,in_channels: int,# 输出通道数out_channels: int,# 时间嵌入通道数temb_channels: int,# dropout 概率,默认为 0.0dropout: float = 0.0,# 层数,默认为 1num_layers: int = 1,# 每个块的变换器层数,可以是整数或整数元组,默认为 1transformer_layers_per_block: Union[int, Tuple[int]] = 1,# ResNet 的 epsilon 值,默认为 1e-6resnet_eps: float = 1e-6,# ResNet 的时间尺度偏移设置,默认为 "default"resnet_time_scale_shift: str = "default",# ResNet 的激活函数,默认为 "swish"resnet_act_fn: str = "swish",# ResNet 的组数,默认为 32resnet_groups: int = 32,# 是否使用预归一化,默认为 Trueresnet_pre_norm: bool = True,# 注意力头的数量,默认为 1num_attention_heads: int = 1,# 交叉注意力的维度,默认为 1280cross_attention_dim: int = 1280,# 输出缩放因子,默认为 1.0output_scale_factor: float = 1.0,# 下采样的填充大小,默认为 1downsample_padding: int = 1,# 是否添加下采样层,默认为 Trueadd_downsample: bool = True,# 是否使用双重交叉注意力,默认为 Falsedual_cross_attention: bool = False,# 是否使用线性投影,默认为 Falseuse_linear_projection: bool = False,# 是否只使用交叉注意力,默认为 Falseonly_cross_attention: bool = False,# 是否上溯注意力,默认为 Falseupcast_attention: bool = False,# 注意力类型,默认为 "default"attention_type: str = "default",):# 调用父类的构造函数以初始化基类super().__init__()# 初始化存储 ResNet 块的列表resnets = []# 初始化存储注意力模型的列表attentions = []# 设置是否使用交叉注意力的标志self.has_cross_attention = True# 设置注意力头的数量self.num_attention_heads = num_attention_heads# 如果 transformer_layers_per_block 是一个整数,则将其转换为列表形式if isinstance(transformer_layers_per_block, int):transformer_layers_per_block = [transformer_layers_per_block] * num_layers# 为每一层构建 ResNet 块和注意力模型for i in range(num_layers):# 设置当前层的输入通道数,第一层使用 in_channels,其他层使用 out_channelsin_channels = in_channels if i == 0 else out_channels# 向 resnets 列表添加一个 ResNet 块resnets.append(ResnetBlockFlat(# 设置 ResNet 块的输入通道数in_channels=in_channels,# 设置 ResNet 块的输出通道数out_channels=out_channels,# 设置时间嵌入通道数temb_channels=temb_channels,# 设置 ResNet 块的 epsilon 值eps=resnet_eps,# 设置 ResNet 块的组数groups=resnet_groups,# 设置 dropout 概率dropout=dropout,# 设置时间嵌入的归一化方法time_embedding_norm=resnet_time_scale_shift,# 设置激活函数non_linearity=resnet_act_fn,# 设置输出缩放因子output_scale_factor=output_scale_factor,# 设置是否在前面进行归一化pre_norm=resnet_pre_norm,))# 如果不使用双交叉注意力if not dual_cross_attention:# 向 attentions 列表添加一个 Transformer 2D 模型attentions.append(Transformer2DModel(# 设置注意力头的数量num_attention_heads,# 设置每个注意力头的输出通道数out_channels // num_attention_heads,# 设置输入通道数in_channels=out_channels,# 设置当前层的 Transformer 层数num_layers=transformer_layers_per_block[i],# 设置交叉注意力的维度cross_attention_dim=cross_attention_dim,# 设置归一化的组数norm_num_groups=resnet_groups,# 设置是否使用线性投影use_linear_projection=use_linear_projection,# 设置是否仅使用交叉注意力only_cross_attention=only_cross_attention,# 设置是否提高注意力精度upcast_attention=upcast_attention,# 设置注意力类型attention_type=attention_type,))else:# 向 attentions 列表添加一个双 Transformer 2D 模型attentions.append(DualTransformer2DModel(# 设置注意力头的数量num_attention_heads,# 设置每个注意力头的输出通道数out_channels // num_attention_heads,# 设置输入通道数in_channels=out_channels,# 固定层数为 1num_layers=1,# 设置交叉注意力的维度cross_attention_dim=cross_attention_dim,# 设置归一化的组数norm_num_groups=resnet_groups,))# 将注意力模型列表转换为 PyTorch 的 ModuleListself.attentions = nn.ModuleList(attentions)# 将 ResNet 块列表转换为 PyTorch 的 ModuleListself.resnets = nn.ModuleList(resnets)# 如果需要添加下采样层if add_downsample:# 初始化下采样层为 ModuleListself.downsamplers = nn.ModuleList([LinearMultiDim(# 设置输出通道数out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op")])else:# 如果不添加下采样层,将其设为 Noneself.downsamplers = None# 初始化梯度检查点标志为 Falseself.gradient_checkpointing = False# 定义前向传播函数,接收隐藏状态和其他可选参数def forward(self,hidden_states: torch.Tensor,  # 当前隐藏状态的张量temb: Optional[torch.Tensor] = None,  # 可选的时间嵌入张量encoder_hidden_states: Optional[torch.Tensor] = None,  # 可选的编码器隐藏状态张量attention_mask: Optional[torch.Tensor] = None,  # 可选的注意力掩码cross_attention_kwargs: Optional[Dict[str, Any]] = None,  # 可选的交叉注意力参数encoder_attention_mask: Optional[torch.Tensor] = None,  # 可选的编码器注意力掩码additional_residuals: Optional[torch.Tensor] = None,  # 可选的额外残差张量) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:  # 返回隐藏状态和输出状态元组output_states = ()  # 初始化输出状态元组blocks = list(zip(self.resnets, self.attentions))  # 将残差网络和注意力模块配对成块for i, (resnet, attn) in enumerate(blocks):  # 遍历每个块及其索引if self.training and self.gradient_checkpointing:  # 检查是否在训练且启用梯度检查点def create_custom_forward(module, return_dict=None):  # 定义自定义前向传播函数def custom_forward(*inputs):  # 自定义前向传播逻辑if return_dict is not None:  # 如果提供了返回字典return module(*inputs, return_dict=return_dict)  # 返回带字典的结果else:return module(*inputs)  # 否则返回普通结果return custom_forward  # 返回自定义前向传播函数# 设置检查点参数,如果 PyTorch 版本大于等于 1.11.0,则使用非重入模式ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}# 通过检查点机制计算当前块的隐藏状态hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet),  # 创建自定义前向函数的检查点hidden_states,  # 输入当前隐藏状态temb,  # 输入时间嵌入**ckpt_kwargs,  # 传递检查点参数)# 通过注意力模块处理隐藏状态并获取输出hidden_states = attn(hidden_states,encoder_hidden_states=encoder_hidden_states,  # 编码器隐藏状态cross_attention_kwargs=cross_attention_kwargs,  # 交叉注意力参数attention_mask=attention_mask,  # 注意力掩码encoder_attention_mask=encoder_attention_mask,  # 编码器注意力掩码return_dict=False,  # 不返回字典格式)[0]  # 取出第一个输出else:  # 如果不启用梯度检查# 直接通过残差网络处理隐藏状态hidden_states = resnet(hidden_states, temb)# 通过注意力模块处理隐藏状态并获取输出hidden_states = attn(hidden_states,encoder_hidden_states=encoder_hidden_states,  # 编码器隐藏状态cross_attention_kwargs=cross_attention_kwargs,  # 交叉注意力参数attention_mask=attention_mask,  # 注意力掩码encoder_attention_mask=encoder_attention_mask,  # 编码器注意力掩码return_dict=False,  # 不返回字典格式)[0]  # 取出第一个输出# 如果是最后一个块并且提供了额外残差,则将其添加到隐藏状态if i == len(blocks) - 1 and additional_residuals is not None:hidden_states = hidden_states + additional_residuals  # 加上额外残差output_states = output_states + (hidden_states,)  # 将当前隐藏状态添加到输出状态元组中if self.downsamplers is not None:  # 如果存在下采样器for downsampler in self.downsamplers:  # 遍历每个下采样器hidden_states = downsampler(hidden_states)  # 处理当前隐藏状态output_states = output_states + (hidden_states,)  # 将当前隐藏状态添加到输出状态元组中return hidden_states, output_states  # 返回最终的隐藏状态和输出状态元组
# 从 diffusers.models.unets.unet_2d_blocks 中复制,替换 UpBlock2D 为 UpBlockFlat,ResnetBlock2D 为 ResnetBlockFlat,Upsample2D 为 LinearMultiDim
class UpBlockFlat(nn.Module):# 初始化函数,定义输入输出通道及其他参数def __init__(self,in_channels: int,  # 输入通道数prev_output_channel: int,  # 前一层输出通道数out_channels: int,  # 当前层输出通道数temb_channels: int,  # 时间嵌入通道数resolution_idx: Optional[int] = None,  # 分辨率索引dropout: float = 0.0,  # dropout 概率num_layers: int = 1,  # 层数resnet_eps: float = 1e-6,  # ResNet 中的 epsilon 值resnet_time_scale_shift: str = "default",  # 时间尺度偏移设置resnet_act_fn: str = "swish",  # 激活函数类型resnet_groups: int = 32,  # 分组数resnet_pre_norm: bool = True,  # 是否进行预归一化output_scale_factor: float = 1.0,  # 输出缩放因子add_upsample: bool = True,  # 是否添加上采样):# 调用父类构造函数super().__init__()# 初始化一个空列表存储 ResNet 块resnets = []# 遍历层数,构建每一层的 ResNet 块for i in range(num_layers):# 根据层数决定残差跳跃通道数res_skip_channels = in_channels if (i == num_layers - 1) else out_channels# 根据当前层数决定输入通道数resnet_in_channels = prev_output_channel if i == 0 else out_channels# 将 ResNet 块添加到列表中resnets.append(ResnetBlockFlat(in_channels=resnet_in_channels + res_skip_channels,  # 输入通道数out_channels=out_channels,  # 输出通道数temb_channels=temb_channels,  # 时间嵌入通道数eps=resnet_eps,  # epsilon 值groups=resnet_groups,  # 分组数dropout=dropout,  # dropout 概率time_embedding_norm=resnet_time_scale_shift,  # 时间嵌入归一化non_linearity=resnet_act_fn,  # 激活函数output_scale_factor=output_scale_factor,  # 输出缩放因子pre_norm=resnet_pre_norm,  # 预归一化))# 将 ResNet 块列表转换为模块列表self.resnets = nn.ModuleList(resnets)# 如果需要添加上采样层,则创建上采样模块if add_upsample:self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)])else:# 否则设置为 Noneself.upsamplers = None# 初始化梯度检查点标志self.gradient_checkpointing = False# 设置分辨率索引self.resolution_idx = resolution_idx# 前向传播函数def forward(self,hidden_states: torch.Tensor,  # 隐藏状态张量res_hidden_states_tuple: Tuple[torch.Tensor, ...],  # 残差隐藏状态元组temb: Optional[torch.Tensor] = None,  # 可选的时间嵌入张量upsample_size: Optional[int] = None,  # 可选的上采样大小*args,  # 可变参数**kwargs,  # 可变关键字参数) -> torch.Tensor:  # 定义一个返回 torch.Tensor 类型的函数# 如果参数列表 args 长度大于 0 或 kwargs 中的 scale 参数不为 Noneif len(args) > 0 or kwargs.get("scale", None) is not None:# 设置弃用消息,提醒用户 scale 参数已弃用且将被忽略deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."# 调用 deprecate 函数,记录 scale 参数的弃用deprecate("scale", "1.0.0", deprecation_message)# 检查 FreeU 是否启用,取决于 s1, s2, b1 和 b2 的值is_freeu_enabled = (getattr(self, "s1", None)  # 获取 self 中的 s1 属性and getattr(self, "s2", None)  # 获取 self 中的 s2 属性and getattr(self, "b1", None)  # 获取 self 中的 b1 属性and getattr(self, "b2", None)  # 获取 self 中的 b2 属性)# 遍历 self.resnets 中的每个 ResNet 模型for resnet in self.resnets:# 弹出 res 隐藏状态的最后一个元素res_hidden_states = res_hidden_states_tuple[-1]  # 移除 res 隐藏状态元组的最后一个元素res_hidden_states_tuple = res_hidden_states_tuple[:-1]  # FreeU: 仅在前两个阶段进行操作if is_freeu_enabled:# 应用 FreeU 操作,返回更新后的 hidden_states 和 res_hidden_stateshidden_states, res_hidden_states = apply_freeu(self.resolution_idx,  # 当前分辨率索引hidden_states,  # 当前隐藏状态res_hidden_states,  # 之前的隐藏状态s1=self.s1,  # s1 参数s2=self.s2,  # s2 参数b1=self.b1,  # b1 参数b2=self.b2,  # b2 参数)# 将当前的 hidden_states 和 res_hidden_states 在维度 1 上拼接hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)  # 如果处于训练模式并且开启了梯度检查点if self.training and self.gradient_checkpointing:# 定义一个创建自定义前向函数的函数def create_custom_forward(module):# 定义自定义前向函数,接收输入并调用模块def custom_forward(*inputs):return module(*inputs)return custom_forward# 如果 PyTorch 版本大于等于 1.11.0if is_torch_version(">=", "1.11.0"):# 使用梯度检查点来计算 hidden_stateshidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet),  # 使用自定义前向函数hidden_states,  # 当前隐藏状态temb,  # 传入的额外输入use_reentrant=False  # 禁用重入检查)else:# 对于早期版本,使用梯度检查点计算 hidden_stateshidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet),  # 使用自定义前向函数hidden_states,  # 当前隐藏状态temb  # 传入的额外输入)else:# 在非训练模式下直接调用 resnet 处理 hidden_stateshidden_states = resnet(hidden_states, temb)  # 如果存在上采样器if self.upsamplers is not None:# 遍历所有上采样器for upsampler in self.upsamplers:# 使用上采样器对 hidden_states 进行处理,指定上采样尺寸hidden_states = upsampler(hidden_states, upsample_size)  # 返回处理后的 hidden_statesreturn hidden_states  
# 从 diffusers.models.unets.unet_2d_blocks 中复制的代码,修改了类名和一些组件
class CrossAttnUpBlockFlat(nn.Module):# 初始化方法,定义类的基本属性和参数def __init__(# 输入通道数in_channels: int,# 输出通道数out_channels: int,# 上一层输出的通道数prev_output_channel: int,# 额外的时间嵌入通道数temb_channels: int,# 可选的分辨率索引resolution_idx: Optional[int] = None,# dropout 概率dropout: float = 0.0,# 层数num_layers: int = 1,# 每个块的变换器层数,可以是单个整数或元组transformer_layers_per_block: Union[int, Tuple[int]] = 1,# ResNet 的 epsilon 值resnet_eps: float = 1e-6,# ResNet 时间尺度偏移的类型resnet_time_scale_shift: str = "default",# ResNet 激活函数的类型resnet_act_fn: str = "swish",# ResNet 的组数resnet_groups: int = 32,# 是否在 ResNet 中使用预归一化resnet_pre_norm: bool = True,# 注意力头的数量num_attention_heads: int = 1,# 交叉注意力的维度cross_attention_dim: int = 1280,# 输出缩放因子output_scale_factor: float = 1.0,# 是否添加上采样步骤add_upsample: bool = True,# 是否使用双交叉注意力dual_cross_attention: bool = False,# 是否使用线性投影use_linear_projection: bool = False,# 是否仅使用交叉注意力only_cross_attention: bool = False,# 是否上溯注意力upcast_attention: bool = False,# 注意力类型attention_type: str = "default",# 定义构造函数的结束部分):# 调用父类的构造函数super().__init__()# 初始化一个空列表用于存储残差网络块resnets = []# 初始化一个空列表用于存储注意力模型attentions = []# 设置是否使用交叉注意力标志为真self.has_cross_attention = True# 设置注意力头的数量self.num_attention_heads = num_attention_heads# 如果 transformer_layers_per_block 是整数,则将其转换为相同长度的列表if isinstance(transformer_layers_per_block, int):transformer_layers_per_block = [transformer_layers_per_block] * num_layers# 遍历每一层以构建残差网络和注意力模型for i in range(num_layers):# 设置残差跳过通道数,最后一层使用输入通道,否则使用输出通道res_skip_channels = in_channels if (i == num_layers - 1) else out_channels# 设置残差网络输入通道数,第一层使用前一层输出通道,否则使用当前输出通道resnet_in_channels = prev_output_channel if i == 0 else out_channels# 添加一个残差网络块到 resnets 列表中resnets.append(ResnetBlockFlat(# 设置残差网络输入通道数in_channels=resnet_in_channels + res_skip_channels,# 设置残差网络输出通道数out_channels=out_channels,# 设置时间嵌入通道数temb_channels=temb_channels,# 设置残差网络的 epsilon 值eps=resnet_eps,# 设置残差网络的组数groups=resnet_groups,# 设置丢弃率dropout=dropout,# 设置时间嵌入的归一化方法time_embedding_norm=resnet_time_scale_shift,# 设置非线性激活函数non_linearity=resnet_act_fn,# 设置输出缩放因子output_scale_factor=output_scale_factor,# 设置是否进行预归一化pre_norm=resnet_pre_norm,))# 如果不使用双重交叉注意力if not dual_cross_attention:# 添加一个普通的 Transformer2DModel 到 attentions 列表中attentions.append(Transformer2DModel(# 设置注意力头的数量num_attention_heads,# 设置每个注意力头的输出通道数out_channels // num_attention_heads,# 设置输入通道数in_channels=out_channels,# 设置层数num_layers=transformer_layers_per_block[i],# 设置交叉注意力维度cross_attention_dim=cross_attention_dim,# 设置归一化组数norm_num_groups=resnet_groups,# 设置是否使用线性投影use_linear_projection=use_linear_projection,# 设置是否仅使用交叉注意力only_cross_attention=only_cross_attention,# 设置是否上溯注意力upcast_attention=upcast_attention,# 设置注意力类型attention_type=attention_type,))else:# 添加一个双重 Transformer2DModel 到 attentions 列表中attentions.append(DualTransformer2DModel(# 设置注意力头的数量num_attention_heads,# 设置每个注意力头的输出通道数out_channels // num_attention_heads,# 设置输入通道数in_channels=out_channels,# 设置层数为 1num_layers=1,# 设置交叉注意力维度cross_attention_dim=cross_attention_dim,# 设置归一化组数norm_num_groups=resnet_groups,))# 将注意力模型列表转换为 nn.ModuleListself.attentions = nn.ModuleList(attentions)# 将残差网络块列表转换为 nn.ModuleListself.resnets = nn.ModuleList(resnets)# 如果需要添加上采样层if add_upsample:# 将上采样器添加到 nn.ModuleList 中self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)])else:# 否则将上采样器设置为 Noneself.upsamplers = None# 设置梯度检查点标志为假self.gradient_checkpointing = False# 设置分辨率索引self.resolution_idx = resolution_idx# 定义前向传播函数,接收多个输入参数def forward(self,# 隐藏状态,类型为 PyTorch 的张量hidden_states: torch.Tensor,# 包含残差隐藏状态的元组,元素类型为 PyTorch 张量res_hidden_states_tuple: Tuple[torch.Tensor, ...],# 可选的时间嵌入,类型为 PyTorch 的张量temb: Optional[torch.Tensor] = None,# 可选的编码器隐藏状态,类型为 PyTorch 的张量encoder_hidden_states: Optional[torch.Tensor] = None,# 可选的交叉注意力参数,类型为字典,包含任意键值对cross_attention_kwargs: Optional[Dict[str, Any]] = None,# 可选的上采样大小,类型为整数upsample_size: Optional[int] = None,# 可选的注意力掩码,类型为 PyTorch 的张量attention_mask: Optional[torch.Tensor] = None,# 可选的编码器注意力掩码,类型为 PyTorch 的张量encoder_attention_mask: Optional[torch.Tensor] = None,
# 从 diffusers.models.unets.unet_2d_blocks 中复制的 UNetMidBlock2D 代码,替换了 UNetMidBlock2D 为 UNetMidBlockFlat,ResnetBlock2D 为 ResnetBlockFlat
class UNetMidBlockFlat(nn.Module):"""2D UNet 中间块 [`UNetMidBlockFlat`],包含多个残差块和可选的注意力块。参数:in_channels (`int`): 输入通道的数量。temb_channels (`int`): 时间嵌入通道的数量。dropout (`float`, *可选*, 默认值为 0.0): dropout 比率。num_layers (`int`, *可选*, 默认值为 1): 残差块的数量。resnet_eps (`float`, *可选*, 默认值为 1e-6): resnet 块的 epsilon 值。resnet_time_scale_shift (`str`, *可选*, 默认值为 `default`):应用于时间嵌入的归一化类型。这可以帮助提高模型在长范围时间依赖任务上的性能。resnet_act_fn (`str`, *可选*, 默认值为 `swish`): resnet 块的激活函数。resnet_groups (`int`, *可选*, 默认值为 32):resnet 块的分组归一化层使用的组数。attn_groups (`Optional[int]`, *可选*, 默认值为 None): 注意力块的组数。resnet_pre_norm (`bool`, *可选*, 默认值为 `True`):是否在 resnet 块中使用预归一化。add_attention (`bool`, *可选*, 默认值为 `True`): 是否添加注意力块。attention_head_dim (`int`, *可选*, 默认值为 1):单个注意力头的维度。注意力头的数量基于此值和输入通道的数量确定。output_scale_factor (`float`, *可选*, 默认值为 1.0): 输出缩放因子。返回:`torch.Tensor`: 最后一个残差块的输出,是一个形状为 `(batch_size, in_channels,height, width)` 的张量。"""def __init__(self,in_channels: int,temb_channels: int,dropout: float = 0.0,num_layers: int = 1,resnet_eps: float = 1e-6,resnet_time_scale_shift: str = "default",  # 默认,空间resnet_act_fn: str = "swish",resnet_groups: int = 32,attn_groups: Optional[int] = None,resnet_pre_norm: bool = True,add_attention: bool = True,attention_head_dim: int = 1,output_scale_factor: float = 1.0,):# 初始化 UNetMidBlockFlat 类,设置各参数的默认值super().__init__()  # 调用父类的初始化方法self.in_channels = in_channels  # 保存输入通道数self.temb_channels = temb_channels  # 保存时间嵌入通道数self.dropout = dropout  # 保存 dropout 比率self.num_layers = num_layers  # 保存残差块的数量self.resnet_eps = resnet_eps  # 保存 resnet 块的 epsilon 值self.resnet_time_scale_shift = resnet_time_scale_shift  # 保存时间缩放偏移类型self.resnet_act_fn = resnet_act_fn  # 保存激活函数类型self.resnet_groups = resnet_groups  # 保存分组数self.attn_groups = attn_groups  # 保存注意力组数self.resnet_pre_norm = resnet_pre_norm  # 保存是否使用预归一化self.add_attention = add_attention  # 保存是否添加注意力块self.attention_head_dim = attention_head_dim  # 保存注意力头的维度self.output_scale_factor = output_scale_factor  # 保存输出缩放因子def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:# 定义前向传播方法,接受隐藏状态和可选的时间嵌入hidden_states = self.resnets[0](hidden_states, temb)  # 通过第一个残差块处理隐藏状态for attn, resnet in zip(self.attentions, self.resnets[1:]):  # 遍历后续的注意力块和残差块if attn is not None:  # 如果注意力块存在hidden_states = attn(hidden_states, temb=temb)  # 通过注意力块处理隐藏状态hidden_states = resnet(hidden_states, temb)  # 通过残差块处理隐藏状态return hidden_states  # 返回处理后的隐藏状态
# 从 diffusers.models.unets.unet_2d_blocks 中复制,替换 UNetMidBlock2DCrossAttn 为 UNetMidBlockFlatCrossAttn,ResnetBlock2D 为 ResnetBlockFlat
class UNetMidBlockFlatCrossAttn(nn.Module):# 初始化方法,定义模型参数def __init__(self,# 输入通道数in_channels: int,# 时间嵌入通道数temb_channels: int,# 输出通道数,默认为 Noneout_channels: Optional[int] = None,# Dropout 概率,默认为 0.0dropout: float = 0.0,# 层数,默认为 1num_layers: int = 1,# 每个块的 Transformer 层数,默认为 1transformer_layers_per_block: Union[int, Tuple[int]] = 1,# ResNet 的 epsilon 值,默认为 1e-6resnet_eps: float = 1e-6,# ResNet 的时间尺度偏移,默认为 "default"resnet_time_scale_shift: str = "default",# ResNet 的激活函数类型,默认为 "swish"resnet_act_fn: str = "swish",# ResNet 的分组数,默认为 32resnet_groups: int = 32,# 输出的 ResNet 分组数,默认为 Noneresnet_groups_out: Optional[int] = None,# 是否使用预归一化,默认为 Trueresnet_pre_norm: bool = True,# 注意力头数,默认为 1num_attention_heads: int = 1,# 输出缩放因子,默认为 1.0output_scale_factor: float = 1.0,# 交叉注意力维度,默认为 1280cross_attention_dim: int = 1280,# 是否使用双交叉注意力,默认为 Falsedual_cross_attention: bool = False,# 是否使用线性投影,默认为 Falseuse_linear_projection: bool = False,# 是否上升注意力计算精度,默认为 Falseupcast_attention: bool = False,# 注意力类型,默认为 "default"attention_type: str = "default",# 前向传播方法,定义模型的前向计算逻辑def forward(self,# 隐藏状态张量hidden_states: torch.Tensor,# 可选的时间嵌入张量,默认为 Nonetemb: Optional[torch.Tensor] = None,# 可选的编码器隐藏状态张量,默认为 Noneencoder_hidden_states: Optional[torch.Tensor] = None,# 可选的注意力掩码,默认为 Noneattention_mask: Optional[torch.Tensor] = None,# 可选的交叉注意力参数字典,默认为 Nonecross_attention_kwargs: Optional[Dict[str, Any]] = None,# 可选的编码器注意力掩码,默认为 Noneencoder_attention_mask: Optional[torch.Tensor] = None,) -> torch.Tensor:  # 定义函数的返回类型为 torch.Tensorif cross_attention_kwargs is not None:  # 检查 cross_attention_kwargs 是否为 Noneif cross_attention_kwargs.get("scale", None) is not None:  # 检查 scale 是否在 cross_attention_kwargs 中logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")  # 发出警告,提示 scale 参数已过时hidden_states = self.resnets[0](hidden_states, temb)  # 使用第一个残差网络处理隐藏状态和时间嵌入for attn, resnet in zip(self.attentions, self.resnets[1:]):  # 遍历注意力层和后续的残差网络if self.training and self.gradient_checkpointing:  # 检查是否在训练模式且开启了梯度检查点def create_custom_forward(module, return_dict=None):  # 定义一个函数以创建自定义前向传播def custom_forward(*inputs):  # 定义实际的前向传播函数if return_dict is not None:  # 检查是否需要返回字典形式的输出return module(*inputs, return_dict=return_dict)  # 调用模块并返回字典else:  # 如果不需要字典形式的输出return module(*inputs)  # 直接调用模块并返回结果return custom_forward  # 返回自定义前向传播函数ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}  # 根据 PyTorch 版本设置检查点参数hidden_states = attn(  # 使用注意力层处理隐藏状态hidden_states,  # 输入隐藏状态encoder_hidden_states=encoder_hidden_states,  # 输入编码器的隐藏状态cross_attention_kwargs=cross_attention_kwargs,  # 传递交叉注意力参数attention_mask=attention_mask,  # 传递注意力掩码encoder_attention_mask=encoder_attention_mask,  # 传递编码器注意力掩码return_dict=False,  # 不返回字典)[0]  # 取出输出的第一个元素hidden_states = torch.utils.checkpoint.checkpoint(  # 使用检查点保存内存create_custom_forward(resnet),  # 创建自定义前向传播hidden_states,  # 输入隐藏状态temb,  # 输入时间嵌入**ckpt_kwargs,  # 解包检查点参数)else:  # 如果不在训练模式或不使用梯度检查点hidden_states = attn(  # 使用注意力层处理隐藏状态hidden_states,  # 输入隐藏状态encoder_hidden_states=encoder_hidden_states,  # 输入编码器的隐藏状态cross_attention_kwargs=cross_attention_kwargs,  # 传递交叉注意力参数attention_mask=attention_mask,  # 传递注意力掩码encoder_attention_mask=encoder_attention_mask,  # 传递编码器注意力掩码return_dict=False,  # 不返回字典)[0]  # 取出输出的第一个元素hidden_states = resnet(hidden_states, temb)  # 使用残差网络处理隐藏状态和时间嵌入return hidden_states  # 返回处理后的隐藏状态
# 从 diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn 复制,替换 UNetMidBlock2DSimpleCrossAttn 为 UNetMidBlockFlatSimpleCrossAttn,ResnetBlock2D 为 ResnetBlockFlat
class UNetMidBlockFlatSimpleCrossAttn(nn.Module):# 初始化方法,设置各层的输入输出参数def __init__(# 输入通道数in_channels: int,# 条件嵌入通道数temb_channels: int,# Dropout 概率dropout: float = 0.0,# 网络层数num_layers: int = 1,# ResNet 的 epsilon 值resnet_eps: float = 1e-6,# ResNet 的时间缩放偏移方式resnet_time_scale_shift: str = "default",# ResNet 激活函数类型resnet_act_fn: str = "swish",# ResNet 中组的数量resnet_groups: int = 32,# 是否使用 ResNet 前归一化resnet_pre_norm: bool = True,# 注意力头的维度attention_head_dim: int = 1,# 输出缩放因子output_scale_factor: float = 1.0,# 交叉注意力的维度cross_attention_dim: int = 1280,# 是否跳过时间激活skip_time_act: bool = False,# 是否仅使用交叉注意力only_cross_attention: bool = False,# 交叉注意力的归一化方式cross_attention_norm: Optional[str] = None,):# 调用父类的初始化方法super().__init__()# 设置是否使用交叉注意力机制self.has_cross_attention = True# 设置注意力头的维度self.attention_head_dim = attention_head_dim# 确定 ResNet 的组数,若未提供则使用默认值resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)# 计算头的数量self.num_heads = in_channels // self.attention_head_dim# 确保至少有一个 ResNet 块resnets = [# 创建一个 ResNet 块ResnetBlockFlat(in_channels=in_channels,out_channels=in_channels,temb_channels=temb_channels,eps=resnet_eps,groups=resnet_groups,dropout=dropout,time_embedding_norm=resnet_time_scale_shift,non_linearity=resnet_act_fn,output_scale_factor=output_scale_factor,pre_norm=resnet_pre_norm,skip_time_act=skip_time_act,)]# 初始化注意力列表attentions = []# 根据层数创建对应的注意力机制for _ in range(num_layers):# 根据是否支持缩放点积注意力选择处理器processor = (AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor())# 添加注意力机制到列表attentions.append(Attention(query_dim=in_channels,cross_attention_dim=in_channels,heads=self.num_heads,dim_head=self.attention_head_dim,added_kv_proj_dim=cross_attention_dim,norm_num_groups=resnet_groups,bias=True,upcast_softmax=True,only_cross_attention=only_cross_attention,cross_attention_norm=cross_attention_norm,processor=processor,))# 添加 ResNet 块到列表resnets.append(ResnetBlockFlat(in_channels=in_channels,out_channels=in_channels,temb_channels=temb_channels,eps=resnet_eps,groups=resnet_groups,dropout=dropout,time_embedding_norm=resnet_time_scale_shift,non_linearity=resnet_act_fn,output_scale_factor=output_scale_factor,pre_norm=resnet_pre_norm,skip_time_act=skip_time_act,))# 将注意力层存入模块列表self.attentions = nn.ModuleList(attentions)# 将 ResNet 块存入模块列表self.resnets = nn.ModuleList(resnets)def forward(# 定义前向传播的方法self,hidden_states: torch.Tensor,# 可选的时间嵌入张量temb: Optional[torch.Tensor] = None,# 可选的编码器隐藏状态张量encoder_hidden_states: Optional[torch.Tensor] = None,# 可选的注意力掩码张量attention_mask: Optional[torch.Tensor] = None,# 可选的交叉注意力参数字典cross_attention_kwargs: Optional[Dict[str, Any]] = None,# 可选的编码器注意力掩码张量encoder_attention_mask: Optional[torch.Tensor] = None,) -> torch.Tensor:# 如果传入的 cross_attention_kwargs 为 None,则初始化为空字典cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}# 检查 cross_attention_kwargs 中是否有 'scale',如果有则发出警告,说明该参数已弃用if cross_attention_kwargs.get("scale", None) is not None:logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")# 如果 attention_mask 为 Noneif attention_mask is None:# 如果 encoder_hidden_states 被定义:表示我们在进行交叉注意力,因此应该使用交叉注意力掩码mask = None if encoder_hidden_states is None else encoder_attention_maskelse:# 当 attention_mask 被定义时:我们不检查 encoder_attention_mask# 这是为了与 UnCLIP 兼容,UnCLIP 使用 'attention_mask' 参数作为交叉注意力掩码# TODO: UnCLIP 应通过 encoder_attention_mask 参数而不是 attention_mask 参数来表达交叉注意力掩码#       然后我们可以简化整个 if/else 块为:#         mask = attention_mask if encoder_hidden_states is None else encoder_attention_maskmask = attention_mask# 使用第一个残差网络处理隐藏状态和时间嵌入hidden_states = self.resnets[0](hidden_states, temb)# 遍历所有注意力层和对应的残差网络for attn, resnet in zip(self.attentions, self.resnets[1:]):# 使用注意力层处理隐藏状态hidden_states = attn(hidden_states,encoder_hidden_states=encoder_hidden_states,  # 传递编码器隐藏状态attention_mask=mask,  # 传递掩码**cross_attention_kwargs,  # 传递交叉注意力参数)# 使用残差网络处理隐藏状态hidden_states = resnet(hidden_states, temb)# 返回最终的隐藏状态return hidden_states

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.ryyt.cn/news/74589.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈,一经查实,立即删除!

相关文章

习题6.3

习题6.3代码 import numpy as np import pandas as pd import cvxpy as cp import networkx as nx import matplotlib.pyplot as plt plt.rcParams[font.sans-serif]=[Times New Roman + SimSun + WFM Sans SC] plt.rcParams[mathtext.fontset]=stix Times New Roman + SimSun …

苹果笔记本和微软Surface哪个更适合商务使用

在商务环境中,选择合适的笔记本电脑对于提高工作效率至关重要。本文对苹果笔记本和微软Surface进行比较分析,探讨哪种更适合商务使用。主要考虑因素包括:1.性能和可靠性;2.操作系统与软件兼容性;3.设计与便携性;4.电池续航力;5.价格与性价比;6.售后服务与支持。通过全面…

【日记】今天好忙(459 字)

写在前面今天没有什么可看的,可以不用看。 正文爆炸忙。整个下午我的手似乎就没停过,现在写这则日记,回想那个时候的自己,觉得好陌生。整体来说,那段时间也一片空白,什么印象都没有了。太忙,也没有做其它事情的空间。BAEA 台灯到了。我还以为又送到发改局去了,先去那边…

Nuxt.js 应用中的 build:manifest 事件钩子详解

title: Nuxt.js 应用中的 build:manifest 事件钩子详解 date: 2024/10/22 updated: 2024/10/22 author: cmdragon excerpt: build:manifest 是 Nuxt.js 中的一个生命周期钩子,它在 Vite 和 Webpack 构建清单期间被调用。利用这个钩子,开发者可以自定义 Nitro 渲染在最终 H…

如何进行前端单元测试

​前端单元测试的引入为软件开发流程带来了更高的质量和稳定性,需要遵循以下步骤:一、理解单元测试的重要性;二、选择合适的测试框架;三、编写有效的测试用例;四、模拟外部依赖;五、持续维护和优化测试。单元测试的开始,是对前端代码的核心功能进行验证。一、理解单元测…

Jenkins打包Unity游戏环境变量配置

Jenkins打包Unity游戏失败,通过报错日志会查找到sdk环境有问题,解决sdk的环境问题后会出现ndk环境有问题,为了解决这两个问题导致的打包失败需要在Jenkins中配置环境变量打开 Jenkins 首页,选中Manager Jenkins,再点击 System 选项找到全局属性,勾选Environment variable…

【Linux】shell 脚本 (.sh) 编写及执行

shell脚本shell脚本就是一些命令的集合#!/bin/bash echo "文件开头代表:该文件使用的是bash语法" 一、运行.sh文件 方法一:当前文件执行.sh 文件# 文件必须含有x执行权限 [文件赋x权限:chmod u+x hello.sh] ./test.sh# 文件可以没有x权限 sh test.sh 方法二:绝…

人工智能编程助手MarsCode注册和安装步骤

人工智能编程助手MarsCode注册和安装步骤 字节最近推出了人工智能编程助手MarsCode,功能非常强大。在IDEA中安装和使用MarsCode的步骤如下: 一、注册MarsCode账号注册地址:https://www.marscode.cn/events/s/iSMPHK8a/ 二、在Idea中安装插件点击菜单“File”——“Settings”…