FLUX 源码解析(全)

news/2024/9/24 13:43:50

.\flux\demo_gr.py

# 导入操作系统相关模块
import os
# 导入时间相关模块
import time
# 从 io 模块导入 BytesIO 类
from io import BytesIO
# 导入 UUID 生成模块
import uuid# 导入 PyTorch 库
import torch
# 导入 Gradio 库
import gradio as gr
# 导入 NumPy 库
import numpy as np
# 从 einops 模块导入 rearrange 函数
from einops import rearrange
# 从 PIL 库导入 Image 和 ExifTags
from PIL import Image, ExifTags
# 从 transformers 库导入 pipeline 函数
from transformers import pipeline# 从 flux.cli 模块导入 SamplingOptions 类
from flux.cli import SamplingOptions
# 从 flux.sampling 模块导入多个函数
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
# 从 flux.util 模块导入多个函数
from flux.util import configs, embed_watermark, load_ae, load_clip, load_flow_model, load_t5# 设置 NSFW (不适宜工作) 图像的分类阈值
NSFW_THRESHOLD = 0.85# 定义获取模型的函数
def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):# 加载 T5 模型,长度限制根据是否为 schnell 模型决定t5 = load_t5(device, max_length=256 if is_schnell else 512)# 加载 CLIP 模型clip = load_clip(device)# 加载流动模型,根据是否卸载来决定使用 CPU 还是设备model = load_flow_model(name, device="cpu" if offload else device)# 加载自编码器模型,同样根据是否卸载来决定使用 CPU 还是设备ae = load_ae(name, device="cpu" if offload else device)# 创建 NSFW 分类器管道nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)# 返回加载的模型和分类器return model, ae, t5, clip, nsfw_classifier# 定义 FluxGenerator 类
class FluxGenerator:# 类的初始化函数def __init__(self, model_name: str, device: str, offload: bool):# 将设备字符串转换为 torch.device 对象self.device = torch.device(device)# 是否卸载的标志self.offload = offload# 模型名称self.model_name = model_name# 判断是否为 schnell 模型self.is_schnell = model_name == "flux-schnell"# 获取模型及相关组件self.model, self.ae, self.t5, self.clip, self.nsfw_classifier = get_models(model_name,device=self.device,offload=self.offload,is_schnell=self.is_schnell,)# 使用 torch 的推理模式生成图像@torch.inference_mode()def generate_image(self,width,height,num_steps,guidance,seed,prompt,init_image=None,image2image_strength=0.0,add_sampling_metadata=True,# 定义创建演示的函数
def create_demo(model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False):# 初始化 FluxGenerator 对象generator = FluxGenerator(model_name, device, offload)# 判断是否为 schnell 模型is_schnell = model_name == "flux-schnell"# 创建一个 Gradio 应用的 UI 布局with gr.Blocks() as demo:# 添加标题 Markdown 文本,显示模型名称gr.Markdown(f"# Flux Image Generation Demo - Model: {model_name}")# 创建一行布局with gr.Row():# 创建一列布局with gr.Column():# 创建一个文本框用于输入提示prompt = gr.Textbox(label="Prompt", value="a photo of a forest with mist swirling around the tree trunks. The word \"FLUX\" is painted over it in big, red brush strokes with visible texture")# 创建一个复选框用于选择是否启用图像到图像转换do_img2img = gr.Checkbox(label="Image to Image", value=False, interactive=not is_schnell)# 创建一个隐藏的图像输入框init_image = gr.Image(label="Input Image", visible=False)# 创建一个隐藏的滑块,用于调整图像到图像转换的强度image2image_strength = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False)# 创建一个可折叠的高级选项区域with gr.Accordion("Advanced Options", open=False):# 创建滑块用于设置图像宽度width = gr.Slider(128, 8192, 1360, step=16, label="Width")# 创建滑块用于设置图像高度height = gr.Slider(128, 8192, 768, step=16, label="Height")# 创建滑块用于设置步骤数,根据是否快速模式设置初始值num_steps = gr.Slider(1, 50, 4 if is_schnell else 50, step=1, label="Number of steps")# 创建滑块用于设置指导强度guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Guidance", interactive=not is_schnell)# 创建一个文本框用于输入种子值seed = gr.Textbox(-1, label="Seed (-1 for random)")# 创建一个复选框用于选择是否将采样参数添加到元数据add_sampling_metadata = gr.Checkbox(label="Add sampling parameters to metadata?", value=True)# 创建一个生成按钮generate_btn = gr.Button("Generate")# 创建另一列布局with gr.Column():# 创建一个图像框用于显示生成的图像output_image = gr.Image(label="Generated Image")# 创建一个数字框用于显示使用的种子seed_output = gr.Number(label="Used Seed")# 创建一个文本框用于显示警告信息warning_text = gr.Textbox(label="Warning", visible=False)# 创建一个文件框用于下载高分辨率图像download_btn = gr.File(label="Download full-resolution")# 定义一个函数,用于更新图像到图像转换的可见性def update_img2img(do_img2img):return {init_image: gr.update(visible=do_img2img),image2image_strength: gr.update(visible=do_img2img),}# 当复选框状态变化时,调用更新函数do_img2img.change(update_img2img, do_img2img, [init_image, image2image_strength])# 设置生成按钮的点击事件,调用生成图像的函数并设置输入和输出generate_btn.click(fn=generator.generate_image,inputs=[width, height, num_steps, guidance, seed, prompt, init_image, image2image_strength, add_sampling_metadata],outputs=[output_image, seed_output, download_btn, warning_text],)# 返回创建的 Gradio 应用布局return demo
# 当脚本作为主程序运行时执行以下代码
if __name__ == "__main__":# 导入 argparse 模块用于处理命令行参数import argparse# 创建 ArgumentParser 对象,用于解析命令行参数parser = argparse.ArgumentParser(description="Flux")# 添加 --name 参数,指定模型名称,默认值为 "flux-schnell",并限制选择范围parser.add_argument("--name", type=str, default="flux-schnell", choices=list(configs.keys()), help="Model name")# 添加 --device 参数,指定设备,默认值为 "cuda"(如果有 GPU 可用),否则为 "cpu"parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use")# 添加 --offload 参数,标志位,指示是否在不使用时将模型移到 CPUparser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")# 添加 --share 参数,标志位,指示是否创建一个公共链接以共享演示parser.add_argument("--share", action="store_true", help="Create a public link to your demo")# 解析命令行参数,并将结果存储在 args 对象中args = parser.parse_args()# 使用解析出的参数创建 demo 对象demo = create_demo(args.name, args.device, args.offload)# 启动 demo,是否共享由 --share 参数决定demo.launch(share=args.share)

.\flux\demo_st.py

