Skip to content

Commit 9727c6b

Browse files
authored
fix: resolve VAE tiling problem in Qwen Image (#873)
1 parent beb99a2 commit 9727c6b

File tree

2 files changed

+42
-27
lines changed

2 files changed

+42
-27
lines changed

ggml_extend.hpp

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -483,12 +483,15 @@ __STATIC_INLINE__ void ggml_split_tensor_2d(struct ggml_tensor* input,
483483
int64_t width = output->ne[0];
484484
int64_t height = output->ne[1];
485485
int64_t channels = output->ne[2];
486+
int64_t ne3 = output->ne[3];
486487
GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32);
487488
for (int iy = 0; iy < height; iy++) {
488489
for (int ix = 0; ix < width; ix++) {
489490
for (int k = 0; k < channels; k++) {
490-
float value = ggml_tensor_get_f32(input, ix + x, iy + y, k);
491-
ggml_tensor_set_f32(output, value, ix, iy, k);
491+
for (int l = 0; l < ne3; l++) {
492+
float value = ggml_tensor_get_f32(input, ix + x, iy + y, k, l);
493+
ggml_tensor_set_f32(output, value, ix, iy, k, l);
494+
}
492495
}
493496
}
494497
}
@@ -511,6 +514,7 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
511514
int64_t width = input->ne[0];
512515
int64_t height = input->ne[1];
513516
int64_t channels = input->ne[2];
517+
int64_t ne3 = input->ne[3];
514518

515519
int64_t img_width = output->ne[0];
516520
int64_t img_height = output->ne[1];
@@ -519,24 +523,26 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
519523
for (int iy = y_skip; iy < height; iy++) {
520524
for (int ix = x_skip; ix < width; ix++) {
521525
for (int k = 0; k < channels; k++) {
522-
float new_value = ggml_tensor_get_f32(input, ix, iy, k);
523-
if (overlap_x > 0 || overlap_y > 0) { // blend colors in overlapped area
524-
float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k);
525-
526-
const float x_f_0 = (overlap_x > 0 && x > 0) ? (ix - x_skip) / float(overlap_x) : 1;
527-
const float x_f_1 = (overlap_x > 0 && x < (img_width - width)) ? (width - ix) / float(overlap_x) : 1;
528-
const float y_f_0 = (overlap_y > 0 && y > 0) ? (iy - y_skip) / float(overlap_y) : 1;
529-
const float y_f_1 = (overlap_y > 0 && y < (img_height - height)) ? (height - iy) / float(overlap_y) : 1;
530-
531-
const float x_f = std::min(std::min(x_f_0, x_f_1), 1.f);
532-
const float y_f = std::min(std::min(y_f_0, y_f_1), 1.f);
533-
534-
ggml_tensor_set_f32(
535-
output,
536-
old_value + new_value * ggml_smootherstep_f32(y_f) * ggml_smootherstep_f32(x_f),
537-
x + ix, y + iy, k);
538-
} else {
539-
ggml_tensor_set_f32(output, new_value, x + ix, y + iy, k);
526+
for (int l = 0; l < ne3; l++) {
527+
float new_value = ggml_tensor_get_f32(input, ix, iy, k, l);
528+
if (overlap_x > 0 || overlap_y > 0) { // blend colors in overlapped area
529+
float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k, l);
530+
531+
const float x_f_0 = (overlap_x > 0 && x > 0) ? (ix - x_skip) / float(overlap_x) : 1;
532+
const float x_f_1 = (overlap_x > 0 && x < (img_width - width)) ? (width - ix) / float(overlap_x) : 1;
533+
const float y_f_0 = (overlap_y > 0 && y > 0) ? (iy - y_skip) / float(overlap_y) : 1;
534+
const float y_f_1 = (overlap_y > 0 && y < (img_height - height)) ? (height - iy) / float(overlap_y) : 1;
535+
536+
const float x_f = std::min(std::min(x_f_0, x_f_1), 1.f);
537+
const float y_f = std::min(std::min(y_f_0, y_f_1), 1.f);
538+
539+
ggml_tensor_set_f32(
540+
output,
541+
old_value + new_value * ggml_smootherstep_f32(y_f) * ggml_smootherstep_f32(x_f),
542+
x + ix, y + iy, k, l);
543+
} else {
544+
ggml_tensor_set_f32(output, new_value, x + ix, y + iy, k, l);
545+
}
540546
}
541547
}
542548
}
@@ -852,8 +858,8 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
852858
}
853859

854860
struct ggml_init_params params = {};
855-
params.mem_size += input_tile_size_x * input_tile_size_y * input->ne[2] * sizeof(float); // input chunk
856-
params.mem_size += output_tile_size_x * output_tile_size_y * output->ne[2] * sizeof(float); // output chunk
861+
params.mem_size += input_tile_size_x * input_tile_size_y * input->ne[2] * input->ne[3] * sizeof(float); // input chunk
862+
params.mem_size += output_tile_size_x * output_tile_size_y * output->ne[2] * output->ne[3] * sizeof(float); // output chunk
857863
params.mem_size += 3 * ggml_tensor_overhead();
858864
params.mem_buffer = NULL;
859865
params.no_alloc = false;
@@ -868,8 +874,8 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
868874
}
869875

870876
// tiling
871-
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], 1);
872-
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], 1);
877+
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], input->ne[3]);
878+
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], output->ne[3]);
873879
int num_tiles = num_tiles_x * num_tiles_y;
874880
LOG_INFO("processing %i tiles", num_tiles);
875881
pretty_progress(0, num_tiles, 0.0f);

stable-diffusion.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,10 +1440,19 @@ class StableDiffusionGGML {
14401440
if (vae_tiling_params.enabled && !encode_video) {
14411441
// TODO wan2.2 vae support?
14421442
int C = sd_version_is_dit(version) ? 16 : 4;
1443-
if (!use_tiny_autoencoder) {
1444-
C *= 2;
1443+
int ne2;
1444+
int ne3;
1445+
if (sd_version_is_qwen_image(version)) {
1446+
ne2 = 1;
1447+
ne3 = C*x->ne[3];
1448+
} else {
1449+
if (!use_tiny_autoencoder) {
1450+
C *= 2;
1451+
}
1452+
ne2 = C;
1453+
ne3 = x->ne[3];
14451454
}
1446-
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, x->ne[3]);
1455+
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3);
14471456
}
14481457

14491458
if (sd_version_is_qwen_image(version)) {

0 commit comments

Comments
 (0)