Skip to content

Commit 40a6a87

Browse files
authored
fix: resolve precision issues in SDXL VAE under fp16 (#888)
* fix: resolve precision issues in SDXL VAE under fp16 * add --force-sdxl-vae-conv-scale option * update docs
1 parent e370258 commit 40a6a87

File tree

8 files changed

+66
-44
lines changed

8 files changed

+66
-44
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ API and command-line option may change frequently.***
1717
- Image Models
1818
- SD1.x, SD2.x, [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo)
1919
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
20-
- !!!The VAE in SDXL encounters NaN issues under FP16, but unfortunately, the ggml_conv_2d only operates under FP16. Hence, a parameter is needed to specify the VAE that has fixed the FP16 NaN issue. You can find it here: [SDXL VAE FP16 Fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors).
2120
- [SD3/SD3.5](./docs/sd3.md)
2221
- [Flux-dev/Flux-schnell](./docs/flux.md)
2322
- [Chroma](./docs/chroma.md)
@@ -365,6 +364,7 @@ arguments:
365364
--vae-tile-size [X]x[Y] tile size for vae tiling (default: 32x32)
366365
--vae-relative-tile-size [X]x[Y] relative tile size for vae tiling, in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)
367366
--vae-tile-overlap OVERLAP tile overlap for vae tiling, in fraction of tile size (default: 0.5)
367+
--force-sdxl-vae-conv-scale force use of conv scale on sdxl vae
368368
--vae-on-cpu keep vae in cpu (for low vram)
369369
--clip-on-cpu keep clip in cpu (for low vram)
370370
--diffusion-fa use flash attention in the diffusion model (for low vram)