# 导入操作系统相关功能
import os
# 导入正则表达式处理功能
import re
# 导入时间处理功能
import time
# 从 glob 模块导入 iglob,用于生成匹配特定模式的文件路径
from glob import iglob
# 从 io 模块导入 BytesIO,用于处理字节流
from io import BytesIO# 导入 streamlit 库,用于创建 Web 应用
import streamlit as st
# 导入 PyTorch 库,用于深度学习模型
import torch
# 从 einops 库导入 rearrange,用于张量的重排
from einops import rearrange
# 从 fire 库导入 Fire,用于将命令行参数绑定到函数
from fire import Fire
# 从 PIL 库导入 ExifTags 和 Image,用于图像处理
from PIL import ExifTags, Image
# 从 st_keyup 库导入 st_keyup,用于捕捉键盘事件
from st_keyup import st_keyup
# 从 torchvision 库导入 transforms,用于图像转换
from torchvision import transforms
# 从 transformers 库导入 pipeline,用于各种预训练模型的管道
from transformers import pipeline# 设置 NSFW 内容的阈值
NSFW_THRESHOLD = 0.85# 使用 Streamlit 缓存模型加载函数的结果,以提高性能
@st.cache_resource()
def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):# 加载 T5 模型,最大长度取决于是否使用 Schnell 模式t5 = load_t5(device, max_length=256 if is_schnell else 512)# 加载 CLIP 模型clip = load_clip(device)# 加载流模型,设备可能是 CPU 或 GPUmodel = load_flow_model(name, device="cpu" if offload else device)# 加载自动编码器模型,设备可能是 CPU 或 GPUae = load_ae(name, device="cpu" if offload else device)# 加载 NSFW 分类器,用于图像内容检测nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)# 返回模型、自动编码器、T5、CLIP 和 NSFW 分类器return model, ae, t5, clip, nsfw_classifier# 获取用户上传的图像,返回处理后的张量
def get_image() -> torch.Tensor | None:# 允许用户上传 JPG、JPEG 或 PNG 格式的图像image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])# 如果没有上传图像,返回 Noneif image is None:return None# 打开图像文件并转换为 RGB 模式image = Image.open(image).convert("RGB")# 定义图像转换操作,将图像转为张量,并进行归一化transform = transforms.Compose([transforms.ToTensor(),transforms.Lambda(lambda x: 2.0 * x - 1.0),])# 应用转换,将图像处理为张量,并增加一个维度img: torch.Tensor = transform(image)return img[None, ...]# 主函数,用于运行应用逻辑
@torch.inference_mode()
def main(device: str = "cuda" if torch.cuda.is_available() else "cpu",offload: bool = False,output_dir: str = "output",
):# 根据用户选择的设备创建 PyTorch 设备对象torch_device = torch.device(device)# 获取配置中的模型名称列表names = list(configs.keys())# 让用户选择要加载的模型name = st.selectbox("Which model to load?", names)# 如果未选择模型或未勾选加载模型的复选框,则返回if name is None or not st.checkbox("Load model", False):return# 判断是否使用 Schnell 模式is_schnell = name == "flux-schnell"# 获取所需的模型和分类器model, ae, t5, clip, nsfw_classifier = get_models(name,device=torch_device,offload=offload,is_schnell=is_schnell,)# 判断是否执行图像到图像的转换do_img2img = (st.checkbox("Image to Image",False,disabled=is_schnell,help="Partially noise an image and denoise again to get variations.\n\nOnly works for flux-dev",)and not is_schnell)# 如果需要图像到图像转换if do_img2img:# 获取用户上传的图像init_image = get_image()# 如果没有上传图像,显示警告信息if init_image is None:st.warning("Please add an image to do image to image")# 让用户输入噪声强度image2image_strength = st.number_input("Noising strength", min_value=0.0, max_value=1.0, value=0.8)# 如果上传了图像,显示图像尺寸if init_image is not None:h, w = init_image.shape[-2:]st.write(f"Got image of size {w}x{h} ({h*w/1e6:.2f}MP)")# 让用户选择是否调整图像大小resize_img = st.checkbox("Resize image", False) or init_image is Noneelse:# 如果不进行图像到图像转换,初始化图像和图像调整标志init_image = Noneresize_img = Trueimage2image_strength = 0.0# 允许进行打包和转换到潜在空间# 根据用户输入的宽度值计算实际宽度,确保宽度为16的倍数width = int(16 * (st.number_input("Width", min_value=128, value=1360, step=16, disabled=not resize_img) // 16))# 根据用户输入的高度值计算实际高度,确保高度为16的倍数height = int(16 * (st.number_input("Height", min_value=128, value=768, step=16, disabled=not resize_img) // 16))# 根据用户输入的步数值设置步数,默认值为4(如果是"schnell"模式),否则为50num_steps = int(st.number_input("Number of steps", min_value=1, value=(4 if is_schnell else 50)))# 根据用户输入的引导值设置引导参数,默认为3.5,"schnell"模式下禁用此输入guidance = float(st.number_input("Guidance", min_value=1.0, value=3.5, disabled=is_schnell))# 根据用户输入的种子值设置种子,"schnell"模式下禁用此输入seed_str = st.text_input("Seed", disabled=is_schnell)# 如果种子值是有效的十进制数,则将其转换为整数;否则,设置种子为None,并显示提示信息if seed_str.isdecimal():seed = int(seed_str)else:st.info("No seed set, set to positive integer to enable")seed = None# 根据用户选择是否保存样本,设置保存样本的选项save_samples = st.checkbox("Save samples?", not is_schnell)# 根据用户选择是否将采样参数添加到元数据中,设置此选项add_sampling_metadata = st.checkbox("Add sampling parameters to metadata?", True)# 默认提示文本,用于生成图像default_prompt = ("a photo of a forest with mist swirling around the tree trunks. The word "'"FLUX" is painted over it in big, red brush strokes with visible texture')# 获取用户输入的提示文本,默认值为default_prompt,并设置300毫秒的防抖延迟prompt = st_keyup("Enter a prompt", value=default_prompt, debounce=300, key="interactive_text")# 构造输出文件名的路径,并检查输出目录是否存在output_name = os.path.join(output_dir, "img_{idx}.jpg")if not os.path.exists(output_dir):# 如果输出目录不存在,则创建目录,并初始化索引为0os.makedirs(output_dir)idx = 0else:# 如果输出目录存在,获取所有匹配的文件名,并计算下一个可用的索引fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]if len(fns) > 0:idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1else:idx = 0# 创建一个 PyTorch 随机数生成器对象rng = torch.Generator(device="cpu")# 如果 session_state 中没有“seed”项,则初始化种子if "seed" not in st.session_state:st.session_state.seed = rng.seed()# 定义增加种子值的函数def increment_counter():st.session_state.seed += 1# 定义减少种子值的函数(种子值不能小于0)def decrement_counter():if st.session_state.seed > 0:st.session_state.seed -= 1# 创建一个采样选项对象,用于后续处理opts = SamplingOptions(prompt=prompt,width=width,height=height,num_steps=num_steps,guidance=guidance,seed=seed,)# 如果应用名为“flux-schnell”,则显示带有按钮的列来增加或减少种子值if name == "flux-schnell":cols = st.columns([5, 1, 1, 5])with cols[1]:st.button("↩", on_click=increment_counter)with cols[2]:st.button("↪", on_click=decrement_counter)# 获取会话状态中的样本(如果存在),并显示图像及其相关信息samples = st.session_state.get("samples", None)if samples is not None:st.image(samples["img"], caption=samples["prompt"])st.download_button("Download full-resolution",samples["bytes"],file_name="generated.jpg",mime="image/jpg",)st.write(f"Seed: {samples['seed']}")
# 定义应用程序入口函数
def app():# 调用 Fire 函数并传入 main 作为参数Fire(main)# 如果脚本是主程序(而不是被导入),则执行 app() 函数
if __name__ == "__main__":app()

FLUX.1 [dev] Grid

FLUX.1 [dev] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions.
For more information, please read our blog post.

Key Features

  1. Cutting-edge output quality, second only to our state-of-the-art model FLUX.1 [pro].
  2. Competitive prompt following, matching the performance of closed source alternatives.
  3. Trained using guidance distillation, making FLUX.1 [dev] more efficient.
  4. Open weights to drive new scientific research, and empower artists to develop innovative workflows.
  5. Generated outputs can be used for personal, scientific, and commercial purposes, as described in the flux-1-dev-non-commercial-license.

Usage

We provide a reference implementation of FLUX.1 [dev], as well as sampling code, in a dedicated github repository.
Developers and creatives looking to build on top of FLUX.1 [dev] are encouraged to use this as a starting point.

API Endpoints

The FLUX.1 models are also available via API from the following sources

  1. bfl.ml (currently FLUX.1 [pro])
  2. replicate.com
  3. fal.ai

ComfyUI

FLUX.1 [dev] is also available in Comfy UI for local inference with a node-based workflow.


Limitations

  • This model is not intended or able to provide factual information.
  • As a statistical model this checkpoint might amplify existing societal biases.
  • The model may fail to generate output that matches the prompts.
  • Prompt following is heavily influenced by the prompting-style.

Out-of-Scope Use

The model and its derivatives may not be used

  • In any way that violates any applicable national, federal, state, local or international law or regulation.
  • For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; including but not limited to the solicitation, creation, acquisition, or dissemination of child exploitative content.
  • To generate or disseminate verifiably false information and/or content with the purpose of harming others.
  • To generate or disseminate personal identifiable information that can be used to harm an individual.
  • To harass, abuse, threaten, stalk, or bully individuals or groups of individuals.
  • To create non-consensual nudity or illegal pornographic content.
  • For fully automated decision making that adversely impacts an individual's legal rights or otherwise creates or modifies a binding, enforceable obligation.
  • Generating or facilitating large-scale disinformation campaigns.

License

This model falls under the FLUX.1 [dev] Non-Commercial License.

FLUX.1 [schnell] Grid

FLUX.1 [schnell] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions.
For more information, please read our blog post.

Key Features

  1. Cutting-edge output quality and competitive prompt following, matching the performance of closed source alternatives.
  2. Trained using latent adversarial diffusion distillation, FLUX.1 [schnell] can generate high-quality images in only 1 to 4 steps.
  3. Released under the apache-2.0 licence, the model can be used for personal, scientific, and commercial purposes.

Usage

We provide a reference implementation of FLUX.1 [schnell], as well as sampling code, in a dedicated github repository.
Developers and creatives looking to build on top of FLUX.1 [schnell] are encouraged to use this as a starting point.

API Endpoints

The FLUX.1 models are also available via API from the following sources

  1. bfl.ml (currently FLUX.1 [pro])
  2. replicate.com
  3. fal.ai

ComfyUI

FLUX.1 [schnell] is also available in Comfy UI for local inference with a node-based workflow.


Limitations

  • This model is not intended or able to provide factual information.
  • As a statistical model this checkpoint might amplify existing societal biases.
  • The model may fail to generate output that matches the prompts.
  • Prompt following is heavily influenced by the prompting-style.

Out-of-Scope Use

The model and its derivatives may not be used

  • In any way that violates any applicable national, federal, state, local or international law or regulation.
  • For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; including but not limited to the solicitation, creation, acquisition, or dissemination of child exploitative content.
  • To generate or disseminate verifiably false information and/or content with the purpose of harming others.
  • To generate or disseminate personal identifiable information that can be used to harm an individual.
  • To harass, abuse, threaten, stalk, or bully individuals or groups of individuals.
  • To create non-consensual nudity or illegal pornographic content.
  • For fully automated decision making that adversely impacts an individual's legal rights or otherwise creates or modifies a binding, enforceable obligation.
  • Generating or facilitating large-scale disinformation campaigns.

.\flux\src\flux\api.py

