摘要
1) 一句话总结
本文记录了从头训练 1.2B 参数文本到图像模型的消融实验结果,详细评估了表示对齐、训练目标、Token 路由、数据策略及优化器设置对训练效率和生成质量的具体影响。
2) 核心要点
- 基线设置:采用纯流匹配(Flow Matching)架构(1.2B 参数,256×256 分辨率),主要通过 FID、CMMD、DINO-MMD 和网络吞吐量进行评估。
- 表示对齐(REPA):引入 REPA 并结合冻结的 DINOv2 编码器能有效加速早期收敛,在速度与质量间取得最佳平衡;在潜在空间使用 REPA-E-VAE 也能在保持吞吐量(3.39 batches/sec)的同时使 FID 下降约 6 个点。
- 训练目标调整:对比流匹配可作为低成本的正则化手段;JiT(直接预测干净图像)在 256 分辨率下优势不显,但能极大稳定 1024×1024 高分辨率下的无 VAE 像素级训练。
- Token 路由与稀疏化:TREAD 和 SPRINT 技术在 1024 高分辨率下表现优异,不仅大幅提升吞吐量,还显著改善了生成质量(TREAD 使 FID 从 17.42 降至 14.10)。
- 数据与提示词策略:长且详细的提示词能提供更丰富的监督信号,显著优于短提示词;最佳实践是使用合成数据快速锁定全局结构,再用真实数据匹配纹理。
- 高质量微调(SFT):仅使用 3350 张高质量图像进行 20K 步的监督微调,即可为模型带来显著的构图和画面质感提升。
- 优化器选择:采用 Muon 优化器可带来实质性的质量提升,使 FID 从 18.20 降至 15.55。
3) 风险与不足
- REPA 的后期容量不匹配:REPA 在训练后期会限制模型生成高频细节,建议仅作为“预热”手段,在模型特征跟上后将其关闭。
- iREPA 表现不一致:iREPA(空间结构对齐)在不同视觉编码器上效果不稳定(在 DINOv3 上会导致性能下降),不建议作为默认配置。
- 低分辨率下的 Token 路由惩罚:在 256×256 分辨率下使用 Token 路由仅带来微小的速度提升,但会导致质量指标下降。
- BF16 权重存储陷阱:将模型权重存储为 BF16 会对数值敏感操作产生负面影响,导致 FID 明显恶化(从 18.20 升至 21.87)。必须严格遵守“使用 BF16 计算,但将权重和优化器状态保持在 FP32”的规则。
正文
欢迎回来!这是我们关于从头开始训练高效文本到图像模型系列文章的第二部分。
在第一篇文章中,我们介绍了我们的目标:完全在开源环境下、大规模地从头训练一个具有竞争力的文本到图像基础模型。我们主要关注了架构选择,并发布了早期的小型版本(1.2B参数)作为预览。
本文将重点从架构转向训练。我们的目标是记录在尝试让模型训练更快、收敛更可靠以及学习更好表示时,真正发挥作用的因素。我们将以实验日志的形式,复现或调整近期的一些想法,并在统一的设置下进行实现,报告它们在实践中对优化和收敛的影响。最后,我们不仅会孤立地报告这些技术,还会探索将它们结合使用时的效果。
在下一篇文章中,我们将以代码形式发布完整的训练配方。我们还将进行一次公开的“速通(speedrun)”测试,将最佳组件整合到一个配置中,并进行端到端的压力测试。
基线模型 (The Baseline)
在引入任何训练效率提升技术之前,我们首先建立一个干净的参考基线。这个基线有意保持简单:使用标准组件,避免辅助目标,不依赖架构捷径或节省算力的技巧。
具体来说,这是一个纯粹的流匹配(Flow Matching)训练设置。我们使用 PRX-1.2B 模型作为基线,在 Flux VAE 潜在空间中进行训练,并在所有比较中保持配置固定(除非另有说明)。
基线训练设置如下:
- 训练步数: 100k
- 数据集: 100万张由 MidJourneyV6 生成的公开合成图像
- 分辨率: 256×256
- 全局批次大小: 256
- 优化器: AdamW (lr: 1e-4, weight_decay: 0.0, eps: 1e-15, betas: 0.9, 0.95)
- 文本编码器: GemmaT5
- 位置编码: Rotary (RoPE)
- 注意力掩码: Padding mask
- EMA: 禁用
评估指标 (Benchmarking Metrics)
为了监控模型随时间的表现,我们依赖以下几个指标:
- FID (Fréchet Inception Distance): 衡量生成图像与真实图像分布的接近程度。较低的值通常与较高的样本保真度相关。
- CMMD (CLIP Maximum Mean Discrepancy): 使用 CLIP 图像嵌入计算分布距离。它通常比 FID 更能反映感知质量。
- DINO-MMD: 基于 DINOv2 图像嵌入的 MMD 距离,提供自监督视觉主干下分布偏移的补充视角。
- 网络吞吐量: 每秒处理的样本数(samples/s),作为端到端训练效率的衡量标准。
表示对齐 (Representation Alignment)
扩散模型和流模型通常只使用单一目标进行训练。表示对齐通过保留去噪目标,并添加一个辅助损失来直接监督中间特征(使用强大的冻结视觉编码器),从而加速早期学习。
REPA
REPA 在基础流匹配目标上添加了表示匹配项。学生模型被训练为从噪声状态中产生对噪声鲁棒、与数据一致的 Patch 表示,以便后续层可以专注于预测向量场和生成细节。
- 观察结果: 我们使用冻结的 DINOv2 和 DINOv3 进行了测试。添加对齐功能一致地提高了质量指标。DINOv3 实现了最佳的整体数据,但代价是训练速度变慢。DINOv2 则提供了一个更高效的权衡,在较小的速度损失下依然带来了实质性的收益。
iREPA
iREPA 认为应该对齐的是空间结构,而不是全局语义。它引入了轻量级的 3×3 卷积投影和空间归一化。
- 观察结果: 在 DINOv2 上应用这些调整时,收敛更加平滑,指标稳步提升。然而,在 DINOv3 上,这些调整并没有带来同样的提升,反而倾向于降低性能。鉴于这种不一致性,我们可能不会将其作为默认配置。
关于在整个训练过程中使用 REPA
REPA 是一个强大的早期加速器,但在训练后期可能会遇到“容量不匹配”的问题,从而限制模型生成高频细节。我们的建议是:将其作为“预热”阶段的过渡手段,在模型自身的生成特征跟上后将其关闭。
Token 潜在空间中的对齐
除了对齐中间特征,我们还可以直接塑造潜在空间。我们比较了两个预训练的自编码器:REPA-E-VAE(添加了 REPA 对齐目标)和 Flux2-AE(未添加 REPA)。
- 观察结果: 两种干预都使 FID 大幅下降了约 6 个点。Flux2-AE 在大多数指标上占优,但伴随着巨大的吞吐量惩罚(从 3.95 降至 1.79 batches/sec)。REPA-E-VAE 则是一个平衡的选择,达到了与 Flux2-AE 基本相同的 FID,同时吞吐量更接近基线(3.39 batches/sec)。
训练目标:超越原生流匹配
对比流匹配 (Contrastive Flow Matching)
对比流匹配通过添加一个对比项,将条件流与批次中的其他流推开,从而解决条件生成中的流重叠和“平均化”行为。
- 观察结果: 在我们的运行中,它在表示驱动的指标(CMMD 和 DINO-MMD)上产生了微小但可衡量的改进。虽然 FID 没有明显改善,但吞吐量成本可以忽略不计。我们将保留它作为一种低成本的正则化手段。
JiT (预测干净图像)
与其让网络预测偏离流形(off-manifold)的噪声或速度,不如让它直接预测干净的图像(x-prediction)。
- 观察结果: 在 256×256 潜在空间中,这种方法的优势并不明显,甚至会降低 CMMD 和 DINO-MMD。然而,它的真正优势在于稳定高维训练。使用 JiT,我们能够直接在 1024×1024 像素空间(无 VAE)上使用 32×32 的 Patch 进行训练,且优化保持稳定和快速。这为高分辨率、无 Tokenizer 的文本到图像训练打开了大门。
Token 路由与稀疏化以降低计算成本
为了降低 Transformer 处理大量 Token 的计算成本,我们测试了两种路由和稀疏化方法:
- TREAD: 随机选择一部分 Token 绕过某些层,随后再重新注入。
- SPRINT: 在计算最密集的中间层引入稀疏性(丢弃大部分 Token),同时保留密集的残差路径以维持全分辨率信息。
观察结果:
- 在 256×256 分辨率下,路由仅带来适度的吞吐量提升(7-9%),且伴随着质量指标的下降。
- 在 1024×1024 分辨率(1024 个 Token)下,情况完全改变。TREAD 和 SPRINT 都带来了巨大的吞吐量提升。更重要的是,质量不仅没有下降,反而显著提高。TREAD 的 FID 从 17.42 降至 14.10。在这个机制下,路由不再是边缘优化,而是提升训练速度和质量的重要手段。
数据 (Data)
长标题 vs 短标题
我们将包含构图、属性、光照等细节的长标题与简短的单行标题进行了对比。
- 观察结果: 切换到短标题会严重损害收敛性。长标题提供了更丰富的监督信号,通过将隐式的选择转化为显式的约束,消除了不确定性,实际上让学习问题变得更容易。
使用合成图像进行引导
我们比较了使用合成数据(MidjourneyV6)和真实数据(Pexels)训练的模型。
- 观察结果: 合成数据训练的模型在 CMMD 和 DINO-MMD(全局结构)上得分更高,而真实数据训练的模型在 FID(纹理和低级统计)上表现更好。实用的策略是:使用合成数据快速引导并锁定全局结构,后期使用真实图像来匹配摄影纹理。
使用 Alchemist 进行 SFT
我们在一个仅包含 3350 张高质量图像的精选数据集(Alchemist)上进行了 20K 步的监督微调(SFT)。
- 观察结果: 尽管数据集很小,但它产生了巨大的影响,为模型添加了明显的“风格层”,带来了更好的构图和更丰富的画面质感。
更多实用的训练技巧
Muon 优化器
我们尝试了 Muon 优化器,它试图在没有完整二阶方法开销的情况下应用条件更好的更新步骤。
- 观察结果: Muon 使得指标有了立竿见影的改善(FID 从 18.20 降至 15.55),证明优化器的选择可以带来实质性的质量提升。
精度陷阱:BF16 转换 vs 存储
使用 BF16 进行前向和反向传播计算是标准做法,但将模型权重存储为 BF16 会对数值敏感的操作(如归一化层、注意力 Softmax、RoPE 等)产生负面影响。
- 观察结果: 错误地将权重存储为 BF16 会导致 FID 明显恶化(从 18.20 升至 21.87)。严格的规则是:使用 BF16 进行计算,但将权重(和优化器状态)保持在 FP32。
总结与下一步计划
我们对 PRX 训练进行了一系列系统的消融实验。最大的收益来自表示对齐(REPA 提升早期收敛,更好的 Tokenizer 大幅提升质量)。目标函数的调整喜忧参半,但 x-prediction 使得稳定的 1024 分辨率像素级训练成为可能。Token 路由在高分辨率下取得了巨大成功。数据方面,长标题至关重要,合成数据与真实数据的结合策略非常有效,小规模 SFT 提升了画面质感。此外,Muon 优化器和避免 BF16 权重存储陷阱也对训练大有裨益。
在接下来的几周内,我们将发布 PRX 训练框架的完整源代码,并进行一次公开的 24 小时“速通”测试,看看将这些最佳理念结合在一起能达到怎样的高度。