conditioner.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1457,7 +1457,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
14571457
const ConditionerParams& conditioner_params) {
14581458
std::string prompt;
14591459
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
1460-
size_t system_prompt_length = 0;
1460+
size_t system_prompt_length = 0;
14611461
int prompt_template_encode_start_idx = 34;
14621462
if (qwenvl->enable_vision && conditioner_params.ref_images.size() > 0) {
14631463
LOG_INFO("QwenImageEditPlusPipeline");

examples/cli/main.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ struct SDParams {
131131
prediction_t prediction = DEFAULT_PRED;
132132

133133
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
134+
bool force_sdxl_vae_conv_scale = false;
134135

135136
SDParams() {
136137
sd_sample_params_init(&sample_params);
@@ -198,6 +199,7 @@ void print_params(SDParams params) {
198199
printf(" seed: %zd\n", params.seed);
199200
printf(" batch_count: %d\n", params.batch_count);
200201
printf(" vae_tiling: %s\n", params.vae_tiling_params.enabled ? "true" : "false");
202+
printf(" force_sdxl_vae_conv_scale: %s\n", params.force_sdxl_vae_conv_scale ? "true" : "false");
201203
printf(" upscale_repeats: %d\n", params.upscale_repeats);
202204
printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false");
203205
printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false");
@@ -292,6 +294,7 @@ void print_usage(int argc, const char* argv[]) {
292294
printf(" --vae-tile-size [X]x[Y] tile size for vae tiling (default: 32x32)\n");
293295
printf(" --vae-relative-tile-size [X]x[Y] relative tile size for vae tiling, in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)\n");
294296
printf(" --vae-tile-overlap OVERLAP tile overlap for vae tiling, in fraction of tile size (default: 0.5)\n");
297+
printf(" --force-sdxl-vae-conv-scale force use of conv scale on sdxl vae\n");
295298
printf(" --vae-on-cpu keep vae in cpu (for low vram)\n");
296299
printf(" --clip-on-cpu keep clip in cpu (for low vram)\n");
297300
printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n");
@@ -562,6 +565,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
562565

563566
options.bool_options = {
564567
{"", "--vae-tiling", "", true, &params.vae_tiling_params.enabled},
568+
{"", "--force-sdxl-vae-conv-scale", "", true, &params.force_sdxl_vae_conv_scale},
565569
{"", "--offload-to-cpu", "", true, &params.offload_params_to_cpu},
566570
{"", "--control-net-cpu", "", true, &params.control_net_cpu},
567571
{"", "--clip-on-cpu", "", true, &params.clip_on_cpu},
@@ -1382,6 +1386,7 @@ int main(int argc, const char* argv[]) {
13821386
params.diffusion_flash_attn,
13831387
params.diffusion_conv_direct,
13841388
params.vae_conv_direct,
1389+
params.force_sdxl_vae_conv_scale,
13851390
params.chroma_use_dit_mask,
13861391
params.chroma_use_t5_mask,
13871392
params.chroma_t5_mask_pad,

ggml_extend.hpp

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -975,38 +975,28 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx,
975975
struct ggml_tensor* x,
976976
struct ggml_tensor* w,
977977
struct ggml_tensor* b,
978-
int s0 = 1,
979-
int s1 = 1,
980-
int p0 = 0,
981-
int p1 = 0,
982-
int d0 = 1,
983-
int d1 = 1) {
984-
x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1);
985-
if (b != NULL) {
986-
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
987-
// b = ggml_repeat(ctx, b, x);
988-
x = ggml_add_inplace(ctx, x, b);
978+
int s0 = 1,
979+
int s1 = 1,
980+
int p0 = 0,
981+
int p1 = 0,
982+
int d0 = 1,
983+
int d1 = 1,
984+
bool direct = false,
985+
float scale = 1.f) {
986+
if (scale != 1.f) {
987+
x = ggml_scale(ctx, x, scale);
988+
}
989+
if (direct) {
990+
x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1);
991+
} else {
992+
x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1);
993+
}
994+
if (scale != 1.f) {
995+
x = ggml_scale(ctx, x, 1.f / scale);
989996
}
990-
return x;
991-
}
992-
993-
// w: [OC*IC, KD, KH, KW]
994-
// x: [N*IC, ID, IH, IW]
995-
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d_direct(struct ggml_context* ctx,
996-
struct ggml_tensor* x,
997-
struct ggml_tensor* w,
998-
struct ggml_tensor* b,
999-
int s0 = 1,
1000-
int s1 = 1,
1001-
int p0 = 0,
1002-
int p1 = 0,
1003-
int d0 = 1,
1004-
int d1 = 1) {
1005-
x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1);
1006997
if (b != NULL) {
1007998
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
1008-
// b = ggml_repeat(ctx, b, x);
1009-
x = ggml_add(ctx, x, b);
999+
x = ggml_add_inplace(ctx, x, b);
10101000
}
10111001
return x;
10121002
}
@@ -2067,6 +2057,7 @@ class Conv2d : public UnaryBlock {
20672057
std::pair<int, int> dilation;
20682058
bool bias;
20692059
bool direct = false;
2060+
float scale = 1.f;
20702061

20712062
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
20722063
enum ggml_type wtype = GGML_TYPE_F16;
@@ -2097,6 +2088,10 @@ class Conv2d : public UnaryBlock {
20972088
direct = true;
20982089
}
20992090

2091+
void set_scale(float scale_value) {
2092+
scale = scale_value;
2093+
}
2094+
21002095
std::string get_desc() {
21012096
return "Conv2d";
21022097
}
@@ -2107,11 +2102,18 @@ class Conv2d : public UnaryBlock {
21072102
if (bias) {
21082103
b = params["bias"];
21092104
}
2110-
if (direct) {
2111-
return ggml_nn_conv_2d_direct(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
2112-
} else {
2113-
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
2114-
}
2105+
return ggml_nn_conv_2d(ctx,
2106+
x,
2107+
w,
2108+
b,
2109+
stride.second,
2110+
stride.first,
2111+
padding.second,
2112+
padding.first,
2113+
dilation.second,
2114+
dilation.first,
2115+
direct,
2116+
scale);
21152117
}
21162118
};
21172119

qwen_image.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ namespace Qwen {
535535
}
536536
}
537537
LOG_ERROR("qwen_image_params.num_layers: %ld", qwen_image_params.num_layers);
538-
qwen_image = QwenImageModel(qwen_image_params);
538+
qwen_image = QwenImageModel(qwen_image_params);
539539
qwen_image.init(params_ctx, tensor_types, prefix);
540540
}
541541

stable-diffusion.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -330,13 +330,6 @@ class StableDiffusionGGML {
330330

331331
if (sd_version_is_sdxl(version)) {
332332
scale_factor = 0.13025f;
333-
if (strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 && strlen(SAFE_STR(sd_ctx_params->taesd_path)) == 0) {
334-
LOG_WARN(
335-
"!!!It looks like you are using SDXL model. "
336-
"If you find that the generated images are completely black, "
337-
"try specifying SDXL VAE FP16 Fix with the --vae parameter. "
338-
"You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors");
339-
}
340333
} else if (sd_version_is_sd3(version)) {
341334
scale_factor = 1.5305f;
342335
} else if (sd_version_is_flux(version)) {
@@ -517,6 +510,15 @@ class StableDiffusionGGML {
517510
LOG_INFO("Using Conv2d direct in the vae model");
518511
first_stage_model->enable_conv2d_direct();
519512
}
513+
if (version == VERSION_SDXL &&
514+
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) {
515+
float vae_conv_2d_scale = 1.f / 32.f;
516+
LOG_WARN(
517+
"No VAE specified with --vae or --force-sdxl-vae-conv-scale flag set, "
518+
"using Conv2D scale %.3f",
519+
vae_conv_2d_scale);
520+
first_stage_model->set_conv2d_scale(vae_conv_2d_scale);
521+
}
520522
first_stage_model->alloc_params_buffer();
521523
first_stage_model->get_param_tensors(tensors, "first_stage_model");
522524
} else {

stable-diffusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ typedef struct {
164164
bool diffusion_flash_attn;
165165
bool diffusion_conv_direct;
166166
bool vae_conv_direct;
167+
bool force_sdxl_vae_conv_scale;
167168
bool chroma_use_dit_mask;
168169
bool chroma_use_t5_mask;
169170
int chroma_t5_mask_pad;

vae.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ struct VAE : public GGMLRunner {
530530
struct ggml_context* output_ctx) = 0;
531531
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
532532
virtual void enable_conv2d_direct(){};
533+
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
533534
};
534535

535536
struct AutoEncoderKL : public VAE {
@@ -558,6 +559,17 @@ struct AutoEncoderKL : public VAE {
558559
}
559560
}
560561

562+
void set_conv2d_scale(float scale) {
563+
std::vector<GGMLBlock*> blocks;
564+
ae.get_all_blocks(blocks);
565+
for (auto block : blocks) {
566+
if (block->get_desc() == "Conv2d") {
567+
auto conv_block = (Conv2d*)block;
568+
conv_block->set_scale(scale);
569+
}
570+
}
571+
}
572+
561573
std::string get_desc() {
562574
return "vae";
563575
}

0 commit comments

Comments
 (0)