# 导入标准库中的 io 模块,用于处理):"""Manages an image generation request to the API.Args:prompt: Prompt to samplewidth: Width of the image in pixelheight: Height of the image in pixelname: Name of the modelnum_steps: Number of network evaluationsprompt_upsampling: Use prompt upsamplingseed: Fix the generation seedvalidate: Run input validationlaunch: Directly launches requestapi_key: Your API key if not provided by the environmentRaises:ValueError: For invalid inputApiException: For errors raised from the API"""# 如果需要验证输入if validate:# 检查模型名称是否有效if name not in ["flux.1-pro"]:raise ValueError(f"Invalid model {name}")# 检查宽度是否是 32 的倍数elif width % 32 != 0:raise ValueError(f"width must be divisible by 32, got {width}")# 检查宽度是否在合法范围内elif not (256 <= width <= 1440):raise ValueError(f"width must be between 256 and 1440, got {width}")# 检查高度是否是 32 的倍数elif height % 32 != 0:raise ValueError(f"height must be divisible by 32, got {height}")# 检查高度是否在合法范围内elif not (256 <= height <= 1440):raise ValueError(f"height must be between 256 and 1440, got {height}")# 检查步骤数量是否在合法范围内elif not (1 <= num_steps <= 50):raise ValueError(f"steps must be between 1 and 50, got {num_steps}")# 创建请求 JSON 对象,包含所有必需的参数self.request_json = {"prompt": prompt,"width": width,"height": height,"variant": name,"steps": num_steps,"prompt_upsampling": prompt_upsampling,}# 如果指定了种子,将其添加到请求 JSON 中if seed is not None:self.request_json["seed"] = seed# 初始化实例变量self.request_id: str | None = Noneself.result: dict | None = Noneself._image_bytes: bytes | None = Noneself._url: str | None = None# 如果没有提供 API 密钥,则从环境变量中获取if api_key is None:self.api_key = os.environ.get("BFL_API_KEY")else:# 否则使用提供的 API 密钥self.api_key = api_key# 如果需要立即发起请求if launch:self.request()def request(self):"""Request to generate the image."""# 如果已经有请求 ID,则不再发起请求if self.request_id is not None:return# 发起 POST 请求以生成图像response = requests.post(f"{API_ENDPOINT}/v1/image",headers={"accept": "application/json","x-key": self.api_key,"Content-Type": "application/json",},json=self.request_json,)# 解析响应为 JSONresult = response.json()# 如果响应状态码不是 200,抛出 API 异常if response.status_code != 200:raise ApiException(status_code=response.status_code, detail=result.get("detail"))# 存储请求 IDself.request_id = response.json()["id"]# 定义一个方法来等待生成完成并检索响应结果def retrieve(self) -> dict:"""等待生成完成并检索响应"""# 如果 request_id 为空,则调用请求方法生成请求 IDif self.request_id is None:self.request()# 循环等待直到结果可用while self.result is None:# 发送 GET 请求以获取结果response = requests.get(f"{API_ENDPOINT}/v1/get_result",headers={"accept": "application/json","x-key": self.api_key,},params={"id": self.request_id,},)# 将响应内容转换为 JSON 格式result = response.json()# 检查返回结果中是否包含状态字段if "status" not in result:# 如果没有状态字段,抛出 API 异常raise ApiException(status_code=response.status_code, detail=result.get("detail"))# 如果状态是“Ready”,则将结果保存到实例变量elif result["status"] == "Ready":self.result = result["result"]# 如果状态是“Pending”,则等待 0.5 秒再重试elif result["status"] == "Pending":time.sleep(0.5)# 如果状态是其他值,抛出 API 异常else:raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")# 返回最终结果return self.result# 定义一个属性方法,返回生成的图像字节@propertydef bytes(self) -> bytes:"""生成的图像字节"""# 如果图像字节为空,则从 URL 获取图像数据if self._image_bytes is None:response = requests.get(self.url)# 如果响应状态码是 200,则保存图像字节if response.status_code == 200:self._image_bytes = response.content# 否则抛出 API 异常else:raise ApiException(status_code=response.status_code)# 返回图像字节return self._image_bytes# 定义一个属性方法,返回图像的公共 URL@propertydef url(self) -> str:"""检索图像的公共 URL"""# 如果 URL 为空,则调用 retrieve 方法获取结果并保存 URLif self._url is None:result = self.retrieve()self._url = result["sample"]# 返回图像的 URLreturn self._url# 定义一个属性方法,返回 PIL 图像对象@propertydef image(self) -> Image.Image:"""加载图像为 PIL Image 对象"""return Image.open(io.BytesIO(self.bytes))# 定义一个方法来将生成的图像保存到本地路径def save(self, path: str):"""将生成的图像保存到本地路径"""# 获取 URL 的文件扩展名suffix = Path(self.url).suffix# 如果路径没有扩展名,则将扩展名添加到路径中if not path.endswith(suffix):path = path + suffix# 创建保存路径的父目录(如果不存在)Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)# 将图像字节写入指定路径with open(path, "wb") as file:file.write(self.bytes)
# 确保只有在直接运行该脚本时才执行以下代码
if __name__ == "__main__":# 从 fire 库中导入 Fire 类from fire import Fire# 使用 Fire 类启动命令行界面,传入 ImageRequest 作为处理对象Fire(ImageRequest)

.\flux\src\flux\cli.py

# 导入操作系统相关模块
import os
# 导入正则表达式模块
import re
# 导入时间模块
import time
# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass
# 从 glob 模块导入 iglob 函数,用于文件名模式匹配
from glob import iglob# 导入 PyTorch 库
import torch
# 从 einops 模块导入 rearrange 函数,用于张量重排
from einops import rearrange
# 从 fire 模块导入 Fire 类,用于命令行接口
from fire import Fire
# 从 PIL 模块导入 ExifTags 和 Image,用于处理图片和元数据
from PIL import ExifTags, Image# 从 flux.sampling 模块导入采样相关函数
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
# 从 flux.util 模块导入实用工具函数
from flux.util import (configs, embed_watermark, load_ae, load_clip,load_flow_model, load_t5)
# 从 transformers 模块导入 pipeline,用于加载预训练模型
from transformers import pipeline# 设置 NSFW(不适宜工作)内容的阈值
NSFW_THRESHOLD = 0.85# 定义一个数据类,用于存储采样选项
@dataclass
class SamplingOptions:# 用户提示文本prompt: str# 图像宽度width: int# 图像高度height: int# 生成图像的步骤数量num_steps: int# 引导强度guidance: float# 随机种子,可选seed: int | None# 解析用户输入的提示,并根据选项更新 SamplingOptions
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:# 提示用户输入下一个提示user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"# 使用说明文本usage = ("Usage: Either write your prompt directly, leave this field empty ""to repeat the prompt or write a command starting with a slash:\n""- '/w <width>' will set the width of the generated image\n""- '/h <height>' will set the height of the generated image\n""- '/s <seed>' sets the next seed\n""- '/g <guidance>' sets the guidance (flux-dev only)\n""- '/n <steps>' sets the number of steps\n""- '/q' to quit")# 循环读取用户输入,直到输入不以斜杠开头while (prompt := input(user_question)).startswith("/"):# 处理以 "/w" 开头的命令,设置宽度if prompt.startswith("/w"):# 如果命令中没有空格,提示无效命令并继续if prompt.count(" ") != 1:print(f"Got invalid command '{prompt}'\n{usage}")continue# 解析命令中的宽度值并设置为16的倍数_, width = prompt.split()options.width = 16 * (int(width) // 16)# 打印设置的宽度和高度,以及总像素数print(f"Setting resolution to {options.width} x {options.height} "f"({options.height *options.width/1e6:.2f}MP)")# 处理以 "/h" 开头的命令,设置高度elif prompt.startswith("/h"):# 如果命令中没有空格,提示无效命令并继续if prompt.count(" ") != 1:print(f"Got invalid command '{prompt}'\n{usage}")continue# 解析命令中的高度值并设置为16的倍数_, height = prompt.split()options.height = 16 * (int(height) // 16)# 打印设置的宽度和高度,以及总像素数print(f"Setting resolution to {options.width} x {options.height} "f"({options.height *options.width/1e6:.2f}MP)")# 处理以 "/g" 开头的命令,设置指导值elif prompt.startswith("/g"):# 如果命令中没有空格,提示无效命令并继续if prompt.count(" ") != 1:print(f"Got invalid command '{prompt}'\n{usage}")continue# 解析命令中的指导值_, guidance = prompt.split()options.guidance = float(guidance)# 打印设置的指导值print(f"Setting guidance to {options.guidance}")# 处理以 "/s" 开头的命令,设置种子值elif prompt.startswith("/s"):# 如果命令中没有空格,提示无效命令并继续if prompt.count(" ") != 1:print(f"Got invalid command '{prompt}'\n{usage}")continue# 解析命令中的种子值_, seed = prompt.split()options.seed = int(seed)# 打印设置的种子值print(f"Setting seed to {options.seed}")# 处理以 "/n" 开头的命令,设置步骤数elif prompt.startswith("/n"):# 如果命令中没有空格,提示无效命令并继续if prompt.count(" ") != 1:print(f"Got invalid command '{prompt}'\n{usage}")continue# 解析命令中的步骤数_, steps = prompt.split()options.num_steps = int(steps)# 打印设置的步骤数print(f"Setting seed to {options.num_steps}")# 处理以 "/q" 开头的命令,退出循环elif prompt.startswith("/q"):print("Quitting")return Noneelse:# 如果命令不以已知前缀开头,提示无效命令并显示用法if not prompt.startswith("/h"):print(f"Got invalid command '{prompt}'\n{usage}")print(usage)# 如果输入不为空,将其设置为提示if prompt != "":options.prompt = prompt# 返回更新后的选项对象return options
@torch.inference_mode()
def main(name: str = "flux-schnell",width: int = 1360,height: int = 768,seed: int | None = None,prompt: str = ("a photo of a forest with mist swirling around the tree trunks. The word "'"FLUX" is painted over it in big, red brush strokes with visible texture'),device: str = "cuda" if torch.cuda.is_available() else "cpu",num_steps: int | None = None,loop: bool = False,guidance: float = 3.5,offload: bool = False,output_dir: str = "output",add_sampling_metadata: bool = True,
):"""Sample the flux model. Either interactively (set `--loop`) or run for asingle image.Args:name: Name of the model to loadheight: height of the sample in pixels (should be a multiple of 16)width: width of the sample in pixels (should be a multiple of 16)seed: Set a seed for samplingoutput_name: where to save the output image, `{idx}` will be replacedby the index of the sampleprompt: Prompt used for samplingdevice: Pytorch devicenum_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)loop: start an interactive session and sample multiple timesguidance: guidance value used for guidance distillationadd_sampling_metadata: Add the prompt to the image Exif metadata"""# Initialize an NSFW image classification pipeline with the specified model and devicensfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)# Check if the specified model name is validif name not in configs:available = ", ".join(configs.keys())raise ValueError(f"Got unknown model name: {name}, chose from {available}")# Set the PyTorch device based on the provided device stringtorch_device = torch.device(device)# Determine the number of sampling steps based on the model nameif num_steps is None:num_steps = 4 if name == "flux-schnell" else 50# Adjust height and width to be multiples of 16 for compatibilityheight = 16 * (height // 16)width = 16 * (width // 16)# Construct the output file path and handle directory and index managementoutput_name = os.path.join(output_dir, "img_{idx}.jpg")if not os.path.exists(output_dir):os.makedirs(output_dir)idx = 0else:fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]if len(fns) > 0:idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1else:idx = 0# Initialize components for the sampling processt5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)clip = load_clip(torch_device)model = load_flow_model(name, device="cpu" if offload else torch_device)ae = load_ae(name, device="cpu" if offload else torch_device)# Create a random number generator and sampling optionsrng = torch.Generator(device="cpu")opts = SamplingOptions(prompt=prompt,width=width,height=height,num_steps=num_steps,guidance=guidance,seed=seed,)# If loop mode is enabled, adjust the options based on the promptif loop:opts = parse_prompt(opts)# 当 opts 不为 None 时持续循环while opts is not None:# 如果 opts 中没有种子,则生成一个新的种子if opts.seed is None:opts.seed = rng.seed()# 打印生成过程的种子和提示print(f"Generating with seed {opts.seed}:\n{opts.prompt}")# 记录当前时间以计算生成时间t0 = time.perf_counter()# 准备输入噪声数据x = get_noise(1,opts.height,opts.width,device=torch_device,dtype=torch.bfloat16,seed=opts.seed,)# 将种子置为 None 以防止重复使用opts.seed = None# 如果需要将模型移至 CPU,清理 CUDA 缓存,并将模型移动到指定设备if offload:ae = ae.cpu()torch.cuda.empty_cache()t5, clip = t5.to(torch_device), clip.to(torch_device)# 准备输入数据,包括将 T5 和 CLIP 模型的输出、噪声以及提示整理成输入inp = prepare(t5, clip, x, prompt=opts.prompt)# 获取时间步的调度timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))# 如果需要将模型移至 CPU,清理 CUDA 缓存,并将模型移动到 GPUif offload:t5, clip = t5.cpu(), clip.cpu()torch.cuda.empty_cache()model = model.to(torch_device)# 对初始噪声进行去噪处理x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)# 如果需要将模型移至 CPU,清理 CUDA 缓存,并将自动编码器的解码器移至当前设备if offload:model.cpu()torch.cuda.empty_cache()ae.decoder.to(x.device)# 将潜在变量解码到像素空间x = unpack(x.float(), opts.height, opts.width)with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):x = ae.decode(x)# 记录解码处理时间t1 = time.perf_counter()# 格式化输出文件名fn = output_name.format(idx=idx)print(f"Done in {t1 - t0:.1f}s. Saving {fn}")# 将图像数据带入 PIL 格式并保存x = x.clamp(-1, 1)x = embed_watermark(x.float())x = rearrange(x[0], "c h w -> h w c")# 从 numpy 数组创建 PIL 图像对象img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())# 进行 NSFW 内容检测nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]# 如果 NSFW 分数低于阈值,则保存图像及其 EXIF 元数据if nsfw_score < NSFW_THRESHOLD:exif_data = Image.Exif()exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"exif_data[ExifTags.Base.Make] = "Black Forest Labs"exif_data[ExifTags.Base.Model] = nameif add_sampling_metadata:exif_data[ExifTags.Base.ImageDescription] = promptimg.save(fn, exif=exif_data, quality=95, subsampling=0)# 增加图像索引idx += 1else:print("Your generated image may contain NSFW content.")# 如果设置了循环,则解析新的提示并继续,否则退出循环if loop:print("-" * 80)opts = parse_prompt(opts)else:opts = None
# 定义主函数
def app():# 使用 Fire 库将 main 函数作为命令行接口Fire(main)# 检查是否为主模块运行
if __name__ == "__main__":# 调用 app 函数app()

