Skip to content

Commit e370258

Browse files
feat: added prediction argument (#334)
1 parent a7d6d29 commit e370258

File tree

4 files changed

+153
-52
lines changed

4 files changed

+153
-52
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ arguments:
358358
--rng {std_default, cuda} RNG (default: cuda)
359359
-s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
360360
-b, --batch-count COUNT number of images to generate
361+
--prediction {eps, v, edm_v, sd3_flow, flux_flow} Prediction type override
361362
--clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
362363
<= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
363364
--vae-tiling process vae in tiles to reduce memory usage

examples/cli/main.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ struct SDParams {
8484

8585
std::string prompt;
8686
std::string negative_prompt;
87+
8788
int clip_skip = -1; // <= 0 represents unspecified
8889
int width = 512;
8990
int height = 512;
@@ -127,6 +128,8 @@ struct SDParams {
127128
int chroma_t5_mask_pad = 1;
128129
float flow_shift = INFINITY;
129130

131+
prediction_t prediction = DEFAULT_PRED;
132+
130133
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
131134

132135
SDParams() {
@@ -188,6 +191,7 @@ void print_params(SDParams params) {
188191
printf(" sample_params: %s\n", SAFE_STR(sample_params_str));
189192
printf(" high_noise_sample_params: %s\n", SAFE_STR(high_noise_sample_params_str));
190193
printf(" moe_boundary: %.3f\n", params.moe_boundary);
194+
printf(" prediction: %s\n", sd_prediction_name(params.prediction));
191195
printf(" flow_shift: %.2f\n", params.flow_shift);
192196
printf(" strength(img2img): %.2f\n", params.strength);
193197
printf(" rng: %s\n", sd_rng_type_name(params.rng_type));
@@ -281,6 +285,7 @@ void print_usage(int argc, const char* argv[]) {
281285
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
282286
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
283287
printf(" -b, --batch-count COUNT number of images to generate\n");
288+
printf(" --prediction {eps, v, edm_v, sd3_flow, flux_flow} Prediction type override.\n");
284289
printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
285290
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
286291
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
@@ -651,6 +656,20 @@ void parse_args(int argc, const char** argv, SDParams& params) {
651656
return 1;
652657
};
653658

659+
auto on_prediction_arg = [&](int argc, const char** argv, int index) {
660+
if (++index >= argc) {
661+
return -1;
662+
}
663+
const char* arg = argv[index];
664+
params.prediction = str_to_prediction(arg);
665+
if (params.prediction == PREDICTION_COUNT) {
666+
fprintf(stderr, "error: invalid prediction type %s\n",
667+
arg);
668+
return -1;
669+
}
670+
return 1;
671+
};
672+
654673
auto on_sample_method_arg = [&](int argc, const char** argv, int index) {
655674
if (++index >= argc) {
656675
return -1;
@@ -807,6 +826,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
807826
{"", "--rng", "", on_rng_arg},
808827
{"-s", "--seed", "", on_seed_arg},
809828
{"", "--sampling-method", "", on_sample_method_arg},
829+
{"", "--prediction", "", on_prediction_arg},
810830
{"", "--scheduler", "", on_schedule_arg},
811831
{"", "--skip-layers", "", on_skip_layers_arg},
812832
{"", "--high-noise-sampling-method", "", on_high_noise_sample_method_arg},
@@ -1354,6 +1374,7 @@ int main(int argc, const char* argv[]) {
13541374
params.n_threads,
13551375
params.wtype,
13561376
params.rng_type,
1377+
params.prediction,
13571378
params.offload_params_to_cpu,
13581379
params.clip_on_cpu,
13591380
params.control_net_cpu,

stable-diffusion.cpp

Lines changed: 118 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -700,64 +700,102 @@ class StableDiffusionGGML {
700700
ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM");
701701
}
702702

703-
// check is_using_v_parameterization_for_sd2
704-
if (sd_version_is_sd2(version)) {
705-
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
706-
is_using_v_parameterization = true;
707-
}
708-
} else if (sd_version_is_sdxl(version)) {
709-
if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) {
710-
// CosXL models
711-
// TODO: get sigma_min and sigma_max values from file
712-
is_using_edm_v_parameterization = true;
703+
if (sd_ctx_params->prediction != DEFAULT_PRED) {
704+
switch (sd_ctx_params->prediction) {
705+
case EPS_PRED:
706+
LOG_INFO("running in eps-prediction mode");
707+
break;
708+
case V_PRED:
709+
LOG_INFO("running in v-prediction mode");
710+
denoiser = std::make_shared<CompVisVDenoiser>();
711+
break;
712+
case EDM_V_PRED:
713+
LOG_INFO("running in v-prediction EDM mode");
714+
denoiser = std::make_shared<EDMVDenoiser>();
715+
break;
716+
case SD3_FLOW_PRED: {
717+
LOG_INFO("running in FLOW mode");
718+
float shift = sd_ctx_params->flow_shift;
719+
if (shift == INFINITY) {
720+
shift = 3.0;
721+
}
722+
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
723+
break;
724+
}
725+
case FLUX_FLOW_PRED: {
726+
LOG_INFO("running in Flux FLOW mode");
727+
float shift = sd_ctx_params->flow_shift;
728+
if (shift == INFINITY) {
729+
shift = 3.0;
730+
}
731+
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
732+
break;
733+
}
734+
default: {
735+
LOG_ERROR("Unknown parametrization %i", sd_ctx_params->prediction);
736+
return false;
737+
}
713738
}
714-
if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) {
739+
} else {
740+
if (sd_version_is_sd2(version)) {
741+
// check is_using_v_parameterization_for_sd2
742+
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
743+
is_using_v_parameterization = true;
744+
}
745+
} else if (sd_version_is_sdxl(version)) {
746+
if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) {
747+
// CosXL models
748+
// TODO: get sigma_min and sigma_max values from file
749+
is_using_edm_v_parameterization = true;
750+
}
751+
if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) {
752+
is_using_v_parameterization = true;
753+
}
754+
} else if (version == VERSION_SVD) {
755+
// TODO: V_PREDICTION_EDM
715756
is_using_v_parameterization = true;
716757
}
717-
} else if (version == VERSION_SVD) {
718-
// TODO: V_PREDICTION_EDM
719-
is_using_v_parameterization = true;
720-
}
721758

722-
if (sd_version_is_sd3(version)) {
723-
LOG_INFO("running in FLOW mode");
724-
float shift = sd_ctx_params->flow_shift;
725-
if (shift == INFINITY) {
726-
shift = 3.0;
727-
}
728-
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
729-
} else if (sd_version_is_flux(version)) {
730-
LOG_INFO("running in Flux FLOW mode");
731-
float shift = 1.0f; // TODO: validate
732-
for (auto pair : model_loader.tensor_storages_types) {
733-
if (pair.first.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
734-
shift = 1.15f;
735-
break;
759+
if (sd_version_is_sd3(version)) {
760+
LOG_INFO("running in FLOW mode");
761+
float shift = sd_ctx_params->flow_shift;
762+
if (shift == INFINITY) {
763+
shift = 3.0;
736764
}
765+
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
766+
} else if (sd_version_is_flux(version)) {
767+
LOG_INFO("running in Flux FLOW mode");
768+
float shift = 1.0f; // TODO: validate
769+
for (auto pair : model_loader.tensor_storages_types) {
770+
if (pair.first.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
771+
shift = 1.15f;
772+
break;
773+
}
774+
}
775+
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
776+
} else if (sd_version_is_wan(version)) {
777+
LOG_INFO("running in FLOW mode");
778+
float shift = sd_ctx_params->flow_shift;
779+
if (shift == INFINITY) {
780+
shift = 5.0;
781+
}
782+
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
783+
} else if (sd_version_is_qwen_image(version)) {
784+
LOG_INFO("running in FLOW mode");
785+
float shift = sd_ctx_params->flow_shift;
786+
if (shift == INFINITY) {
787+
shift = 3.0;
788+
}
789+
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
790+
} else if (is_using_v_parameterization) {
791+
LOG_INFO("running in v-prediction mode");
792+
denoiser = std::make_shared<CompVisVDenoiser>();
793+
} else if (is_using_edm_v_parameterization) {
794+
LOG_INFO("running in v-prediction EDM mode");
795+
denoiser = std::make_shared<EDMVDenoiser>();
796+
} else {
797+
LOG_INFO("running in eps-prediction mode");
737798
}
738-
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
739-
} else if (sd_version_is_wan(version)) {
740-
LOG_INFO("running in FLOW mode");
741-
float shift = sd_ctx_params->flow_shift;
742-
if (shift == INFINITY) {
743-
shift = 5.0;
744-
}
745-
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
746-
} else if (sd_version_is_qwen_image(version)) {
747-
LOG_INFO("running in FLOW mode");
748-
float shift = sd_ctx_params->flow_shift;
749-
if (shift == INFINITY) {
750-
shift = 3.0;
751-
}
752-
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
753-
} else if (is_using_v_parameterization) {
754-
LOG_INFO("running in v-prediction mode");
755-
denoiser = std::make_shared<CompVisVDenoiser>();
756-
} else if (is_using_edm_v_parameterization) {
757-
LOG_INFO("running in v-prediction EDM mode");
758-
denoiser = std::make_shared<EDMVDenoiser>();
759-
} else {
760-
LOG_INFO("running in eps-prediction mode");
761799
}
762800

763801
auto comp_vis_denoiser = std::dynamic_pointer_cast<CompVisDenoiser>(denoiser);
@@ -1742,13 +1780,39 @@ enum scheduler_t str_to_schedule(const char* str) {
17421780
return SCHEDULE_COUNT;
17431781
}
17441782

1783+
const char* prediction_to_str[] = {
1784+
"default",
1785+
"eps",
1786+
"v",
1787+
"edm_v",
1788+
"sd3_flow",
1789+
"flux_flow",
1790+
};
1791+
1792+
const char* sd_prediction_name(enum prediction_t prediction) {
1793+
if (prediction < PREDICTION_COUNT) {
1794+
return prediction_to_str[prediction];
1795+
}
1796+
return NONE_STR;
1797+
}
1798+
1799+
enum prediction_t str_to_prediction(const char* str) {
1800+
for (int i = 0; i < PREDICTION_COUNT; i++) {
1801+
if (!strcmp(str, prediction_to_str[i])) {
1802+
return (enum prediction_t)i;
1803+
}
1804+
}
1805+
return PREDICTION_COUNT;
1806+
}
1807+
17451808
void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
17461809
*sd_ctx_params = {};
17471810
sd_ctx_params->vae_decode_only = true;
17481811
sd_ctx_params->free_params_immediately = true;
17491812
sd_ctx_params->n_threads = get_num_physical_cores();
17501813
sd_ctx_params->wtype = SD_TYPE_COUNT;
17511814
sd_ctx_params->rng_type = CUDA_RNG;
1815+
sd_ctx_params->prediction = DEFAULT_PRED;
17521816
sd_ctx_params->offload_params_to_cpu = false;
17531817
sd_ctx_params->keep_clip_on_cpu = false;
17541818
sd_ctx_params->keep_control_net_on_cpu = false;
@@ -1788,6 +1852,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
17881852
"n_threads: %d\n"
17891853
"wtype: %s\n"
17901854
"rng_type: %s\n"
1855+
"prediction: %s\n"
17911856
"offload_params_to_cpu: %s\n"
17921857
"keep_clip_on_cpu: %s\n"
17931858
"keep_control_net_on_cpu: %s\n"
@@ -1816,6 +1881,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
18161881
sd_ctx_params->n_threads,
18171882
sd_type_name(sd_ctx_params->wtype),
18181883
sd_rng_type_name(sd_ctx_params->rng_type),
1884+
sd_prediction_name(sd_ctx_params->prediction),
18191885
BOOL_STR(sd_ctx_params->offload_params_to_cpu),
18201886
BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
18211887
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),

stable-diffusion.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,16 @@ enum scheduler_t {
6464
SCHEDULE_COUNT
6565
};
6666

67+
enum prediction_t {
68+
DEFAULT_PRED,
69+
EPS_PRED,
70+
V_PRED,
71+
EDM_V_PRED,
72+
SD3_FLOW_PRED,
73+
FLUX_FLOW_PRED,
74+
PREDICTION_COUNT
75+
};
76+
6777
// same as enum ggml_type
6878
enum sd_type_t {
6979
SD_TYPE_F32 = 0,
@@ -146,6 +156,7 @@ typedef struct {
146156
int n_threads;
147157
enum sd_type_t wtype;
148158
enum rng_type_t rng_type;
159+
enum prediction_t prediction;
149160
bool offload_params_to_cpu;
150161
bool keep_clip_on_cpu;
151162
bool keep_control_net_on_cpu;
@@ -255,6 +266,8 @@ SD_API const char* sd_sample_method_name(enum sample_method_t sample_method);
255266
SD_API enum sample_method_t str_to_sample_method(const char* str);
256267
SD_API const char* sd_schedule_name(enum scheduler_t scheduler);
257268
SD_API enum scheduler_t str_to_schedule(const char* str);
269+
SD_API const char* sd_prediction_name(enum prediction_t prediction);
270+
SD_API enum prediction_t str_to_prediction(const char* str);
258271

259272
SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params);
260273
SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);

0 commit comments

Comments
 (0)