diffusers-源码解析-二十三-

news/2024/10/22 12:37:07

diffusers 源码解析(二十三)

.\diffusers\pipelines\controlnet\pipeline_controlnet_sd_xl_img2img.py

# 版权所有 2024 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)许可;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下网址获得许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,软件
# 在“按原样”基础上分发,不提供任何形式的保证或条件,
# 无论是明示或暗示的。
# 请参阅许可证以了解管理权限的具体语言和
# 限制条款。import inspect  # 导入 inspect 模块,用于获取对象的摘要信息
from typing import Any, Callable, Dict, List, Optional, Tuple, Union  # 导入类型注解模块import numpy as np  # 导入 numpy,用于数组和矩阵计算
import PIL.Image  # 导入 PIL.Image,用于处理图像
import torch  # 导入 PyTorch,用于深度学习
import torch.nn.functional as F  # 导入 PyTorch 的函数式 API
from transformers import (  # 从 transformers 导入模型和处理器CLIPImageProcessor,  # 导入 CLIP 图像处理器CLIPTextModel,  # 导入 CLIP 文本模型CLIPTextModelWithProjection,  # 导入带投影的 CLIP 文本模型CLIPTokenizer,  # 导入 CLIP 分词器CLIPVisionModelWithProjection,  # 导入带投影的 CLIP 视觉模型
)from diffusers.utils.import_utils import is_invisible_watermark_available  # 导入检查是否可用的隐形水印功能from ...callbacks import MultiPipelineCallbacks, PipelineCallback  # 导入多管道回调和管道回调类
from ...image_processor import PipelineImageInput, VaeImageProcessor  # 导入图像处理相关类
from ...loaders import (  # 导入加载器相关类FromSingleFileMixin,  # 从单文件加载的混合类IPAdapterMixin,  # 图像处理适配器混合类StableDiffusionXLLoraLoaderMixin,  # StableDiffusionXL Lora 加载混合类TextualInversionLoaderMixin,  # 文本反转加载混合类
)
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel  # 导入不同模型
from ...models.attention_processor import (  # 导入注意力处理器AttnProcessor2_0,  # 注意力处理器版本 2.0XFormersAttnProcessor,  # XFormers 注意力处理器
)
from ...models.lora import adjust_lora_scale_text_encoder  # 导入调整 Lora 标度文本编码器的函数
from ...schedulers import KarrasDiffusionSchedulers  # 导入 Karras 扩散调度器
from ...utils import (  # 导入常用工具USE_PEFT_BACKEND,  # 指示是否使用 PEFT 后端的常量deprecate,  # 导入弃用装饰器logging,  # 导入日志记录模块replace_example_docstring,  # 导入替换示例文档字符串的工具scale_lora_layers,  # 导入缩放 Lora 层的工具unscale_lora_layers,  # 导入反缩放 Lora 层的工具
)
from ...utils.torch_utils import is_compiled_module, randn_tensor  # 导入与 PyTorch 相关的工具
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin  # 导入扩散管道和稳定扩散混合类
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput  # 导入稳定扩散 XL 管道输出类if is_invisible_watermark_available():  # 如果隐形水印功能可用from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker  # 导入稳定扩散 XL 水印类from .multicontrolnet import MultiControlNetModel  # 导入多控制网模型logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器,禁止 pylint 检查EXAMPLE_DOC_STRING = """  # 示例文档字符串的空模板
"""# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 复制的函数
def retrieve_latents(  # 定义函数以检索潜在变量encoder_output: torch.Tensor,  # 输入为编码器输出的张量generator: Optional[torch.Generator] = None,  # 可选的随机数生成器sample_mode: str = "sample"  # 采样模式,默认为“sample”
):if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":  # 如果编码器输出有潜在分布并且模式为采样return encoder_output.latent_dist.sample(generator)  # 从潜在分布中采样并返回elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":  # 如果编码器输出有潜在分布并且模式为“argmax”return encoder_output.latent_dist.mode()  # 返回潜在分布的众数elif hasattr(encoder_output, "latents"):  # 如果编码器输出有潜在变量return encoder_output.latents  # 直接返回潜在变量else:  # 如果以上条件都不满足raise AttributeError("Could not access latents of provided encoder_output")  # 抛出属性错误,说明无法访问潜在变量class StableDiffusionXLControlNetImg2ImgPipeline(  # 定义 StableDiffusionXL 控制网络图像到图像的管道类DiffusionPipeline,  # 继承自扩散管道# 继承稳定扩散模型的混合类StableDiffusionMixin,# 继承文本反转加载器的混合类TextualInversionLoaderMixin,# 继承稳定扩散 XL Lora 加载器的混合类StableDiffusionXLLoraLoaderMixin,# 继承单文件加载器的混合类FromSingleFileMixin,# 继承 IP 适配器的混合类IPAdapterMixin,
# 文档字符串,描述使用 ControlNet 指导的图像生成管道r"""Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance.This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods thelibrary implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)The pipeline also inherits the following loading methods:- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters"""# 定义模型在 CPU 上卸载的顺序model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"# 定义可选组件的列表,用于管道的初始化_optional_components = ["tokenizer",  # 词汇表,用于文本编码"tokenizer_2",  # 第二个词汇表,用于文本编码"text_encoder",  # 文本编码器,用于生成文本嵌入"text_encoder_2",  # 第二个文本编码器,可能有不同的功能"feature_extractor",  # 特征提取器,用于图像特征的提取"image_encoder",  # 图像编码器,将图像转换为嵌入]# 定义回调张量输入的列表,用于处理管道中的输入_callback_tensor_inputs = ["latents",  # 潜在变量,用于生成模型的输入"prompt_embeds",  # 正向提示的嵌入表示"negative_prompt_embeds",  # 负向提示的嵌入表示"add_text_embeds",  # 额外文本嵌入,用于补充输入"add_time_ids",  # 附加的时间标识符,用于时间相关的处理"negative_pooled_prompt_embeds",  # 负向池化提示的嵌入表示"add_neg_time_ids",  # 附加的负向时间标识符]# 构造函数,初始化管道所需的组件def __init__(self,  # 构造函数的第一个参数,指向类的实例vae: AutoencoderKL,  # 变分自编码器,用于图像的重建text_encoder: CLIPTextModel,  # 文本编码器,使用 CLIP 模型text_encoder_2: CLIPTextModelWithProjection,  # 第二个文本编码器,带投影功能的 CLIP 模型tokenizer: CLIPTokenizer,  # 第一个 CLIP 词汇表tokenizer_2: CLIPTokenizer,  # 第二个 CLIP 词汇表unet: UNet2DConditionModel,  # U-Net 模型,用于生成图像controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],  # 控制网络模型,用于引导生成scheduler: KarrasDiffusionSchedulers,  # 调度器,控制扩散过程requires_aesthetics_score: bool = False,  # 是否需要美学评分,默认为 Falseforce_zeros_for_empty_prompt: bool = True,  # 对于空提示强制使用零值,默认为 Trueadd_watermarker: Optional[bool] = None,  # 是否添加水印,默认为 Nonefeature_extractor: CLIPImageProcessor = None,  # 特征提取器,默认为 Noneimage_encoder: CLIPVisionModelWithProjection = None,  # 图像编码器,默认为 None):# 调用父类的构造函数进行初始化super().__init__()# 检查 controlnet 是否为列表或元组,如果是则将其封装为 MultiControlNetModel 对象if isinstance(controlnet, (list, tuple)):controlnet = MultiControlNetModel(controlnet)# 注册多个模块,包括 VAE、文本编码器、tokenizer、UNet 等self.register_modules(vae=vae,text_encoder=text_encoder,text_encoder_2=text_encoder_2,tokenizer=tokenizer,tokenizer_2=tokenizer_2,unet=unet,controlnet=controlnet,scheduler=scheduler,feature_extractor=feature_extractor,image_encoder=image_encoder,)# 计算 VAE 的缩放因子,通常用于图像尺寸调整self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)# 创建 VAE 图像处理器,设置缩放因子并开启 RGB 转换self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)# 创建控制图像处理器,设置缩放因子,开启 RGB 转换,但不进行标准化self.control_image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False)# 根据输入参数或默认值确定是否添加水印add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()# 如果需要水印,则初始化水印对象if add_watermarker:self.watermark = StableDiffusionXLWatermarker()else:# 否则将水印设置为 Noneself.watermark = None# 注册配置,强制空提示使用零值self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)# 注册配置,标记是否需要美学评分self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)# 从 StableDiffusionXLPipeline 复制的 encode_prompt 方法def encode_prompt(self,# 定义 prompt 字符串及其相关参数prompt: str,prompt_2: Optional[str] = None,device: Optional[torch.device] = None,num_images_per_prompt: int = 1,do_classifier_free_guidance: bool = True,negative_prompt: Optional[str] = None,negative_prompt_2: Optional[str] = None,prompt_embeds: Optional[torch.Tensor] = None,negative_prompt_embeds: Optional[torch.Tensor] = None,pooled_prompt_embeds: Optional[torch.Tensor] = None,negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,lora_scale: Optional[float] = None,clip_skip: Optional[int] = None,# 从 StableDiffusionPipeline 复制的 encode_image 方法# 定义一个方法来编码图像,参数包括图像、设备、每个提示的图像数量和可选的隐藏状态输出def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):# 获取图像编码器参数的数据类型dtype = next(self.image_encoder.parameters()).dtype# 检查输入的图像是否为张量类型if not isinstance(image, torch.Tensor):# 如果不是,将其转换为张量,并提取像素值image = self.feature_extractor(image, return_tensors="pt").pixel_values# 将图像移动到指定设备并转换为相应的数据类型image = image.to(device=device, dtype=dtype)# 检查是否需要输出隐藏状态if output_hidden_states:# 获取图像编码器的隐藏状态,选择倒数第二个隐藏层image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]# 将隐藏状态按每个提示的图像数量重复image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)# 获取无条件图像编码的隐藏状态,使用全零张量作为输入uncond_image_enc_hidden_states = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[-2]# 将无条件隐藏状态按每个提示的图像数量重复uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)# 返回图像编码的隐藏状态和无条件图像编码的隐藏状态return image_enc_hidden_states, uncond_image_enc_hidden_stateselse:# 获取图像编码的嵌入表示image_embeds = self.image_encoder(image).image_embeds# 将嵌入表示按每个提示的图像数量重复image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)# 创建与图像嵌入同样形状的全零张量作为无条件嵌入uncond_image_embeds = torch.zeros_like(image_embeds)# 返回图像嵌入和无条件图像嵌入return image_embeds, uncond_image_embeds# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds 复制的方法def prepare_ip_adapter_image_embeds(# 定义方法的参数,包括 IP 适配器图像、图像嵌入、设备、每个提示的图像数量和分类器自由引导的标志self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance):# 初始化一个空列表,用于存储图像嵌入image_embeds = []# 如果启用了无分类器自由引导,则初始化负图像嵌入列表if do_classifier_free_guidance:negative_image_embeds = []# 如果输入适配器图像嵌入为 Noneif ip_adapter_image_embeds is None:# 检查输入适配器图像是否为列表类型,如果不是,则转换为列表if not isinstance(ip_adapter_image, list):ip_adapter_image = [ip_adapter_image]# 检查输入适配器图像的长度是否与 IP 适配器数量相等if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):# 如果不相等,抛出值错误raise ValueError(f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters.")# 遍历输入适配器图像和相应的图像投影层for single_ip_adapter_image, image_proj_layer in zip(ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers):# 确定是否输出隐藏状态,依据图像投影层的类型output_hidden_state = not isinstance(image_proj_layer, ImageProjection)# 编码单个图像,获取嵌入和负嵌入single_image_embeds, single_negative_image_embeds = self.encode_image(single_ip_adapter_image, device, 1, output_hidden_state)# 将图像嵌入添加到列表中,增加一个维度image_embeds.append(single_image_embeds[None, :])# 如果启用了无分类器自由引导,则将负图像嵌入添加到列表中if do_classifier_free_guidance:negative_image_embeds.append(single_negative_image_embeds[None, :])else:# 如果输入适配器图像嵌入已存在for single_image_embeds in ip_adapter_image_embeds:# 如果启用了无分类器自由引导,将嵌入分成负嵌入和正嵌入if do_classifier_free_guidance:single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)# 添加负图像嵌入到列表中negative_image_embeds.append(single_negative_image_embeds)# 添加正图像嵌入到列表中image_embeds.append(single_image_embeds)# 初始化一个空列表,用于存储处理后的输入适配器图像嵌入ip_adapter_image_embeds = []# 遍历图像嵌入,执行重复操作以匹配每个提示的图像数量for i, single_image_embeds in enumerate(image_embeds):# 将单个图像嵌入沿着维度 0 重复指定次数single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)# 如果启用了无分类器自由引导,处理负嵌入if do_classifier_free_guidance:single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)# 将负嵌入与正嵌入合并single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)# 将嵌入移动到指定的设备single_image_embeds = single_image_embeds.to(device=device)# 将处理后的嵌入添加到列表中ip_adapter_image_embeds.append(single_image_embeds)# 返回处理后的输入适配器图像嵌入列表return ip_adapter_image_embeds# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 复制的内容# 准备额外的参数用于调度器步骤,因为并非所有调度器具有相同的参数签名def prepare_extra_step_kwargs(self, generator, eta):# eta (η) 仅在 DDIMScheduler 中使用,其他调度器会忽略此参数# eta 对应于 DDIM 论文中的 η: https://arxiv.org/abs/2010.02502# 该值应在 [0, 1] 之间# 检查调度器的步骤方法是否接受 eta 参数accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())# 创建一个空字典用于存放额外参数extra_step_kwargs = {}# 如果调度器接受 eta,则将其添加到额外参数中if accepts_eta:extra_step_kwargs["eta"] = eta# 检查调度器的步骤方法是否接受 generator 参数accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())# 如果调度器接受 generator,则将其添加到额外参数中if accepts_generator:extra_step_kwargs["generator"] = generator# 返回包含额外参数的字典return extra_step_kwargs# 检查输入参数的有效性def check_inputs(self,prompt,prompt_2,image,strength,num_inference_steps,callback_steps,negative_prompt=None,negative_prompt_2=None,prompt_embeds=None,negative_prompt_embeds=None,pooled_prompt_embeds=None,negative_pooled_prompt_embeds=None,ip_adapter_image=None,ip_adapter_image_embeds=None,controlnet_conditioning_scale=1.0,control_guidance_start=0.0,control_guidance_end=1.0,callback_on_step_end_tensor_inputs=None,# 从 diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image 复制的参数# 检查输入图像的类型和形状,确保与提示的批量大小一致def check_image(self, image, prompt, prompt_embeds):# 判断输入是否为 PIL 图像image_is_pil = isinstance(image, PIL.Image.Image)# 判断输入是否为 PyTorch 张量image_is_tensor = isinstance(image, torch.Tensor)# 判断输入是否为 NumPy 数组image_is_np = isinstance(image, np.ndarray)# 判断输入是否为 PIL 图像列表image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)# 判断输入是否为 PyTorch 张量列表image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)# 判断输入是否为 NumPy 数组列表image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)# 如果输入不符合任何类型,抛出类型错误if (not image_is_piland not image_is_tensorand not image_is_npand not image_is_pil_listand not image_is_tensor_listand not image_is_np_list):raise TypeError(f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}")# 如果输入为 PIL 图像,设置批量大小为 1if image_is_pil:image_batch_size = 1else:# 否则,根据输入的长度确定批量大小image_batch_size = len(image)# 如果提示不为 None 且为字符串,设置提示批量大小为 1if prompt is not None and isinstance(prompt, str):prompt_batch_size = 1# 如果提示为列表,根据列表长度设置批量大小elif prompt is not None and isinstance(prompt, list):prompt_batch_size = len(prompt)# 如果提示嵌入不为 None,使用其第一维的大小作为批量大小elif prompt_embeds is not None:prompt_batch_size = prompt_embeds.shape[0]# 如果图像批量大小不为 1,且与提示批量大小不一致,抛出值错误if image_batch_size != 1 and image_batch_size != prompt_batch_size:raise ValueError(f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}")# 从 diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl 导入的 prepare_image 方法def prepare_control_image(self,image,width,height,batch_size,num_images_per_prompt,device,dtype,do_classifier_free_guidance=False,guess_mode=False,):# 预处理输入图像并转换为指定的数据类型image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)# 获取图像批量大小image_batch_size = image.shape[0]# 如果图像批量大小为 1,重复次数设置为 batch_sizeif image_batch_size == 1:repeat_by = batch_sizeelse:# 如果图像批量大小与提示批量大小相同,设置重复次数为每个提示的图像数量repeat_by = num_images_per_prompt# 重复图像以匹配所需的批量大小image = image.repeat_interleave(repeat_by, dim=0)# 将图像转移到指定设备和数据类型image = image.to(device=device, dtype=dtype)# 如果启用分类器自由引导并且不在猜测模式下,复制图像以增加维度if do_classifier_free_guidance and not guess_mode:image = torch.cat([image] * 2)# 返回处理后的图像return image# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img 导入的 get_timesteps 方法# 获取时间步的函数,接收推理步骤数、强度和设备参数def get_timesteps(self, num_inference_steps, strength, device):# 计算原始时间步,使用 init_timestep,确保不超过推理步骤数init_timestep = min(int(num_inference_steps * strength), num_inference_steps)# 计算开始时间步,确保不小于零t_start = max(num_inference_steps - init_timestep, 0)# 从调度器获取时间步,截取从 t_start 开始的所有时间步timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]# 如果调度器具有设置开始索引的方法,则调用该方法if hasattr(self.scheduler, "set_begin_index"):self.scheduler.set_begin_index(t_start * self.scheduler.order)# 返回时间步和剩余的推理步骤数return timesteps, num_inference_steps - t_start# 从 StableDiffusionXLImg2ImgPipeline 复制的准备潜在变量的函数def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True# 从 StableDiffusionXLImg2ImgPipeline 复制的获取附加时间 ID 的函数def _get_add_time_ids(self,original_size,crops_coords_top_left,target_size,aesthetic_score,negative_aesthetic_score,negative_original_size,negative_crops_coords_top_left,negative_target_size,dtype,text_encoder_projection_dim=None,):# 检查配置是否需要美学评分if self.config.requires_aesthetics_score:# 创建包含原始大小、裁剪坐标及美学评分的列表add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))# 创建包含负样本原始大小、裁剪坐标及负美学评分的列表add_neg_time_ids = list(negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,))else:# 创建包含原始大小、裁剪坐标和目标大小的列表add_time_ids = list(original_size + crops_coords_top_left + target_size)# 创建包含负样本原始大小、裁剪坐标及负目标大小的列表add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)# 计算通过添加时间嵌入维度和文本编码器投影维度得到的通过嵌入维度passed_add_embed_dim = (self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim)# 获取模型期望的添加嵌入维度expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features# 检查期望的嵌入维度是否大于传递的嵌入维度,并符合特定条件if (expected_add_embed_dim > passed_add_embed_dimand (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim):# 抛出值错误,说明创建的嵌入维度不符合预期raise ValueError(f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model.")# 检查期望的嵌入维度是否小于传递的嵌入维度,并符合特定条件elif (expected_add_embed_dim < passed_add_embed_dimand (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim):# 抛出值错误,说明创建的嵌入维度不符合预期raise ValueError(f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model.")# 检查期望的嵌入维度是否与传递的嵌入维度不相等elif expected_add_embed_dim != passed_add_embed_dim:# 抛出值错误,说明模型配置不正确raise ValueError(f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`.")# 将添加的时间 ID 转换为张量,并指定数据类型add_time_ids = torch.tensor([add_time_ids], dtype=dtype)# 将添加的负时间 ID 转换为张量,并指定数据类型add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)# 返回添加的时间 ID 和添加的负时间 IDreturn add_time_ids, add_neg_time_ids# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae 复制而来# 定义一个方法,用于将 VAE 模型的参数类型提升def upcast_vae(self):# 获取当前 VAE 模型的数据类型dtype = self.vae.dtype# 将 VAE 模型转换为 float32 数据类型self.vae.to(dtype=torch.float32)# 检查 VAE 解码器中第一个注意力处理器的类型,以确定是否使用了特定版本的处理器use_torch_2_0_or_xformers = isinstance(self.vae.decoder.mid_block.attentions[0].processor,(AttnProcessor2_0,XFormersAttnProcessor,),)# 如果使用了 xformers 或 torch_2_0,注意力块不需要为 float32 类型,从而节省大量内存if use_torch_2_0_or_xformers:# 将后量化卷积层转换为原始数据类型self.vae.post_quant_conv.to(dtype)# 将解码器输入卷积层转换为原始数据类型self.vae.decoder.conv_in.to(dtype)# 将解码器中间块转换为原始数据类型self.vae.decoder.mid_block.to(dtype)# 定义一个属性,返回当前的引导缩放比例@propertydef guidance_scale(self):# 返回内部存储的引导缩放比例return self._guidance_scale# 定义一个属性,返回当前的剪辑跳过值@propertydef clip_skip(self):# 返回内部存储的剪辑跳过值return self._clip_skip# 定义一个属性,用于判断是否进行无分类器引导,依据是引导缩放比例是否大于 1# 此属性的定义参考了 Imagen 论文中的方程 (2)# 当 `guidance_scale = 1` 时,相当于不进行无分类器引导@propertydef do_classifier_free_guidance(self):# 如果引导缩放比例大于 1,返回 True,否则返回 Falsereturn self._guidance_scale > 1# 定义一个属性,返回当前的交叉注意力参数@propertydef cross_attention_kwargs(self):# 返回内部存储的交叉注意力参数return self._cross_attention_kwargs# 定义一个属性,返回当前的时间步数@propertydef num_timesteps(self):# 返回内部存储的时间步数return self._num_timesteps# 装饰器,表示在执行下面的方法时不计算梯度@torch.no_grad()# 装饰器,用于替换示例文档字符串@replace_example_docstring(EXAMPLE_DOC_STRING)# 定义一个可调用的类方法,接受多个参数用于处理图像生成def __call__(# 主提示字符串或字符串列表,默认为 Noneself,prompt: Union[str, List[str]] = None,# 第二个提示字符串或字符串列表,默认为 Noneprompt_2: Optional[Union[str, List[str]]] = None,# 输入图像,用于图像生成的基础,默认为 Noneimage: PipelineImageInput = None,# 控制图像,用于影响生成的图像,默认为 Nonecontrol_image: PipelineImageInput = None,# 输出图像的高度,默认为 Noneheight: Optional[int] = None,# 输出图像的宽度,默认为 Nonewidth: Optional[int] = None,# 图像生成的强度,默认为 0.8strength: float = 0.8,# 进行推理的步数,默认为 50num_inference_steps: int = 50,# 引导尺度,控制图像生成的引导程度,默认为 5.0guidance_scale: float = 5.0,# 负面提示字符串或字符串列表,默认为 Nonenegative_prompt: Optional[Union[str, List[str]]] = None,# 第二个负面提示字符串或字符串列表,默认为 Nonenegative_prompt_2: Optional[Union[str, List[str]]] = None,# 每个提示生成的图像数量,默认为 1num_images_per_prompt: Optional[int] = 1,# 采样的 eta 值,默认为 0.0eta: float = 0.0,# 随机数生成器,可选,默认为 Nonegenerator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,# 潜在变量,默认为 Nonelatents: Optional[torch.Tensor] = None,# 提示的嵌入向量,默认为 Noneprompt_embeds: Optional[torch.Tensor] = None,# 负面提示的嵌入向量,默认为 Nonenegative_prompt_embeds: Optional[torch.Tensor] = None,# 聚合的提示嵌入向量,默认为 Nonepooled_prompt_embeds: Optional[torch.Tensor] = None,# 负面聚合提示嵌入向量,默认为 Nonenegative_pooled_prompt_embeds: Optional[torch.Tensor] = None,# 输入适配器图像,默认为 Noneip_adapter_image: Optional[PipelineImageInput] = None,# 输入适配器图像的嵌入向量,默认为 Noneip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,# 输出类型,默认为 "pil"output_type: Optional[str] = "pil",# 是否返回字典,默认为 Truereturn_dict: bool = True,# 交叉注意力参数,默认为 Nonecross_attention_kwargs: Optional[Dict[str, Any]] = None,# 控制网络的条件缩放,默认为 0.8controlnet_conditioning_scale: Union[float, List[float]] = 0.8,# 猜测模式,默认为 Falseguess_mode: bool = False,# 控制引导的开始位置,默认为 0.0control_guidance_start: Union[float, List[float]] = 0.0,# 控制引导的结束位置,默认为 1.0control_guidance_end: Union[float, List[float]] = 1.0,# 原始图像的尺寸,默认为 Noneoriginal_size: Tuple[int, int] = None,# 裁剪坐标的左上角,默认为 (0, 0)crops_coords_top_left: Tuple[int, int] = (0, 0),# 目标尺寸,默认为 Nonetarget_size: Tuple[int, int] = None,# 负面原始图像的尺寸,默认为 Nonenegative_original_size: Optional[Tuple[int, int]] = None,# 负面裁剪坐标的左上角,默认为 (0, 0)negative_crops_coords_top_left: Tuple[int, int] = (0, 0),# 负目标尺寸,默认为 Nonenegative_target_size: Optional[Tuple[int, int]] = None,# 审美分数,默认为 6.0aesthetic_score: float = 6.0,# 负面审美分数,默认为 2.5negative_aesthetic_score: float = 2.5,# 跳过的剪辑层数,默认为 Noneclip_skip: Optional[int] = None,# 步骤结束时的回调函数,可选,默认为 Nonecallback_on_step_end: Optional[Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]] = None,# 结束步骤时的张量输入回调,默认为 ["latents"]callback_on_step_end_tensor_inputs: List[str] = ["latents"],# 其他额外参数,默认为空**kwargs,

.\diffusers\pipelines\controlnet\pipeline_flax_controlnet.py

# 版权所有 2024 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证,版本 2.0(“许可证”)授权;
# 除非遵守许可证,否则不得使用此文件。
# 可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面协议另有规定,
# 否则根据许可证分发的软件是“按原样”提供的,
# 不提供任何形式的担保或条件,无论是明示或暗示。
# 有关许可证下的特定语言的权限和限制,请参见许可证。import warnings  # 导入警告模块,用于处理警告信息
from functools import partial  # 从 functools 导入 partial,用于部分函数应用
from typing import Dict, List, Optional, Union  # 导入类型提示,方便函数参数和返回值的类型注释import jax  # 导入 JAX,用于高性能数值计算
import jax.numpy as jnp  # 导入 JAX 的 NumPy 接口,提供数组操作功能
import numpy as np  # 导入 NumPy,提供数值计算功能
from flax.core.frozen_dict import FrozenDict  # 从 flax 导入 FrozenDict,用于不可变字典
from flax.jax_utils import unreplicate  # 从 flax 导入 unreplicate,用于在 JAX 中处理设备数据
from flax.training.common_utils import shard  # 从 flax 导入 shard,用于数据并行
from PIL import Image  # 从 PIL 导入 Image,用于图像处理
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel  # 导入 CLIP 相关模块,处理图像和文本from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel  # 导入模型定义
from ...schedulers import (  # 导入调度器,用于训练过程中的控制FlaxDDIMScheduler,FlaxDPMSolverMultistepScheduler,FlaxLMSDiscreteScheduler,FlaxPNDMScheduler,
)
from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring  # 导入工具函数和常量
from ..pipeline_flax_utils import FlaxDiffusionPipeline  # 导入扩散管道
from ..stable_diffusion import FlaxStableDiffusionPipelineOutput  # 导入稳定扩散管道输出
from ..stable_diffusion.safety_checker_flax import FlaxStableDiffusionSafetyChecker  # 导入安全检查器logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器,方便调试和信息输出# 设置为 True 以使用 Python 循环而不是 jax.fori_loop,以便于调试
DEBUG = False  # 调试模式标志,默认为关闭状态EXAMPLE_DOC_STRING = """  # 示例文档字符串,可能用于文档生成或示例展示```  # 示例结束标志Examples:```py>>> import jax  # 导入 JAX 库,用于高性能数值计算>>> import numpy as np  # 导入 NumPy 库,支持数组操作>>> import jax.numpy as jnp  # 导入 JAX 的 NumPy,支持自动微分和GPU加速>>> from flax.jax_utils import replicate  # 从 Flax 导入 replicate 函数,用于参数复制>>> from flax.training.common_utils import shard  # 从 Flax 导入 shard 函数,用于数据分片>>> from diffusers.utils import load_image, make_image_grid  # 从 diffusers 导入图像加载和网格生成工具>>> from PIL import Image  # 导入 PIL 库,用于图像处理>>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel  # 导入用于稳定扩散模型和控制网的类>>> def create_key(seed=0):  # 定义函数创建随机数生成器的密钥...     return jax.random.PRNGKey(seed)  # 返回一个以 seed 为种子的 PRNG 密钥>>> rng = create_key(0)  # 创建随机数生成器的密钥,种子为 0>>> # get canny image  # 获取 Canny 边缘检测图像>>> canny_image = load_image(  # 使用 load_image 函数加载图像...     "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg"  # 指定图像的 URL... )>>> prompts = "best quality, extremely detailed"  # 定义用于生成图像的正向提示>>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality"  # 定义生成图像时要避免的负向提示>>> # load control net and stable diffusion v1-5  # 加载控制网络和稳定扩散模型 v1-5>>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(  # 从预训练模型加载控制网络及其参数...     "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32  # 指定模型名称、来源及数据类型... )>>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(  # 从预训练模型加载稳定扩散管道及其参数...     "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32  # 指定模型名称、控制网、版本和数据类型... )>>> params["controlnet"] = controlnet_params  # 将控制网参数存入管道参数中>>> num_samples = jax.device_count()  # 获取当前设备的数量,设置样本数量>>> rng = jax.random.split(rng, jax.device_count())  # 将随机数生成器的密钥根据设备数量进行分割>>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)  # 准备正向提示的输入,针对每个样本复制>>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)  # 准备负向提示的输入,针对每个样本复制>>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)  # 准备处理后的图像输入,针对每个样本复制>>> p_params = replicate(params)  # 复制参数以便在多个设备上使用>>> prompt_ids = shard(prompt_ids)  # 将正向提示的输入数据进行分片>>> negative_prompt_ids = shard(negative_prompt_ids)  # 将负向提示的输入数据进行分片>>> processed_image = shard(processed_image)  # 将处理后的图像输入数据进行分片>>> output = pipe(  # 调用管道生成输出...     prompt_ids=prompt_ids,  # 传入正向提示 ID...     image=processed_image,  # 传入处理后的图像...     params=p_params,  # 传入复制的参数...     prng_seed=rng,  # 传入随机数生成器的密钥...     num_inference_steps=50,  # 设置推理的步骤数...     neg_prompt_ids=negative_prompt_ids,  # 传入负向提示 ID...     jit=True,  # 启用 JIT 编译... ).images  # 获取生成的图像>>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))  # 将输出图像转换为 PIL 格式>>> output_images = make_image_grid(output_images, num_samples // 4, 4)  # 将图像生成网格格式,指定每行显示的图像数量>>> output_images.save("generated_image.png")  # 保存生成的图像为 PNG 文件``` 
# 定义一个类,基于 Flax 实现 Stable Diffusion 的控制网文本到图像生成管道
class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):r"""基于 Flax 的管道,用于使用 Stable Diffusion 和 ControlNet 指导进行文本到图像生成。此模型继承自 [`FlaxDiffusionPipeline`]。有关所有管道实现的通用方法(下载、保存、在特定设备上运行等),请查看超类文档。参数:vae ([`FlaxAutoencoderKL`]):用于将图像编码和解码为潜在表示的变分自编码器(VAE)模型。text_encoder ([`~transformers.FlaxCLIPTextModel`]):冻结的文本编码器([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14))。tokenizer ([`~transformers.CLIPTokenizer`]):用于对文本进行分词的 `CLIPTokenizer`。unet ([`FlaxUNet2DConditionModel`]):一个 `FlaxUNet2DConditionModel`,用于去噪编码后的图像潜在表示。controlnet ([`FlaxControlNetModel`]):在去噪过程中为 `unet` 提供额外的条件信息。scheduler ([`SchedulerMixin`]):用于与 `unet` 结合使用的调度器,以去噪编码的图像潜在表示。可以是[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`] 或[`FlaxDPMSolverMultistepScheduler`] 中的一个。safety_checker ([`FlaxStableDiffusionSafetyChecker`]):分类模块,评估生成的图像是否可能被视为冒犯或有害。有关模型潜在危害的更多细节,请参阅 [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5)。feature_extractor ([`~transformers.CLIPImageProcessor`]):一个 `CLIPImageProcessor`,用于提取生成图像的特征;用于 `safety_checker` 的输入。"""# 初始化方法,定义所需参数及其类型def __init__(# 变分自编码器(VAE)模型,用于图像编码和解码vae: FlaxAutoencoderKL,# 冻结的文本编码器模型text_encoder: FlaxCLIPTextModel,# 文本分词器tokenizer: CLIPTokenizer,# 去噪模型unet: FlaxUNet2DConditionModel,# 控制网模型controlnet: FlaxControlNetModel,# 图像去噪的调度器scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler],# 安全检查模块safety_checker: FlaxStableDiffusionSafetyChecker,# 特征提取器feature_extractor: CLIPImageProcessor,# 数据类型,默认为 32 位浮点数dtype: jnp.dtype = jnp.float32,):# 调用父类的初始化方法super().__init__()# 设置数据类型属性self.dtype = dtype# 检查安全检查器是否为 Noneif safety_checker is None:# 记录警告,告知用户已禁用安全检查器logger.warning(f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"" results in services or applications open to the public. Both the diffusers team and Hugging Face"" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"" it only for use-cases that involve analyzing network behavior or auditing its results. For more"" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .")# 注册各个模块,方便后续使用self.register_modules(vae=vae,  # 变分自编码器text_encoder=text_encoder,  # 文本编码器tokenizer=tokenizer,  # 分词器unet=unet,  # UNet 模型controlnet=controlnet,  # 控制网络scheduler=scheduler,  # 调度器safety_checker=safety_checker,  # 安全检查器feature_extractor=feature_extractor,  # 特征提取器)# 计算 VAE 的缩放因子self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)def prepare_text_inputs(self, prompt: Union[str, List[str]]):# 检查 prompt 类型是否为字符串或列表if not isinstance(prompt, (str, list)):# 如果类型不符,抛出值错误raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")# 使用分词器处理输入文本text_input = self.tokenizer(prompt,  # 输入的提示文本padding="max_length",  # 填充到最大长度max_length=self.tokenizer.model_max_length,  # 设置最大长度为分词器的最大模型长度truncation=True,  # 如果超过最大长度,则截断return_tensors="np",  # 返回 NumPy 格式的张量)# 返回处理后的输入 IDreturn text_input.input_idsdef prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]):# 检查图像类型是否为 PIL.Image 或列表if not isinstance(image, (Image.Image, list)):# 如果类型不符,抛出值错误raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")# 如果输入是单个图像,将其转换为列表if isinstance(image, Image.Image):image = [image]# 对所有图像进行预处理,并合并为一个数组processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image])# 返回处理后的图像数组return processed_imagesdef _get_has_nsfw_concepts(self, features, params):# 使用安全检查器检查是否存在不适当内容概念has_nsfw_concepts = self.safety_checker(features, params)# 返回检查结果return has_nsfw_concepts# 定义一个安全检查的私有方法,接收图像、模型参数和是否使用 JIT 编译的标志def _run_safety_checker(self, images, safety_model_params, jit=False):# 当 jit 为 True 时,safety_model_params 应该已经被复制# 将输入的图像数组转换为 PIL 图像格式pil_images = [Image.fromarray(image) for image in images]# 使用特征提取器处理 PIL 图像,返回其像素值features = self.feature_extractor(pil_images, return_tensors="np").pixel_values# 如果启用 JIT 编译if jit:# 对特征进行分片处理features = shard(features)# 检查特征中是否存在 NSFW 概念has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)# 取消特征的分片has_nsfw_concepts = unshard(has_nsfw_concepts)# 取消模型参数的复制safety_model_params = unreplicate(safety_model_params)else:# 否则,直接获取 NSFW 概念的检查结果has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)# 初始化一个标志,指示图像是否已经被复制images_was_copied = False# 遍历每个 NSFW 概念的检查结果for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):# 如果检测到 NSFW 概念if has_nsfw_concept:# 如果还没有复制图像if not images_was_copied:# 标记为已复制,并进行图像复制images_was_copied = Trueimages = images.copy()# 将对应的图像替换为全黑图像images[idx] = np.zeros(images[idx].shape, dtype=np.uint8)  # black image# 如果存在任何 NSFW 概念if any(has_nsfw_concepts):# 发出警告,提示可能检测到不适宜内容warnings.warn("Potential NSFW content was detected in one or more images. A black image will be returned"" instead. Try again with a different prompt and/or seed.")# 返回处理后的图像和 NSFW 概念的检查结果return images, has_nsfw_concepts# 定义一个生成图像的私有方法,接收多个参数以控制生成过程def _generate(self,prompt_ids: jnp.ndarray,  # 输入的提示 ID 数组image: jnp.ndarray,  # 输入的图像数据params: Union[Dict, FrozenDict],  # 模型参数,可能是字典或不可变字典prng_seed: jax.Array,  # 随机种子,用于随机数生成num_inference_steps: int,  # 推理步骤的数量guidance_scale: float,  # 指导比例,用于控制生成质量latents: Optional[jnp.ndarray] = None,  # 潜在变量,默认值为 Noneneg_prompt_ids: Optional[jnp.ndarray] = None,  # 负提示 ID,默认值为 Nonecontrolnet_conditioning_scale: float = 1.0,  # 控制网络的条件缩放比例@replace_example_docstring(EXAMPLE_DOC_STRING)# 定义可调用的方法,接收多个参数以控制生成过程def __call__(self,prompt_ids: jnp.ndarray,  # 输入的提示 ID 数组image: jnp.ndarray,  # 输入的图像数据params: Union[Dict, FrozenDict],  # 模型参数,可能是字典或不可变字典prng_seed: jax.Array,  # 随机种子,用于随机数生成num_inference_steps: int = 50,  # 默认推理步骤的数量为 50guidance_scale: Union[float, jnp.ndarray] = 7.5,  # 默认指导比例为 7.5latents: jnp.ndarray = None,  # 潜在变量,默认值为 Noneneg_prompt_ids: jnp.ndarray = None,  # 负提示 ID,默认值为 Nonecontrolnet_conditioning_scale: Union[float, jnp.ndarray] = 1.0,  # 默认控制网络的条件缩放比例为 1.0return_dict: bool = True,  # 默认返回字典格式jit: bool = False,  # 默认不启用 JIT 编译
# 静态参数为 pipe 和 num_inference_steps,任何更改都会触发重新编译。
# 非静态参数是(分片)输入张量,这些张量在它们的第一维上被映射(因此为 `0`)。
@partial(jax.pmap,  # 使用 JAX 的 pmap 并行映射功能in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0),  # 指定输入张量的轴static_broadcasted_argnums=(0, 5),  # 指定静态广播参数的索引
)
def _p_generate(  # 定义生成函数pipe,  # 生成管道对象prompt_ids,  # 提示 IDimage,  # 输入图像params,  # 生成参数prng_seed,  # 随机数生成种子num_inference_steps,  # 推理步骤数guidance_scale,  # 指导尺度latents,  # 潜在变量neg_prompt_ids,  # 负提示 IDcontrolnet_conditioning_scale,  # 控制网条件尺度
):return pipe._generate(  # 调用生成管道的生成方法prompt_ids,  # 提示 IDimage,  # 输入图像params,  # 生成参数prng_seed,  # 随机数生成种子num_inference_steps,  # 推理步骤数guidance_scale,  # 指导尺度latents,  # 潜在变量neg_prompt_ids,  # 负提示 IDcontrolnet_conditioning_scale,  # 控制网条件尺度)@partial(jax.pmap, static_broadcasted_argnums=(0,))  # 使用 JAX 的 pmap,并指定静态广播参数
def _p_get_has_nsfw_concepts(pipe, features, params):  # 定义检查是否有 NSFW 概念的函数return pipe._get_has_nsfw_concepts(features, params)  # 调用管道的相关方法def unshard(x: jnp.ndarray):  # 定义反分片函数,接受一个张量# einops.rearrange(x, 'd b ... -> (d b) ...')  # 注释掉的排列操作num_devices, batch_size = x.shape[:2]  # 获取设备数量和批量大小rest = x.shape[2:]  # 获取其余维度return x.reshape(num_devices * batch_size, *rest)  # 重新调整形状以合并设备和批量维度def preprocess(image, dtype):  # 定义图像预处理函数image = image.convert("RGB")  # 将图像转换为 RGB 模式w, h = image.size  # 获取图像的宽和高w, h = (x - x % 64 for x in (w, h))  # 将宽高调整为64的整数倍image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])  # 调整图像大小,使用 Lanczos 插值法image = jnp.array(image).astype(dtype) / 255.0  # 转换为 NumPy 数组并归一化到 [0, 1]image = image[None].transpose(0, 3, 1, 2)  # 添加新维度并调整通道顺序return image  # 返回处理后的图像

.\diffusers\pipelines\controlnet\__init__.py

# 导入类型检查工具
from typing import TYPE_CHECKING# 从 utils 模块导入必要的工具和常量
from ...utils import (DIFFUSERS_SLOW_IMPORT,  # 导入慢导入标志OptionalDependencyNotAvailable,  # 导入可选依赖不可用异常_LazyModule,  # 导入延迟模块工具get_objects_from_module,  # 导入从模块获取对象的函数is_flax_available,  # 导入检查 Flax 可用性的函数is_torch_available,  # 导入检查 Torch 可用性的函数is_transformers_available,  # 导入检查 Transformers 可用性的函数
)# 初始化一个空字典用于存储虚拟对象
_dummy_objects = {}
# 初始化一个空字典用于存储导入结构
_import_structure = {}try:# 检查 Transformers 和 Torch 是否可用if not (is_transformers_available() and is_torch_available()):# 如果不可用,抛出异常raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:# 从 utils 导入虚拟对象from ...utils import dummy_torch_and_transformers_objects  # noqa F403# 更新虚拟对象字典_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:# 如果依赖可用,更新导入结构字典_import_structure["multicontrolnet"] = ["MultiControlNetModel"]_import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]_import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]_import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]_import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]_import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]_import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]_import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]try:# 检查 Transformers 和 Flax 是否可用if not (is_transformers_available() and is_flax_available()):# 如果不可用,抛出异常raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:# 从 utils 导入虚拟 Flax 和 Transformers 对象from ...utils import dummy_flax_and_transformers_objects  # noqa F403# 更新虚拟对象字典_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
else:# 如果依赖可用,更新导入结构字典_import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]# 如果类型检查或慢导入标志被设置
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:try:# 检查 Transformers 和 Torch 是否可用if not (is_transformers_available() and is_torch_available()):# 如果不可用,抛出异常raise OptionalDependencyNotAvailable()# 捕获可选依赖不可用的异常except OptionalDependencyNotAvailable:# 导入虚拟的 Torch 和 Transformers 对象from ...utils.dummy_torch_and_transformers_objects import *else:# 如果依赖可用,导入相应模块from .multicontrolnet import MultiControlNetModelfrom .pipeline_controlnet import StableDiffusionControlNetPipelinefrom .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipelinefrom .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipelinefrom .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipelinefrom .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipelinefrom .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipelinefrom .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipelinetry:# 检查 Transformers 和 Flax 是否可用if not (is_transformers_available() and is_flax_available()):# 如果不可用,抛出异常raise OptionalDependencyNotAvailable()# 捕获可选依赖项不可用的异常except OptionalDependencyNotAvailable:# 从 dummy 模块导入所有内容,忽略 F403 警告from ...utils.dummy_flax_and_transformers_objects import *  # noqa F403else:# 从 pipeline_flax_controlnet 模块导入 FlaxStableDiffusionControlNetPipelinefrom .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
# 如果之前的条件不满足,执行以下代码
else:# 导入 sys 模块,用于访问和操作 Python 解释器的运行时环境import sys# 将当前模块替换为一个延迟加载的模块对象sys.modules[__name__] = _LazyModule(__name__,  # 模块名称globals()["__file__"],  # 当前文件的路径_import_structure,  # 模块的导入结构module_spec=__spec__,  # 模块的规格对象)# 遍历假对象字典,将每个对象属性设置到当前模块for name, value in _dummy_objects.items():setattr(sys.modules[__name__], name, value)  # 设置模块属性

.\diffusers\pipelines\controlnet_hunyuandit\pipeline_hunyuandit_controlnet.py

# 版权声明,指明文件的版权归 HunyuanDiT 和 HuggingFace 团队所有
# 本文件在 Apache 2.0 许可证下授权使用
# 除非遵循许可证,否则不能使用此文件
# 许可证的副本可以在以下网址获取
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律规定或书面协议另有约定,否则软件在"按现状"基础上提供,不附带任何明示或暗示的保证
# 查看许可证以了解特定语言的权限和限制# 导入用于获取函数信息的 inspect 模块
import inspect
# 导入类型提示所需的类型
from typing import Callable, Dict, List, Optional, Tuple, Union# 导入 numpy 库
import numpy as np
# 导入 PyTorch 库
import torch
# 从 transformers 库导入相关模型和分词器
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel# 从 diffusers 库导入 StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput# 导入多管道回调类
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
# 导入图像处理类
from ...image_processor import PipelineImageInput, VaeImageProcessor
# 导入自动编码器和模型
from ...models import AutoencoderKL, HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel
# 导入 2D 旋转位置嵌入函数
from ...models.embeddings import get_2d_rotary_pos_embed
# 导入稳定扩散安全检查器
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
# 导入扩散调度器
from ...schedulers import DDPMScheduler
# 导入实用工具函数
from ...utils import (is_torch_xla_available,  # 检查是否可用 XLAlogging,  # 导入日志记录模块replace_example_docstring,  # 替换示例文档字符串的工具
)
# 导入 PyTorch 相关的随机张量函数
from ...utils.torch_utils import randn_tensor
# 导入扩散管道工具类
from ..pipeline_utils import DiffusionPipeline# 检查是否可用 XLA,并根据结果导入相应模块
if is_torch_xla_available():import torch_xla.core.xla_model as xm  # 导入 XLA 核心模型XLA_AVAILABLE = True  # 设置 XLA 可用标志为 True
else:XLA_AVAILABLE = False  # 设置 XLA 可用标志为 False# 创建一个日志记录器实例,记录当前模块的日志
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name# 示例文档字符串,用于说明使用方法
EXAMPLE_DOC_STRING = """
# 示例代码展示如何使用 HunyuanDiT 进行图像生成Examples:```py# 从 diffusers 库导入所需的模型和管道from diffusers import HunyuanDiT2DControlNetModel, HunyuanDiTControlNetPipeline# 导入 PyTorch 库import torch# 从预训练模型加载 HunyuanDiT2DControlNetModel,并指定数据类型为 float16controlnet = HunyuanDiT2DControlNetModel.from_pretrained("Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny", torch_dtype=torch.float16)# 从预训练模型加载 HunyuanDiTControlNetPipeline,传入 controlnet 和数据类型pipe = HunyuanDiTControlNetPipeline.from_pretrained("Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16)# 将管道移动到 CUDA 设备以加速处理pipe.to("cuda")# 从 diffusers.utils 导入加载图像的工具from diffusers.utils import load_image# 从指定 URL 加载条件图像cond_image = load_image("https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny/resolve/main/canny.jpg?download=true")## HunyuanDiT 支持英语和中文提示,因此也可以使用英文提示# 定义图像生成的提示内容,描述夜晚的场景prompt = "在夜晚的酒店门前,一座古老的中国风格的狮子雕像矗立着,它的眼睛闪烁着光芒,仿佛在守护着这座建筑。背景是夜晚的酒店前,构图方式是特写,平视,居中构图。这张照片呈现了真实摄影风格,蕴含了中国雕塑文化,同时展现了神秘氛围"# prompt="At night, an ancient Chinese-style lion statue stands in front of the hotel, its eyes gleaming as if guarding the building. The background is the hotel entrance at night, with a close-up, eye-level, and centered composition. This photo presents a realistic photographic style, embodies Chinese sculpture culture, and reveals a mysterious atmosphere."# 使用提示、图像尺寸、条件图像和推理步骤生成图像,并获取生成的第一张图像image = pipe(prompt,height=1024,width=1024,control_image=cond_image,num_inference_steps=50,).images[0]```  

"""

文档字符串,通常用于描述模块或类的功能

"""

定义一个标准宽高比的 NumPy 数组

STANDARD_RATIO = np.array(
[
1.0, # 1:1
4.0 / 3.0, # 4:3
3.0 / 4.0, # 3:4
16.0 / 9.0, # 16:9
9.0 / 16.0, # 9:16
]
)

定义一个标准尺寸的列表,每个比例对应不同的宽高组合

STANDARD_SHAPE = [
[(1024, 1024), (1280, 1280)], # 1:1
[(1024, 768), (1152, 864), (1280, 960)], # 4:3
[(768, 1024), (864, 1152), (960, 1280)], # 3:4
[(1280, 768)], # 16:9
[(768, 1280)], # 9:16
]

根据标准尺寸计算每个形状的面积,并将结果存储在 NumPy 数组中

STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE]

定义一个支持的尺寸列表,包含不同的宽高组合

SUPPORTED_SHAPE = [
(1024, 1024),
(1280, 1280), # 1:1
(1024, 768),
(1152, 864),
(1280, 960), # 4:3
(768, 1024),
(864, 1152),
(960, 1280), # 3:4
(1280, 768), # 16:9
(768, 1280), # 9:16
]

定义一个函数,用于将目标宽高映射到标准形状

def map_to_standard_shapes(target_width, target_height):
# 计算目标宽高比
target_ratio = target_width / target_height
# 找到与目标宽高比最接近的标准宽高比的索引
closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
# 找到与目标面积最接近的标准形状的索引
closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height))
# 获取对应的标准宽和高
width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
# 返回标准宽和高
return width, height

定义一个函数,用于计算源图像的缩放裁剪区域以适应目标大小

def get_resize_crop_region_for_grid(src, tgt_size):
# 获取目标尺寸的高度和宽度
th = tw = tgt_size
# 获取源图像的高度和宽度
h, w = src

# 计算源图像的宽高比
r = h / w# 根据宽高比决定缩放方式
# 如果高度大于宽度
if r > 1:# 将目标高度作为缩放高度resize_height = th# 根据高度缩放计算对应的宽度resize_width = int(round(th / h * w))
else:# 否则,将目标宽度作为缩放宽度resize_width = tw# 根据宽度缩放计算对应的高度resize_height = int(round(tw / w * h))# 计算裁剪区域的顶部和左边位置
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))# 返回裁剪区域的起始和结束坐标
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)

从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg 复制的函数

def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
根据 guidance_rescalenoise_cfg 进行重新缩放。基于论文Common Diffusion Noise Schedules and
Sample Steps are Flawed
中的发现。见第3.4节
"""
# 计算噪声预测文本的标准差
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
# 计算噪声配置的标准差
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# 重新缩放来自引导的结果(修复过度曝光问题)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# 按照引导缩放因子与原始引导结果进行混合,以避免生成“单调”的图像
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
# 返回重新缩放后的噪声配置
return noise_cfg

定义 HunyuanDiT 控制网络管道类,继承自 DiffusionPipeline

class HunyuanDiTControlNetPipeline(DiffusionPipeline):
r"""
使用 HunyuanDiT 进行英语/中文到图像生成的管道。

该模型继承自 [`DiffusionPipeline`]. 请查看超类文档以获取库为所有管道实现的通用方法
(例如下载或保存,在特定设备上运行等)。HunyuanDiT 使用两个文本编码器:[mT5](https://huggingface.co/google/mt5-base) 和 [双语 CLIP](自行微调)
"""
# 参数说明
Args:vae ([`AutoencoderKL`]):  # 变分自编码器模型,用于将图像编码和解码为潜在表示,这里使用'sdxl-vae-fp16-fix'Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use`sdxl-vae-fp16-fix`.text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):  # 冻结的文本编码器,使用CLIP模型Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). HunyuanDiT uses a fine-tuned [bilingual CLIP].tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):  # 文本标记化器,可以是BertTokenizer或CLIPTokenizerA `BertTokenizer` or `CLIPTokenizer` to tokenize text.transformer ([`HunyuanDiT2DModel`]):  # HunyuanDiT模型,由腾讯Hunyuan设计The HunyuanDiT model designed by Tencent Hunyuan.text_encoder_2 (`T5EncoderModel`):  # mT5嵌入模型,特别是't5-v1_1-xxl'The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.tokenizer_2 (`MT5Tokenizer`):  # mT5嵌入模型的标记化器The tokenizer for the mT5 embedder.scheduler ([`DDPMScheduler`]):  # 调度器,用于与HunyuanDiT结合,去噪编码的图像潜在表示A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.controlnet ([`HunyuanDiT2DControlNetModel`] or `List[HunyuanDiT2DControlNetModel]` or [`HunyuanDiT2DControlNetModel`]):  # 提供额外的条件信息以辅助去噪过程Provides additional conditioning to the `unet` during the denoising process. If you set multipleControlNets as a list, the outputs from each ControlNet are added together to create one combinedadditional conditioning.
"""# 定义模型在CPU上卸载的顺序
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
# 可选组件列表,可能会在初始化中使用
_optional_components = ["safety_checker",  # 安全检查器"feature_extractor",  # 特征提取器"text_encoder_2",  # 第二个文本编码器"tokenizer_2",  # 第二个标记化器"text_encoder",  # 第一个文本编码器"tokenizer",  # 第一个标记化器
]
# 从CPU卸载中排除的组件
_exclude_from_cpu_offload = ["safety_checker"]  # 不允许卸载安全检查器
# 回调张量输入的列表,用于传递给模型
_callback_tensor_inputs = ["latents",  # 潜在变量"prompt_embeds",  # 提示的嵌入表示"negative_prompt_embeds",  # 负提示的嵌入表示"prompt_embeds_2",  # 第二个提示的嵌入表示"negative_prompt_embeds_2",  # 第二个负提示的嵌入表示
]# 初始化方法定义,接收多个参数以构造模型
def __init__(self,vae: AutoencoderKL,  # 变分自编码器模型text_encoder: BertModel,  # 文本编码器tokenizer: BertTokenizer,  # 文本标记化器transformer: HunyuanDiT2DModel,  # HunyuanDiT模型scheduler: DDPMScheduler,  # 调度器safety_checker: StableDiffusionSafetyChecker,  # 安全检查器feature_extractor: CLIPImageProcessor,  # 特征提取器controlnet: Union[  # 控制网络,可以是单个或多个模型HunyuanDiT2DControlNetModel,List[HunyuanDiT2DControlNetModel],Tuple[HunyuanDiT2DControlNetModel],HunyuanDiT2DMultiControlNetModel,],text_encoder_2=T5EncoderModel,  # 第二个文本编码器,默认使用T5模型tokenizer_2=MT5Tokenizer,  # 第二个标记化器,默认使用MT5标记化器requires_safety_checker: bool = True,  # 是否需要安全检查器,默认是True
# 初始化父类
):super().__init__()# 注册多个模块,提供必要的组件以供使用self.register_modules(vae=vae,  # 注册变分自编码器text_encoder=text_encoder,  # 注册文本编码器tokenizer=tokenizer,  # 注册分词器tokenizer_2=tokenizer_2,  # 注册第二个分词器transformer=transformer,  # 注册变换器scheduler=scheduler,  # 注册调度器safety_checker=safety_checker,  # 注册安全检查器feature_extractor=feature_extractor,  # 注册特征提取器text_encoder_2=text_encoder_2,  # 注册第二个文本编码器controlnet=controlnet,  # 注册控制网络)# 检查安全检查器是否为 None 并且需要使用安全检查器if safety_checker is None and requires_safety_checker:# 记录警告信息,提醒用户禁用安全检查器的后果logger.warning(f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"" results in services or applications open to the public. Both the diffusers team and Hugging Face"" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"" it only for use-cases that involve analyzing network behavior or auditing its results. For more"" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .")# 检查安全检查器不为 None 且特征提取器为 Noneif safety_checker is not None and feature_extractor is None:# 抛出错误,提醒用户必须定义特征提取器raise ValueError("Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.")# 计算 VAE 的缩放因子,如果存在 VAE 配置则使用其通道数量,否则默认为 8self.vae_scale_factor = (2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8)# 初始化图像处理器,传入 VAE 缩放因子self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)# 注册到配置中,指明是否需要安全检查器self.register_to_config(requires_safety_checker=requires_safety_checker)# 设置默认样本大小,根据变换器配置或默认为 128self.default_sample_size = (self.transformer.config.sample_sizeif hasattr(self, "transformer") and self.transformer is not Noneelse 128)# 从其他模块复制的方法,用于编码提示
def encode_prompt(self,prompt: str,  # 输入的提示文本device: torch.device = None,  # 设备参数,指定在哪个设备上处理dtype: torch.dtype = None,  # 数据类型参数,指定张量的数据类型num_images_per_prompt: int = 1,  # 每个提示生成的图像数量do_classifier_free_guidance: bool = True,  # 是否执行无分类器的引导negative_prompt: Optional[str] = None,  # 可选的负面提示文本prompt_embeds: Optional[torch.Tensor] = None,  # 可选的提示嵌入张量negative_prompt_embeds: Optional[torch.Tensor] = None,  # 可选的负面提示嵌入张量prompt_attention_mask: Optional[torch.Tensor] = None,  # 可选的提示注意力掩码negative_prompt_attention_mask: Optional[torch.Tensor] = None,  # 可选的负面提示注意力掩码max_sequence_length: Optional[int] = None,  # 可选的最大序列长度text_encoder_index: int = 0,  # 文本编码器索引,默认值为 0
# 从其他模块复制的方法,用于运行安全检查器
# 定义运行安全检查器的方法,接收图像、设备和数据类型作为参数
def run_safety_checker(self, image, device, dtype):# 如果安全检查器未定义,设置无敏感内容标志为 Noneif self.safety_checker is None:has_nsfw_concept = Noneelse:# 如果输入图像是张量,则后处理为 PIL 格式if torch.is_tensor(image):feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")else:# 如果输入图像不是张量,则将其转换为 PIL 格式feature_extractor_input = self.image_processor.numpy_to_pil(image)# 使用特征提取器处理输入图像并将其转换为指定设备上的张量safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)# 调用安全检查器,返回处理后的图像和无敏感内容标志image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))# 返回处理后的图像和无敏感内容标志return image, has_nsfw_concept# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 复制def prepare_extra_step_kwargs(self, generator, eta):# 为调度器步骤准备额外参数,因为不是所有调度器的参数签名相同# eta(η)仅在 DDIMScheduler 中使用,其他调度器将忽略# eta 在 DDIM 论文中的对应关系:https://arxiv.org/abs/2010.02502# 其值应在 [0, 1] 之间# 检查调度器的步骤是否接受 eta 参数accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())extra_step_kwargs = {}# 如果接受 eta,则将其添加到额外参数字典中if accepts_eta:extra_step_kwargs["eta"] = eta# 检查调度器的步骤是否接受 generator 参数accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())# 如果接受 generator,则将其添加到额外参数字典中if accepts_generator:extra_step_kwargs["generator"] = generator# 返回额外参数字典return extra_step_kwargs# 从 diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.check_inputs 复制def check_inputs(self,prompt,height,width,negative_prompt=None,prompt_embeds=None,negative_prompt_embeds=None,prompt_attention_mask=None,negative_prompt_attention_mask=None,prompt_embeds_2=None,negative_prompt_embeds_2=None,prompt_attention_mask_2=None,negative_prompt_attention_mask_2=None,callback_on_step_end_tensor_inputs=None,# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 复制
# 准备潜在变量的函数,接收多个参数以配置潜在变量的形状和生成方式def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):# 根据批大小、通道数、高度和宽度计算潜在变量的形状shape = (batch_size,num_channels_latents,int(height) // self.vae_scale_factor,int(width) // self.vae_scale_factor,)# 检查生成器列表的长度是否与批大小匹配if isinstance(generator, list) and len(generator) != batch_size:# 如果不匹配,抛出一个值错误raise ValueError(f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"f" size of {batch_size}. Make sure the batch size matches the length of the generators.")# 如果潜在变量为 None,则生成随机潜在变量if latents is None:latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)else:# 如果已提供潜在变量,将其移动到指定设备latents = latents.to(device)# 按调度器要求的标准差缩放初始噪声latents = latents * self.scheduler.init_noise_sigma# 返回处理后的潜在变量return latents# 准备图像的函数,从外部调用def prepare_image(self,image,width,height,batch_size,num_images_per_prompt,device,dtype,do_classifier_free_guidance=False,guess_mode=False,):# 检查图像是否为张量,如果是则不处理if isinstance(image, torch.Tensor):passelse:# 否则对图像进行预处理,调整为指定的高度和宽度image = self.image_processor.preprocess(image, height=height, width=width)# 获取图像的批大小image_batch_size = image.shape[0]# 如果图像批大小为1,则重复次数为批大小if image_batch_size == 1:repeat_by = batch_sizeelse:# 否则图像批大小与提示批大小相同repeat_by = num_images_per_prompt# 沿着维度0重复图像image = image.repeat_interleave(repeat_by, dim=0)# 将图像移动到指定设备,并转换为指定数据类型image = image.to(device=device, dtype=dtype)# 如果启用了无分类器自由引导,并且未启用猜测模式,则将图像复制两次if do_classifier_free_guidance and not guess_mode:image = torch.cat([image] * 2)# 返回处理后的图像return image# 获取指导比例的属性@propertydef guidance_scale(self):# 返回当前的指导比例return self._guidance_scale# 获取指导重标定的属性@propertydef guidance_rescale(self):# 返回当前的指导重标定值return self._guidance_rescale# 此属性定义了类似于论文中指导权重的定义@propertydef do_classifier_free_guidance(self):# 如果指导比例大于1,则启用无分类器自由引导return self._guidance_scale > 1# 获取时间步数的属性@propertydef num_timesteps(self):# 返回当前的时间步数return self._num_timesteps# 获取中断状态的属性@propertydef interrupt(self):# 返回当前中断状态return self._interrupt# 在不计算梯度的情况下运行,替换示例文档字符串@torch.no_grad()@replace_example_docstring(EXAMPLE_DOC_STRING)
# 定义一个可调用的类方法def __call__(# 提示内容,可以是字符串或字符串列表self,prompt: Union[str, List[str]] = None,# 输出图像的高度height: Optional[int] = None,# 输出图像的宽度width: Optional[int] = None,# 推理步骤的数量,默认为 50num_inference_steps: Optional[int] = 50,# 引导比例,默认为 5.0guidance_scale: Optional[float] = 5.0,# 控制图像输入,默认为 Nonecontrol_image: PipelineImageInput = None,# 控制网条件比例,可以是单一值或值列表,默认为 1.0controlnet_conditioning_scale: Union[float, List[float]] = 1.0,# 负提示内容,可以是字符串或字符串列表,默认为 Nonenegative_prompt: Optional[Union[str, List[str]]] = None,# 每个提示生成的图像数量,默认为 1num_images_per_prompt: Optional[int] = 1,# 用于生成的随机性,默认为 0.0eta: Optional[float] = 0.0,# 随机数生成器,可以是单个或列表,默认为 Nonegenerator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,# 潜在变量,默认为 Nonelatents: Optional[torch.Tensor] = None,# 提示的嵌入,默认为 Noneprompt_embeds: Optional[torch.Tensor] = None,# 第二组提示的嵌入,默认为 Noneprompt_embeds_2: Optional[torch.Tensor] = None,# 负提示的嵌入,默认为 Nonenegative_prompt_embeds: Optional[torch.Tensor] = None,# 第二组负提示的嵌入,默认为 Nonenegative_prompt_embeds_2: Optional[torch.Tensor] = None,# 提示的注意力掩码,默认为 Noneprompt_attention_mask: Optional[torch.Tensor] = None,# 第二组提示的注意力掩码,默认为 Noneprompt_attention_mask_2: Optional[torch.Tensor] = None,# 负提示的注意力掩码,默认为 Nonenegative_prompt_attention_mask: Optional[torch.Tensor] = None,# 第二组负提示的注意力掩码,默认为 Nonenegative_prompt_attention_mask_2: Optional[torch.Tensor] = None,# 输出类型,默认为 "pil"output_type: Optional[str] = "pil",# 是否返回字典格式,默认为 Truereturn_dict: bool = True,# 在步骤结束时的回调函数callback_on_step_end: Optional[Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]] = None,# 回调时的张量输入列表,默认为 ["latents"]callback_on_step_end_tensor_inputs: List[str] = ["latents"],# 引导重标定,默认为 0.0guidance_rescale: float = 0.0,# 原始图像大小,默认为 (1024, 1024)original_size: Optional[Tuple[int, int]] = (1024, 1024),# 目标图像大小,默认为 Nonetarget_size: Optional[Tuple[int, int]] = None,# 裁剪坐标,默认为 (0, 0)crops_coords_top_left: Tuple[int, int] = (0, 0),# 是否使用分辨率分箱,默认为 Trueuse_resolution_binning: bool = True,

# `.\diffusers\pipelines\controlnet_hunyuandit\__init__.py````py
# 从 typing 模块导入 TYPE_CHECKING,用于静态类型检查
from typing import TYPE_CHECKING# 从父模块的 utils 导入多个工具函数和常量
from ...utils import (DIFFUSERS_SLOW_IMPORT,  # 导入慢导入的标志OptionalDependencyNotAvailable,  # 导入可选依赖不可用的异常_LazyModule,  # 导入延迟加载模块的类get_objects_from_module,  # 导入从模块获取对象的函数is_torch_available,  # 导入检查 PyTorch 是否可用的函数is_transformers_available,  # 导入检查 Transformers 是否可用的函数
)# 创建一个空字典,用于存储假对象
_dummy_objects = {}
# 创建一个空字典,用于存储导入结构
_import_structure = {}# 尝试检查是否可用的依赖
try:# 如果 Transformers 和 Torch 不可用,抛出异常if not (is_transformers_available() and is_torch_available()):raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:# 从 utils 导入假对象(dummy objects),避免直接依赖from ...utils import dummy_torch_and_transformers_objects  # noqa F403# 更新 _dummy_objects 字典,包含假对象_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
# 如果依赖可用,更新导入结构
else:# 将 HunyuanDiTControlNetPipeline 加入导入结构_import_structure["pipeline_hunyuandit_controlnet"] = ["HunyuanDiTControlNetPipeline"]# 检查类型是否在检查模式或是否需要慢导入
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:# 尝试检查是否可用的依赖try:# 如果 Transformers 和 Torch 不可用,抛出异常if not (is_transformers_available() and is_torch_available()):raise OptionalDependencyNotAvailable()# 捕获可选依赖不可用的异常except OptionalDependencyNotAvailable:# 从 utils 导入所有假对象,避免直接依赖from ...utils.dummy_torch_and_transformers_objects import *else:# 导入真实的 HunyuanDiTControlNetPipeline 类from .pipeline_hunyuandit_controlnet import HunyuanDiTControlNetPipeline# 如果不在类型检查或不需要慢导入
else:# 导入 sys 模块import sys# 将当前模块替换为延迟加载模块,传递必要的信息sys.modules[__name__] = _LazyModule(__name__,  # 模块名称globals()["__file__"],  # 模块文件路径_import_structure,  # 导入结构module_spec=__spec__,  # 模块规格)# 遍历假对象并设置到当前模块for name, value in _dummy_objects.items():setattr(sys.modules[__name__], name, value)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.ryyt.cn/news/74591.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈,一经查实,立即删除!

相关文章

diffusers-源码解析-二十六-

diffusers 源码解析(二十六) .\diffusers\pipelines\deepfloyd_if\pipeline_if_inpainting_superresolution.py # 导入 html 模块,用于处理 HTML 文本 import html # 导入 inspect 模块,用于获取对象的信息 import inspect # 导入 re 模块,用于正则表达式匹配 import re #…

diffusers-源码解析-二十九-

diffusers 源码解析(二十九) .\diffusers\pipelines\deprecated\stable_diffusion_variants\pipeline_stable_diffusion_model_editing.py # 版权信息,声明版权和许可协议 # Copyright 2024 TIME Authors and The HuggingFace Team. All rights reserved." # 根据 Apac…

习题6.3

习题6.3代码 import numpy as np import pandas as pd import cvxpy as cp import networkx as nx import matplotlib.pyplot as plt plt.rcParams[font.sans-serif]=[Times New Roman + SimSun + WFM Sans SC] plt.rcParams[mathtext.fontset]=stix Times New Roman + SimSun …

苹果笔记本和微软Surface哪个更适合商务使用

在商务环境中,选择合适的笔记本电脑对于提高工作效率至关重要。本文对苹果笔记本和微软Surface进行比较分析,探讨哪种更适合商务使用。主要考虑因素包括:1.性能和可靠性;2.操作系统与软件兼容性;3.设计与便携性;4.电池续航力;5.价格与性价比;6.售后服务与支持。通过全面…

【日记】今天好忙(459 字)

写在前面今天没有什么可看的,可以不用看。 正文爆炸忙。整个下午我的手似乎就没停过,现在写这则日记,回想那个时候的自己,觉得好陌生。整体来说,那段时间也一片空白,什么印象都没有了。太忙,也没有做其它事情的空间。BAEA 台灯到了。我还以为又送到发改局去了,先去那边…

Nuxt.js 应用中的 build:manifest 事件钩子详解

title: Nuxt.js 应用中的 build:manifest 事件钩子详解 date: 2024/10/22 updated: 2024/10/22 author: cmdragon excerpt: build:manifest 是 Nuxt.js 中的一个生命周期钩子,它在 Vite 和 Webpack 构建清单期间被调用。利用这个钩子,开发者可以自定义 Nitro 渲染在最终 H…

如何进行前端单元测试

​前端单元测试的引入为软件开发流程带来了更高的质量和稳定性,需要遵循以下步骤:一、理解单元测试的重要性;二、选择合适的测试框架;三、编写有效的测试用例;四、模拟外部依赖;五、持续维护和优化测试。单元测试的开始,是对前端代码的核心功能进行验证。一、理解单元测…

Jenkins打包Unity游戏环境变量配置

Jenkins打包Unity游戏失败,通过报错日志会查找到sdk环境有问题,解决sdk的环境问题后会出现ndk环境有问题,为了解决这两个问题导致的打包失败需要在Jenkins中配置环境变量打开 Jenkins 首页,选中Manager Jenkins,再点击 System 选项找到全局属性,勾选Environment variable…