.\flux\src\flux\math.py

# 导入 PyTorch 库和 einops 的 rearrange 函数
import torch
from einops import rearrange
from torch import Tensor# 注意力机制函数
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:# 对 q 和 k 应用相对位置编码q, k = apply_rope(q, k, pe)# 使用缩放点积注意力计算输出x = torch.nn.functional.scaled_dot_product_attention(q, k, v)# 重新排列输出张量的维度x = rearrange(x, "B H L D -> B L (H D)")# 返回处理后的张量return x# 相对位置编码函数
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:# 确保维度是偶数assert dim % 2 == 0# 计算尺度因子scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim# 计算 omega 值omega = 1.0 / (theta**scale)# 通过爱因斯坦求和计算输出out = torch.einsum("...n,d->...nd", pos, omega)# 创建旋转矩阵out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)# 重新排列旋转矩阵的维度out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)# 转换为 float 类型并返回return out.float()# 应用相对位置编码的辅助函数
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:# 重新排列 q 和 k 的维度并转换为 float 类型xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)# 计算 q 和 k 的编码输出xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]# 恢复原始维度并返回return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

.\flux\src\flux\model.py

# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass# 导入 PyTorch 和相关模块
import torch
from torch import Tensor, nn# 从 flux.modules.layers 模块导入特定的类
from flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,MLPEmbedder, SingleStreamBlock,timestep_embedding)# 定义包含模型参数的类
@dataclass
class FluxParams:# 输入通道数in_channels: int# 输入向量维度vec_in_dim: int# 上下文输入维度context_in_dim: int# 隐藏层大小hidden_size: int# MLP 比例mlp_ratio: float# 头数num_heads: int# 网络深度depth: int# 单流块的深度depth_single_blocks: int# 轴维度列表axes_dim: list[int]# theta 参数theta: int# 是否使用 QKV 偏置qkv_bias: bool# 是否使用引导嵌入guidance_embed: bool# 定义 Flux 模型类
class Flux(nn.Module):"""Transformer 模型用于序列上的流匹配。"""# 初始化方法def __init__(self, params: FluxParams):super().__init__()# 保存参数self.params = params# 输入通道数self.in_channels = params.in_channels# 输出通道数与输入通道数相同self.out_channels = self.in_channels# 确保隐藏层大小可以被头数整除if params.hidden_size % params.num_heads != 0:raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")# 计算位置编码维度pe_dim = params.hidden_size // params.num_heads# 确保轴维度总和与位置编码维度匹配if sum(params.axes_dim) != pe_dim:raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")# 隐藏层大小self.hidden_size = params.hidden_size# 头数self.num_heads = params.num_heads# 初始化位置嵌入层self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)# 初始化图像输入线性层self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)# 初始化时间嵌入层self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)# 初始化向量嵌入层self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)# 初始化引导嵌入层(如果需要的话)self.guidance_in = (MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity())# 初始化文本输入线性层self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)# 创建双流块的模块列表self.double_blocks = nn.ModuleList([DoubleStreamBlock(self.hidden_size,self.num_heads,mlp_ratio=params.mlp_ratio,qkv_bias=params.qkv_bias,)for _ in range(params.depth)])# 创建单流块的模块列表self.single_blocks = nn.ModuleList([SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)for _ in range(params.depth_single_blocks)])# 初始化最终层self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)# 前向传播方法def forward(self,img: Tensor,img_ids: Tensor,txt: Tensor,txt_ids: Tensor,timesteps: Tensor,y: Tensor,guidance: Tensor | None = None,) -> Tensor:  # 定义返回类型为 Tensor 的函数# 检查 img 和 txt 张量是否都具有 3 个维度if img.ndim != 3 or txt.ndim != 3:raise ValueError("Input img and txt tensors must have 3 dimensions.")# 对输入的 img 张量进行初步处理img = self.img_in(img)# 计算时间步嵌入向量,并通过 self.time_in 处理vec = self.time_in(timestep_embedding(timesteps, 256))# 如果启用了指导嵌入,则处理指导嵌入if self.params.guidance_embed:if guidance is None:raise ValueError("Didn't get guidance strength for guidance distilled model.")# 将指导嵌入向量添加到 vec 中vec = vec + self.guidance_in(timestep_embedding(guidance, 256))# 将其他向量添加到 vec 中vec = vec + self.vector_in(y)# 对 txt 张量进行处理txt = self.txt_in(txt)# 将 txt_ids 和 img_ids 按维度 1 拼接ids = torch.cat((txt_ids, img_ids), dim=1)# 计算位置编码pe = self.pe_embedder(ids)# 对 double_blocks 中的每个块进行处理for block in self.double_blocks:img, txt = block(img=img, txt=txt, vec=vec, pe=pe)# 将 txt 和 img 张量按维度 1 拼接img = torch.cat((txt, img), 1)# 对 single_blocks 中的每个块进行处理for block in self.single_blocks:img = block(img, vec=vec, pe=pe)# 截取 img 张量,去掉前面的 txt 部分img = img[:, txt.shape[1] :, ...]# 最终处理 img 张量,返回结果img = self.final_layer(img, vec)  # (N, T, patch_size ** 2 * out_channels)return img

.\flux\src\flux\modules\autoencoder.py

# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass# 导入 PyTorch 库
import torch
# 从 einops 模块导入 rearrange 函数
from einops import rearrange
# 从 torch 库导入 Tensor 和 nn 模块
from torch import Tensor, nn# 定义 AutoEncoder 的参数数据类
@dataclass
class AutoEncoderParams:resolution: int  # 图像分辨率in_channels: int  # 输入通道数ch: int  # 基本通道数out_ch: int  # 输出通道数ch_mult: list[int]  # 通道数的增减比例num_res_blocks: int  # 残差块数量z_channels: int  # 潜在通道数scale_factor: float  # 缩放因子shift_factor: float  # 偏移因子# 定义 swish 激活函数
def swish(x: Tensor) -> Tensor:# 使用 sigmoid 函数调节 x 的激活值return x * torch.sigmoid(x)# 定义注意力块类
class AttnBlock(nn.Module):def __init__(self, in_channels: int):super().__init__()self.in_channels = in_channels# 初始化归一化层self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)# 初始化用于计算注意力的卷积层self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)# 注意力机制函数def attention(self, h_: Tensor) -> Tensor:# 归一化输入h_ = self.norm(h_)# 计算 q, k, vq = self.q(h_)k = self.k(h_)v = self.v(h_)# 获取 q, k, v 的维度b, c, h, w = q.shape# 重排列 q, k, vq = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()# 应用缩放点积注意力h_ = nn.functional.scaled_dot_product_attention(q, k, v)# 将输出重排列为原始维度return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)# 前向传播函数def forward(self, x: Tensor) -> Tensor:# 添加注意力机制后的输出到原始输入return x + self.proj_out(self.attention(x))# 定义残差块类
class ResnetBlock(nn.Module):def __init__(self, in_channels: int, out_channels: int):super().__init__()self.in_channels = in_channelsout_channels = in_channels if out_channels is None else out_channelsself.out_channels = out_channels# 初始化归一化层和卷积层self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)# 如果输入和输出通道数不同,初始化快捷连接if self.in_channels != self.out_channels:self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)# 前向传播函数def forward(self, x):h = x# 通过第一层归一化、激活和卷积h = self.norm1(h)h = swish(h)h = self.conv1(h)# 通过第二层归一化、激活和卷积h = self.norm2(h)h = swish(h)h = self.conv2(h)# 如果输入和输出通道数不同,应用快捷连接if self.in_channels != self.out_channels:x = self.nin_shortcut(x)# 返回残差连接的结果return x + h# 定义下采样类
class Downsample(nn.Module):def __init__(self, in_channels: int):super().__init__()# 在 torch conv 中没有非对称填充,必须手动处理self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)# 前向传播函数,接受一个 Tensor 作为输入def forward(self, x: Tensor):# 定义 padding 的大小,分别是右边 1、下边 1pad = (0, 1, 0, 1)# 对输入 Tensor 进行 padding,填充值为 0x = nn.functional.pad(x, pad, mode="constant", value=0)# 将 padding 过的 Tensor 通过卷积层x = self.conv(x)# 返回卷积后的结果return x
# 定义上采样模块,继承自 nn.Module
class Upsample(nn.Module):def __init__(self, in_channels: int):super().__init__()# 创建卷积层,用于对输入特征图进行卷积操作self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)def forward(self, x: Tensor):# 对输入特征图进行双线性插值上采样,扩大尺寸为原来的2倍x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")# 对上采样后的特征图应用卷积层x = self.conv(x)# 返回处理后的特征图return x# 定义编码器模块,继承自 nn.Module
class Encoder(nn.Module):def __init__(self,resolution: int,in_channels: int,ch: int,ch_mult: list[int],num_res_blocks: int,z_channels: int,):super().__init__()self.ch = chself.num_resolutions = len(ch_mult)self.num_res_blocks = num_res_blocksself.resolution = resolutionself.in_channels = in_channels# 输入层卷积,用于初始化特征图self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)curr_res = resolutionin_ch_mult = (1,) + tuple(ch_mult)self.in_ch_mult = in_ch_multself.down = nn.ModuleList()block_in = self.chfor i_level in range(self.num_resolutions):block = nn.ModuleList()attn = nn.ModuleList()# 设置每层的输入和输出通道数block_in = ch * in_ch_mult[i_level]block_out = ch * ch_mult[i_level]for _ in range(self.num_res_blocks):# 添加残差块到当前层block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))block_in = block_outdown = nn.Module()down.block = blockdown.attn = attnif i_level != self.num_resolutions - 1:# 添加下采样层down.downsample = Downsample(block_in)curr_res = curr_res // 2self.down.append(down)# 中间层,包括两个残差块和一个注意力块self.mid = nn.Module()self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)self.mid.attn_1 = AttnBlock(block_in)self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)# 输出层,包括归一化和卷积层self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)def forward(self, x: Tensor) -> Tensor:# 对输入特征图进行下采样hs = [self.conv_in(x)]for i_level in range(self.num_resolutions):for i_block in range(self.num_res_blocks):h = self.down[i_level].block[i_block](hs[-1])if len(self.down[i_level].attn) > 0:h = self.down[i_level].attn[i_block](h)hs.append(h)if i_level != self.num_resolutions - 1:hs.append(self.down[i_level].downsample(hs[-1]))# 中间处理h = hs[-1]h = self.mid.block_1(h)h = self.mid.attn_1(h)h = self.mid.block_2(h)# 输出处理h = self.norm_out(h)h = swish(h)h = self.conv_out(h)# 返回最终处理后的特征图return h# 定义解码器模块,继承自 nn.Module
class Decoder(nn.Module):def __init__(self,ch: int,out_ch: int,ch_mult: list[int],num_res_blocks: int,in_channels: int,resolution: int,z_channels: int,):# 调用父类的初始化方法super().__init__()# 保存输入通道数self.ch = ch# 保存多分辨率通道数的数量self.num_resolutions = len(ch_mult)# 保存残差块的数量self.num_res_blocks = num_res_blocks# 保存图像分辨率self.resolution = resolution# 保存输入通道数self.in_channels = in_channels# 计算最终分辨率的缩放因子self.ffactor = 2 ** (self.num_resolutions - 1)# 计算最低分辨率下的输入通道数和分辨率block_in = ch * ch_mult[self.num_resolutions - 1]curr_res = resolution // 2 ** (self.num_resolutions - 1)# 定义潜在变量 z 的形状self.z_shape = (1, z_channels, curr_res, curr_res)# z 到 block_in 的卷积层self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)# 中间层模块self.mid = nn.Module()self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)self.mid.attn_1 = AttnBlock(block_in)self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)# 上采样模块self.up = nn.ModuleList()for i_level in reversed(range(self.num_resolutions)):block = nn.ModuleList()attn = nn.ModuleList()# 当前分辨率下的输出通道数block_out = ch * ch_mult[i_level]for _ in range(self.num_res_blocks + 1):# 添加残差块block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))block_in = block_outup = nn.Module()up.block = blockup.attn = attnif i_level != 0:# 添加上采样层up.upsample = Upsample(block_in)curr_res = curr_res * 2# 将上采样模块插入列表开头,保持顺序一致self.up.insert(0, up)# 输出归一化层self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)# 输出卷积层self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)def forward(self, z: Tensor) -> Tensor:# 将 z 传入 conv_in 层h = self.conv_in(z)# 通过中间层h = self.mid.block_1(h)h = self.mid.attn_1(h)h = self.mid.block_2(h)# 上采样过程for i_level in reversed(range(self.num_resolutions)):for i_block in range(self.num_res_blocks + 1):h = self.up[i_level].block[i_block](h)if len(self.up[i_level].attn) > 0:h = self.up[i_level].attn[i_block](h)if i_level != 0:# 上采样h = self.up[i_level].upsample(h)# 结束层h = self.norm_out(h)h = swish(h)h = self.conv_out(h)# 返回最终输出return h
# 定义对角高斯分布的神经网络模块
class DiagonalGaussian(nn.Module):# 初始化方法,定义是否采样及分块维度def __init__(self, sample: bool = True, chunk_dim: int = 1):super().__init__()# 是否进行采样self.sample = sample# 进行分块操作的维度self.chunk_dim = chunk_dim# 前向传播方法def forward(self, z: Tensor) -> Tensor:# 将输入张量 z 按指定维度 chunk_dim 划分为两个张量 mean 和 logvarmean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)if self.sample:# 如果需要采样,计算标准差并从标准正态分布中生成随机样本std = torch.exp(0.5 * logvar)return mean + std * torch.randn_like(mean)else:# 否则只返回均值return mean# 定义自编码器的神经网络模块
class AutoEncoder(nn.Module):# 初始化方法,定义编码器、解码器及高斯分布def __init__(self, params: AutoEncoderParams):super().__init__()# 创建编码器实例,传入相应参数self.encoder = Encoder(resolution=params.resolution,in_channels=params.in_channels,ch=params.ch,ch_mult=params.ch_mult,num_res_blocks=params.num_res_blocks,z_channels=params.z_channels,)# 创建解码器实例,传入相应参数self.decoder = Decoder(resolution=params.resolution,in_channels=params.in_channels,ch=params.ch,out_ch=params.out_ch,ch_mult=params.ch_mult,num_res_blocks=params.num_res_blocks,z_channels=params.z_channels,)# 创建对角高斯分布实例self.reg = DiagonalGaussian()# 设置缩放因子和偏移因子self.scale_factor = params.scale_factorself.shift_factor = params.shift_factor# 编码方法,将输入 x 进行编码并调整缩放和偏移def encode(self, x: Tensor) -> Tensor:# 通过编码器获取 z,随后通过对角高斯分布进行处理z = self.reg(self.encoder(x))# 对 z 进行缩放和偏移z = self.scale_factor * (z - self.shift_factor)return z# 解码方法,将 z 解码为输出def decode(self, z: Tensor) -> Tensor:# 对 z 进行逆操作,恢复到编码前的尺度z = z / self.scale_factor + self.shift_factor# 使用解码器进行解码return self.decoder(z)# 前向传播方法,执行编码和解码def forward(self, x: Tensor) -> Tensor:# 先编码再解码return self.decode(self.encode(x))

.\flux\src\flux\modules\conditioner.py

# 从 PyTorch 和 Transformers 库导入必要的模块
from torch import Tensor, nn
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,T5Tokenizer)# 定义一个用于获取文本嵌入的类 HFEmbedder,继承自 nn.Module
class HFEmbedder(nn.Module):# 初始化方法def __init__(self, version: str, max_length: int, **hf_kwargs):# 调用父类的初始化方法super().__init__()# 判断是否使用 CLIP 模型,根据版本名进行判断self.is_clip = version.startswith("openai")# 设置最大长度self.max_length = max_length# 根据是否使用 CLIP 模型选择输出的键self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"# 如果使用 CLIP 模型if self.is_clip:# 从预训练模型加载 tokenizerself.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)# 从预训练模型加载 HF 模块self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)else:# 如果使用 T5 模型# 从预训练模型加载 tokenizerself.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)# 从预训练模型加载 HF 模块self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)# 将模型设置为评估模式,并且不计算梯度self.hf_module = self.hf_module.eval().requires_grad_(False)# 前向传播方法,处理输入文本并返回嵌入def forward(self, text: list[str]) -> Tensor:# 使用 tokenizer 对文本进行编码batch_encoding = self.tokenizer(text,truncation=True,  # 对超长文本进行截断max_length=self.max_length,  # 设置最大长度return_length=False,  # 不返回文本长度return_overflowing_tokens=False,  # 不返回溢出的标记padding="max_length",  # 填充到最大长度return_tensors="pt",  # 返回 PyTorch 张量)# 使用 HF 模块进行前向传播计算outputs = self.hf_module(input_ids=batch_encoding["input_ids"].to(self.hf_module.device),  # 将输入 ID 移动到模型所在设备attention_mask=None,  # 不使用注意力掩码output_hidden_states=False,  # 不返回隐藏状态)# 返回指定键对应的输出return outputs[self.output_key]

.\flux\src\flux\modules\layers.py

# 导入数学库
import math
# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass# 导入 PyTorch 库
import torch
# 从 einops 库导入 rearrange 函数
from einops import rearrange
# 从 torch 库导入 Tensor 和 nn 模块
from torch import Tensor, nn# 从 flux.math 模块导入 attention 和 rope 函数
from flux.math import attention, rope# 定义一个嵌入类,用于处理 N 维数据
class EmbedND(nn.Module):def __init__(self, dim: int, theta: int, axes_dim: list[int]):super().__init__()# 初始化维度、角度和轴维度self.dim = dimself.theta = thetaself.axes_dim = axes_dimdef forward(self, ids: Tensor) -> Tensor:# 获取输入 Tensor 的最后一维大小n_axes = ids.shape[-1]# 对每个轴应用 rope 函数并在-3维上连接emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],dim=-3,)# 在第1维上增加一个维度return emb.unsqueeze(1)# 定义时间步嵌入函数,创建正弦时间步嵌入
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):"""创建正弦时间步嵌入。:param t: 一维 Tensor,包含每批次元素的索引,可以是小数。:param dim: 输出的维度。:param max_period: 控制嵌入的最小频率。:return: 一个 (N, D) 维的 Tensor,表示位置嵌入。"""# 根据时间因子缩放输入 Tensort = time_factor * t# 计算半维度half = dim // 2# 计算频率freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)# 计算嵌入args = t[:, None].float() * freqs[None]embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)# 如果维度是奇数,追加零向量if dim % 2:embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)# 如果 t 是浮点类型,将嵌入转换为 t 的类型if torch.is_floating_point(t):embedding = embedding.to(t)return embedding# 定义一个 MLP 嵌入器类
class MLPEmbedder(nn.Module):def __init__(self, in_dim: int, hidden_dim: int):super().__init__()# 初始化输入层、激活函数和输出层self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)self.silu = nn.SiLU()self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)def forward(self, x: Tensor) -> Tensor:# 执行前向传递,经过输入层、激活函数和输出层return self.out_layer(self.silu(self.in_layer(x)))# 定义 RMSNorm 类
class RMSNorm(torch.nn.Module):def __init__(self, dim: int):super().__init__()# 初始化尺度参数self.scale = nn.Parameter(torch.ones(dim))def forward(self, x: Tensor):# 将输入转换为浮点数x_dtype = x.dtypex = x.float()# 计算均方根归一化rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)# 应用归一化和尺度参数return (x * rrms).to(dtype=x_dtype) * self.scale# 定义 QKNorm 类
class QKNorm(torch.nn.Module):def __init__(self, dim: int):super().__init__()# 初始化查询和键的归一化self.query_norm = RMSNorm(dim)self.key_norm = RMSNorm(dim)def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:# 对查询和键进行归一化q = self.query_norm(q)k = self.key_norm(k)# 返回归一化后的查询、键以及原始值return q.to(v), k.to(v)# 定义自注意力机制类
class SelfAttention(nn.Module):def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):super().__init__()# 设置头的数量和每个头的维度self.num_heads = num_headshead_dim = dim // num_heads# 初始化查询、键、值线性变换层self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)# 初始化归一化层self.norm = QKNorm(head_dim)# 初始化投影层self.proj = nn.Linear(dim, dim)# 前向传播函数,接受输入张量和位置编码,返回处理后的张量def forward(self, x: Tensor, pe: Tensor) -> Tensor:# 将输入张量通过 qkv 层,生成查询、键、值的联合表示qkv = self.qkv(x)# 重新排列 qkv 张量,将其拆分成查询 (q)、键 (k)、值 (v),并根据头数 (num_heads) 分组q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)# 对查询、键和值进行归一化处理q, k = self.norm(q, k, v)# 计算注意力权重并应用于值,得到加权后的输出x = attention(q, k, v, pe=pe)# 通过 proj 层将注意力结果映射到输出空间x = self.proj(x)# 返回最终的输出张量return x
# 定义一个包含三个张量的结构体 ModulationOut
@dataclass
class ModulationOut:shift: Tensorscale: Tensorgate: Tensor# 定义一个继承自 nn.Module 的 Modulation 类
class Modulation(nn.Module):# 初始化方法,设置维度和是否双倍def __init__(self, dim: int, double: bool):super().__init__()self.is_double = double  # 存储是否为双倍标志self.multiplier = 6 if double else 3  # 根据标志设置 multiplierself.lin = nn.Linear(dim, self.multiplier * dim, bias=True)  # 定义线性层# 前向传播方法,处理输入张量并返回结果def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:# 应用激活函数后,进行线性变换,并将结果按 multiplier 切分out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)# 返回切分后的结果,前半部分和后半部分(如果是双倍)return (ModulationOut(*out[:3]),  # 前三部分ModulationOut(*out[3:]) if self.is_double else None,  # 后三部分(如果是双倍))# 定义一个继承自 nn.Module 的 DoubleStreamBlock 类
class DoubleStreamBlock(nn.Module):# 初始化方法,设置隐藏层大小、注意力头数、MLP 比例等def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):super().__init__()mlp_hidden_dim = int(hidden_size * mlp_ratio)  # 计算 MLP 隐藏层维度self.num_heads = num_heads  # 存储注意力头数self.hidden_size = hidden_size  # 存储隐藏层大小self.img_mod = Modulation(hidden_size, double=True)  # 定义图像模调模块self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)  # 定义图像的第一层归一化self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)  # 定义图像的自注意力模块self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)  # 定义图像的第二层归一化self.img_mlp = nn.Sequential(  # 定义图像的 MLP 网络nn.Linear(hidden_size, mlp_hidden_dim, bias=True),  # 第一层线性变换nn.GELU(approximate="tanh"),  # 激活函数nn.Linear(mlp_hidden_dim, hidden_size, bias=True),  # 第二层线性变换)self.txt_mod = Modulation(hidden_size, double=True)  # 定义文本模调模块self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)  # 定义文本的第一层归一化self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)  # 定义文本的自注意力模块self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)  # 定义文本的第二层归一化self.txt_mlp = nn.Sequential(  # 定义文本的 MLP 网络nn.Linear(hidden_size, mlp_hidden_dim, bias=True),  # 第一层线性变换nn.GELU(approximate="tanh"),  # 激活函数nn.Linear(mlp_hidden_dim, hidden_size, bias=True),  # 第二层线性变换)# 前向传播函数,处理图像和文本输入,返回更新后的图像和文本def forward(self, img: Tensor
# 定义一个 DiT 模块,其中包含并行的线性层以及调整的调制接口
class SingleStreamBlock(nn.Module):"""A DiT block with parallel linear layers as described inhttps://arxiv.org/abs/2302.05442 and adapted modulation interface."""def __init__(self,hidden_size: int,num_heads: int,mlp_ratio: float = 4.0,qk_scale: float | None = None,):super().__init__()# 初始化隐藏层维度和注意力头的数量self.hidden_dim = hidden_sizeself.num_heads = num_headshead_dim = hidden_size // num_heads# 计算缩放因子self.scale = qk_scale or head_dim**-0.5# 计算 MLP 层的隐藏维度self.mlp_hidden_dim = int(hidden_size * mlp_ratio)# 定义用于 QKV 和 MLP 输入的线性层self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)# 定义用于投影和 MLP 输出的线性层self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)# 定义归一化层self.norm = QKNorm(head_dim)# 定义层归一化层self.hidden_size = hidden_sizeself.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)# 定义激活函数和调制层self.mlp_act = nn.GELU(approximate="tanh")self.modulation = Modulation(hidden_size, double=False)def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:# 通过调制层计算调制因子mod, _ = self.modulation(vec)# 对输入进行预归一化并应用调制x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift# 将线性层的输出分割为 QKV 和 MLP 输入qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)# 重新排列 QKV 张量,并进行归一化q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)q, k = self.norm(q, k, v)# 计算注意力attn = attention(q, k, v, pe=pe)# 计算 MLP 流中的激活,拼接结果并通过第二个线性层output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))# 将原始输入与输出加权和相加return x + mod.gate * output# 定义最后一层的网络模块
class LastLayer(nn.Module):def __init__(self, hidden_size: int, patch_size: int, out_channels: int):super().__init__()# 定义最终的层归一化self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)# 定义线性层将隐藏维度映射到最终输出通道self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)# 定义自适应层归一化调制self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))def forward(self, x: Tensor, vec: Tensor) -> Tensor:# 通过调制层计算 shift 和 scaleshift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)# 归一化输入并应用 shift 和 scalex = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]# 通过线性层计算最终输出x = self.linear(x)return x

.\flux\src\flux\sampling.py

# 导入数学库
import math
# 导入 Callable 类型
from typing import Callable# 导入 PyTorch 库
import torch
# 从 einops 导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 torch 导入 Tensor 类型
from torch import Tensor# 从 model 模块导入 Flux 类
from .model import Flux
# 从 modules.conditioner 模块导入 HFEmbedder 类
from .modules.conditioner import HFEmbedder# 生成噪声的函数
def get_noise(num_samples: int,  # 生成的样本数量height: int,  # 高度width: int,  # 宽度device: torch.device,  # 计算设备dtype: torch.dtype,  # 数据类型seed: int,  # 随机种子
):return torch.randn(num_samples,  # 样本数量16,  # 通道数# 允许打包的高度和宽度2 * math.ceil(height / 16),2 * math.ceil(width / 16),device=device,  # 指定设备dtype=dtype,  # 指定数据类型generator=torch.Generator(device=device).manual_seed(seed),  # 使用指定种子初始化随机生成器)# 准备数据的函数
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:bs, c, h, w = img.shape  # 获取批量大小、通道数、高度和宽度if bs == 1 and not isinstance(prompt, str):  # 如果批量大小为1且提示不是字符串bs = len(prompt)  # 设置批量大小为提示列表的长度# 调整图像形状以适应后续处理img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)if img.shape[0] == 1 and bs > 1:  # 如果批量大小为1且实际批量大于1img = repeat(img, "1 ... -> bs ...", bs=bs)  # 复制图像以适应批量大小img_ids = torch.zeros(h // 2, w // 2, 3)  # 创建图像ID的零张量img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]  # 设置行IDimg_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]  # 设置列IDimg_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)  # 将ID张量重复以适应批量大小if isinstance(prompt, str):  # 如果提示是字符串prompt = [prompt]  # 将提示转换为列表txt = t5(prompt)  # 使用 t5 模型处理文本提示if txt.shape[0] == 1 and bs > 1:  # 如果文本的批量大小为1且实际批量大于1txt = repeat(txt, "1 ... -> bs ...", bs=bs)  # 复制文本以适应批量大小txt_ids = torch.zeros(bs, txt.shape[1], 3)  # 创建文本ID的零张量vec = clip(prompt)  # 使用 clip 模型处理文本提示if vec.shape[0] == 1 and bs > 1:  # 如果向量的批量大小为1且实际批量大于1vec = repeat(vec, "1 ... -> bs ...", bs=bs)  # 复制向量以适应批量大小return {"img": img,  # 返回处理后的图像"img_ids": img_ids.to(img.device),  # 返回图像ID,转移到图像所在设备"txt": txt.to(img.device),  # 返回处理后的文本,转移到图像所在设备"txt_ids": txt_ids.to(img.device),  # 返回文本ID,转移到图像所在设备"vec": vec.to(img.device),  # 返回处理后的向量,转移到图像所在设备}# 计算时间移位的函数
def time_shift(mu: float, sigma: float, t: Tensor):return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)  # 计算时间移位值# 获取线性函数的函数
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15  # 默认参数值
) -> Callable[[float], float]:  # 返回一个接受浮点数并返回浮点数的函数m = (y2 - y1) / (x2 - x1)  # 计算线性函数的斜率b = y1 - m * x1  # 计算线性函数的截距return lambda x: m * x + b  # 返回线性函数# 获取调度时间的函数
def get_schedule(num_steps: int,  # 步骤数量image_seq_len: int,  # 图像序列长度base_shift: float = 0.5,  # 基础偏移量max_shift: float = 1.15,  # 最大偏移量shift: bool = True,  # 是否应用偏移
) -> list[float]:  # 返回浮点数列表# 生成从1到0的时间步长timesteps = torch.linspace(1, 0, num_steps + 1)# 如果启用了偏移if shift:# 基于线性估算估计 mumu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)timesteps = time_shift(mu, 1.0, timesteps)  # 应用时间移位return timesteps.tolist()  # 返回时间步长的列表# 去噪函数
def denoise(model: Flux,  # 模型# 模型输入img: Tensor,  # 输入图像img_ids: Tensor,  # 图像IDtxt: Tensor,  # 处理后的文本txt_ids: Tensor,  # 文本IDvec: Tensor,  # 处理后的向量# 采样参数timesteps: list[float],  # 时间步长guidance: float = 4.0,  # 引导强度
):# 为每个图像创建引导向量guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)# 遍历当前时间步和前一个时间步的配对for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):# 创建一个张量 t_vec,其形状与 img 的第一个维度相同,值为 t_curr,数据类型和设备与 img 相同t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)# 使用当前时间步 t_vec 及其他参数调用模型,获得预测结果 predpred = model(img=img,img_ids=img_ids,txt=txt,txt_ids=txt_ids,y=vec,timesteps=t_vec,guidance=guidance_vec,)# 更新 img,增加预测结果 pred 和时间步差 (t_prev - t_curr) 的乘积img = img + (t_prev - t_curr) * pred# 返回更新后的 imgreturn img
# 定义一个函数,用于对 Tensor 进行重排列,调整维度
def unpack(x: Tensor, height: int, width: int) -> Tensor:# 使用 rearrange 函数重排列 Tensor 的维度return rearrange(x,# 指定输入维度和输出维度的转换规则"b (h w) (c ph pw) -> b c (h ph) (w pw)",# 根据输入的 height 和 width 计算重排列后的维度h=math.ceil(height / 16),w=math.ceil(width / 16),ph=2,pw=2,)

.\flux\src\flux\util.py

# 导入操作系统模块
import os
# 从 dataclasses 模块导入 dataclass 装饰器,用于创建数据类
from dataclasses import dataclass# 导入 PyTorch 库,用于张量操作和深度学习
import torch
# 从 einops 库导入 rearrange 函数,用于重排列和转换张量
from einops import rearrange
# 从 huggingface_hub 库导入 hf_hub_download 函数,用于下载模型文件
from huggingface_hub import hf_hub_download
# 从 imwatermark 库导入 WatermarkEncoder 类,用于在图像中嵌入水印
from imwatermark import WatermarkEncoder
# 从 safetensors 库导入 load_file 函数,并重命名为 load_sft,用于加载安全张量文件
from safetensors.torch import load_file as load_sft# 从 flux.model 模块导入 Flux 类和 FluxParams 类,用于模型定义和参数配置
from flux.model import Flux, FluxParams
# 从 flux.modules.autoencoder 模块导入 AutoEncoder 类和 AutoEncoderParams 类,用于自动编码器定义和参数配置
from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
# 从 flux.modules.conditioner 模块导入 HFEmbedder 类,用于条件嵌入
from flux.modules.conditioner import HFEmbedder# 定义一个数据类 ModelSpec,用于保存模型的各种规格和参数
@dataclass
class ModelSpec:# 定义模型参数params: FluxParams# 定义自动编码器参数ae_params: AutoEncoderParams# 定义检查点路径(可以为 None)ckpt_path: str | None# 定义自动编码器路径(可以为 None)ae_path: str | None# 定义模型仓库 ID(可以为 None)repo_id: str | None# 定义流文件仓库 ID(可以为 None)repo_flow: str | None# 定义自动编码器仓库 ID(可以为 None)repo_ae: str | None# 定义配置字典 configs,包含不同模型的规格
configs = {# 配置 "flux-dev" 模型的规格"flux-dev": ModelSpec(# 设置模型仓库 IDrepo_id="black-forest-labs/FLUX.1-dev",# 设置流文件仓库 IDrepo_flow="flux1-dev.safetensors",# 设置自动编码器仓库 IDrepo_ae="ae.safetensors",# 从环境变量获取检查点路径ckpt_path=os.getenv("FLUX_DEV"),# 设置 Flux 模型参数params=FluxParams(in_channels=64,vec_in_dim=768,context_in_dim=4096,hidden_size=3072,mlp_ratio=4.0,num_heads=24,depth=19,depth_single_blocks=38,axes_dim=[16, 56, 56],theta=10_000,qkv_bias=True,guidance_embed=True,),# 从环境变量获取自动编码器路径ae_path=os.getenv("AE"),# 设置自动编码器参数ae_params=AutoEncoderParams(resolution=256,in_channels=3,ch=128,out_ch=3,ch_mult=[1, 2, 4, 4],num_res_blocks=2,z_channels=16,scale_factor=0.3611,shift_factor=0.1159,),),# 配置 "flux-schnell" 模型的规格"flux-schnell": ModelSpec(# 设置模型仓库 IDrepo_id="black-forest-labs/FLUX.1-schnell",# 设置流文件仓库 IDrepo_flow="flux1-schnell.safetensors",# 设置自动编码器仓库 IDrepo_ae="ae.safetensors",# 从环境变量获取检查点路径ckpt_path=os.getenv("FLUX_SCHNELL"),# 设置 Flux 模型参数params=FluxParams(in_channels=64,vec_in_dim=768,context_in_dim=4096,hidden_size=3072,mlp_ratio=4.0,num_heads=24,depth=19,depth_single_blocks=38,axes_dim=[16, 56, 56],theta=10_000,qkv_bias=True,guidance_embed=False,),# 从环境变量获取自动编码器路径ae_path=os.getenv("AE"),# 设置自动编码器参数ae_params=AutoEncoderParams(resolution=256,in_channels=3,ch=128,out_ch=3,ch_mult=[1, 2, 4, 4],num_res_blocks=2,z_channels=16,scale_factor=0.3611,shift_factor=0.1159,),),
}# 定义函数 print_load_warning,用于打印加载警告信息
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:# 如果缺少的键和意外的键都存在,则分别打印它们的数量和列表if len(missing) > 0 and len(unexpected) > 0:print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))print("\n" + "-" * 79 + "\n")print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))# 如果只有缺少的键存在,则打印它们的数量和列表elif len(missing) > 0:print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))# 如果意外的键数量大于0elif len(unexpected) > 0:# 打印意外的键数量和它们的列表print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
# 定义加载模型的函数,指定模型名称、设备和是否从 HF 下载
def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):# 打印初始化模型的消息print("Init model")# 获取配置文件中的检查点路径ckpt_path = configs[name].ckpt_path# 如果检查点路径为空且需要从 HF 下载if (ckpt_path is Noneand configs[name].repo_id is not Noneand configs[name].repo_flow is not Noneand hf_download):# 从 HF 下载模型文件ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)# 根据是否有检查点路径选择设备with torch.device("meta" if ckpt_path is not None else device):# 初始化模型并设置数据类型为 bfloat16model = Flux(configs[name].params).to(torch.bfloat16)# 如果有检查点路径,加载模型状态if ckpt_path is not None:print("Loading checkpoint")# 加载检查点并转为字符串设备sd = load_sft(ckpt_path, device=str(device))# 加载状态字典,并检查缺失或意外的参数missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)print_load_warning(missing, unexpected)# 返回模型return model# 定义加载 T5 模型的函数,指定设备和最大序列长度
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:# 创建 HFEmbedder 对象,使用 T5 模型并设置最大序列长度和数据类型return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)# 定义加载 CLIP 模型的函数,指定设备
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:# 创建 HFEmbedder 对象,使用 CLIP 模型并设置最大序列长度和数据类型return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)# 定义加载自动编码器的函数,指定名称、设备和是否从 HF 下载
def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:# 获取配置文件中的自动编码器路径ckpt_path = configs[name].ae_path# 如果路径为空且需要从 HF 下载if (ckpt_path is Noneand configs[name].repo_id is not Noneand configs[name].repo_ae is not Noneand hf_download):# 从 HF 下载自动编码器文件ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)# 打印初始化自动编码器的消息print("Init AE")# 根据是否有检查点路径选择设备with torch.device("meta" if ckpt_path is not None else device):# 初始化自动编码器ae = AutoEncoder(configs[name].ae_params)# 如果有检查点路径,加载自动编码器状态if ckpt_path is not None:# 加载检查点并转为字符串设备sd = load_sft(ckpt_path, device=str(device))# 加载状态字典,并检查缺失或意外的参数missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)print_load_warning(missing, unexpected)# 返回自动编码器return ae# 定义水印嵌入器类
class WatermarkEmbedder:def __init__(self, watermark):# 初始化水印和比特位数self.watermark = watermarkself.num_bits = len(WATERMARK_BITS)# 初始化水印编码器self.encoder = WatermarkEncoder()# 设置水印比特数据self.encoder.set_watermark("bits", self.watermark)# 定义一个可调用对象的 `__call__` 方法,用于给输入图像添加预定义的水印def __call__(self, image: torch.Tensor) -> torch.Tensor:"""Adds a predefined watermark to the input imageArgs:image: ([N,] B, RGB, H, W) in range [-1, 1]Returns:same as input but watermarked"""# 将图像的像素值从范围 [-1, 1] 线性映射到 [0, 1]image = 0.5 * image + 0.5# 检查图像张量的形状是否是 4 维 (即 batch size 和通道数)squeeze = len(image.shape) == 4if squeeze:# 如果是 4 维,给图像增加一个额外的维度,变成 5 维image = image[None, ...]# 获取图像的 batch sizen = image.shape[0]# 将图像从 torch 张量转换为 numpy 数组,并调整形状和通道顺序image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]# watermarking libary expects input as cv2 BGR format# 遍历每张图像,为每张图像应用水印编码for k in range(image_np.shape[0]):image_np[k] = self.encoder.encode(image_np[k], "dwtDct")# 将图像从 numpy 数组转换回 torch 张量,恢复原始的形状和设备image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(image.device)# 将图像的像素值从 [0, 255] 归一化到 [0, 1]image = torch.clamp(image / 255, min=0.0, max=1.0)if squeeze:# 如果之前添加了额外的维度,则将其移除,恢复原始形状image = image[0]# 将图像的像素值从 [0, 1] 转换回 [-1, 1] 范围image = 2 * image - 1# 返回处理后的图像return image
# 固定的 48 位消息,随机选择的
WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
# bin(x)[2:] 将 x 转换为二进制字符串(去掉前缀 '0b'),然后用 int 将每一位转换为 0 或 1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
# 使用提取的位创建 WatermarkEmbedder 对象
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)

.\flux\src\flux\__init__.py

# 尝试从当前包的 `_version` 模块导入 `version` 和 `version_tuple`
try:from ._version import version as __version__  # type: ignore  # type: ignore 用于忽略类型检查器的警告from ._version import version_tuple
# 如果导入失败(模块不存在),则设置默认的版本信息
except ImportError:__version__ = "unknown (no version information available)"  # 设置版本号为未知version_tuple = (0, 0, "unknown", "noinfo")  # 设置版本元组为未知# 导入 Path 类以便处理文件路径
from pathlib import Path# 设置包的名称,将包名中的下划线替换为短横线
PACKAGE = __package__.replace("_", "-")
# 获取当前文件所在目录的路径
PACKAGE_ROOT = Path(__file__).parent

.\flux\src\flux\__main__.py

# 从同一目录下的 cli 模块导入 app 函数
from .cli import app# 如果当前模块是主程序,则执行 app 函数
if __name__ == "__main__":app()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.ryyt.cn/news/56840.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈,一经查实,立即删除!

相关文章

【优技教育】Oracle 19c OCP 082题库(第13题)- 2024年修正版

【优技教育】Oracle 19c OCP 082题库(Q 13题)- 2024年修正版 考试科目:1Z0-082 考试题量:90 通过分数:60% 考试时间:150min 本文为(CUUG 原创)整理并解析,转发请注明出处,禁止抄袭及未经注明出处的转载。 原文地址:http://www.cuug.com.cn/ocp/082kaoshitiku/3817564823…

最快捷查看电脑启动项内容

很多人好奇很多电脑的默认启动项从哪里的看,其实就在运行窗口开两个命令就行了。 第一个,看先用户端设置的启动项: shell:Startup 这个是针对当前登录用户的。 第二个,查看电脑最高权限的通用启动项shell:Common Startup 这个是针对所有用户的。 操作的方式很简单 就是把要…

react中使用echarts关系图

一,工作需求,展示几类数据关系,可缩放大小,可拖拽位置,在节点之间的连线上展示相关日期,每个节点展示本身信息,并且要求每个关系节点能点击。 实现情况如图所示:二,实现过程中遇到的问题: 关系图完美呈现,但关系节点点击后,整个关系图会杂乱无章的浮动,导致不知道…

易基因:中国农大田见晖教授团队揭示DNA甲基化保护早期胚胎线粒体基因组稳定性|项目文章

大家好,这里是专注表观组学十余年,领跑多组学科研服务的易基因。 在早期哺乳动物胚胎中,线粒体氧化代谢增强是着床后生存和发育的重要特征;着床前期的线粒体重塑是正常胚胎发生的关键事件。在这些变化中,氧化磷酸化(OXPHOS)增强对于支持着床后胚胎的高能量需求至关重要,…

WebShell流量特征检测_哥斯拉篇

80后用菜刀,90后用蚁剑,95后用冰蝎和哥斯拉,以phpshell连接为例,本文主要是对这四款经典的webshell管理工具进行流量分析和检测。 什么是一句话木马? 1、定义 顾名思义就是执行恶意指令的木马,通过技术手段上传到指定服务器并可以正常访问,将我们需要服务器执行的命令上…

IPv6基于策略的地址分配

IPv6基于策略的地址分配 RA的周期性发送使用的是组播方式,但是针对 RS 的回复使用组播和单播两种可能;如果 RA 都是以组播方式发送,那么同一个广播域下的所有终端都可以收到,如果要基于终端mac/link-local地址来控制分配策略,则应该使用单播方式回复,以限制RA被接收的范围…

stylus图床

<image src=C:\Users\11277\Downloads\bj.jpg ></image>