@@ -700,64 +700,102 @@ class StableDiffusionGGML {
700700 ggml_backend_is_cpu (clip_backend) ? " RAM" : " VRAM" );
701701 }
702702
703- // check is_using_v_parameterization_for_sd2
704- if (sd_version_is_sd2 (version)) {
705- if (is_using_v_parameterization_for_sd2 (ctx, sd_version_is_inpaint (version))) {
706- is_using_v_parameterization = true ;
707- }
708- } else if (sd_version_is_sdxl (version)) {
709- if (model_loader.tensor_storages_types .find (" edm_vpred.sigma_max" ) != model_loader.tensor_storages_types .end ()) {
710- // CosXL models
711- // TODO: get sigma_min and sigma_max values from file
712- is_using_edm_v_parameterization = true ;
703+ if (sd_ctx_params->prediction != DEFAULT_PRED) {
704+ switch (sd_ctx_params->prediction ) {
705+ case EPS_PRED:
706+ LOG_INFO (" running in eps-prediction mode" );
707+ break ;
708+ case V_PRED:
709+ LOG_INFO (" running in v-prediction mode" );
710+ denoiser = std::make_shared<CompVisVDenoiser>();
711+ break ;
712+ case EDM_V_PRED:
713+ LOG_INFO (" running in v-prediction EDM mode" );
714+ denoiser = std::make_shared<EDMVDenoiser>();
715+ break ;
716+ case SD3_FLOW_PRED: {
717+ LOG_INFO (" running in FLOW mode" );
718+ float shift = sd_ctx_params->flow_shift ;
719+ if (shift == INFINITY) {
720+ shift = 3.0 ;
721+ }
722+ denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
723+ break ;
724+ }
725+ case FLUX_FLOW_PRED: {
726+ LOG_INFO (" running in Flux FLOW mode" );
727+ float shift = sd_ctx_params->flow_shift ;
728+ if (shift == INFINITY) {
729+ shift = 3.0 ;
730+ }
731+ denoiser = std::make_shared<FluxFlowDenoiser>(shift);
732+ break ;
733+ }
734+ default : {
735+ LOG_ERROR (" Unknown parametrization %i" , sd_ctx_params->prediction );
736+ return false ;
737+ }
713738 }
714- if (model_loader.tensor_storages_types .find (" v_pred" ) != model_loader.tensor_storages_types .end ()) {
739+ } else {
740+ if (sd_version_is_sd2 (version)) {
741+ // check is_using_v_parameterization_for_sd2
742+ if (is_using_v_parameterization_for_sd2 (ctx, sd_version_is_inpaint (version))) {
743+ is_using_v_parameterization = true ;
744+ }
745+ } else if (sd_version_is_sdxl (version)) {
746+ if (model_loader.tensor_storages_types .find (" edm_vpred.sigma_max" ) != model_loader.tensor_storages_types .end ()) {
747+ // CosXL models
748+ // TODO: get sigma_min and sigma_max values from file
749+ is_using_edm_v_parameterization = true ;
750+ }
751+ if (model_loader.tensor_storages_types .find (" v_pred" ) != model_loader.tensor_storages_types .end ()) {
752+ is_using_v_parameterization = true ;
753+ }
754+ } else if (version == VERSION_SVD) {
755+ // TODO: V_PREDICTION_EDM
715756 is_using_v_parameterization = true ;
716757 }
717- } else if (version == VERSION_SVD) {
718- // TODO: V_PREDICTION_EDM
719- is_using_v_parameterization = true ;
720- }
721758
722- if (sd_version_is_sd3 (version)) {
723- LOG_INFO (" running in FLOW mode" );
724- float shift = sd_ctx_params->flow_shift ;
725- if (shift == INFINITY) {
726- shift = 3.0 ;
727- }
728- denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
729- } else if (sd_version_is_flux (version)) {
730- LOG_INFO (" running in Flux FLOW mode" );
731- float shift = 1 .0f ; // TODO: validate
732- for (auto pair : model_loader.tensor_storages_types ) {
733- if (pair.first .find (" model.diffusion_model.guidance_in.in_layer.weight" ) != std::string::npos) {
734- shift = 1 .15f ;
735- break ;
759+ if (sd_version_is_sd3 (version)) {
760+ LOG_INFO (" running in FLOW mode" );
761+ float shift = sd_ctx_params->flow_shift ;
762+ if (shift == INFINITY) {
763+ shift = 3.0 ;
736764 }
765+ denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
766+ } else if (sd_version_is_flux (version)) {
767+ LOG_INFO (" running in Flux FLOW mode" );
768+ float shift = 1 .0f ; // TODO: validate
769+ for (auto pair : model_loader.tensor_storages_types ) {
770+ if (pair.first .find (" model.diffusion_model.guidance_in.in_layer.weight" ) != std::string::npos) {
771+ shift = 1 .15f ;
772+ break ;
773+ }
774+ }
775+ denoiser = std::make_shared<FluxFlowDenoiser>(shift);
776+ } else if (sd_version_is_wan (version)) {
777+ LOG_INFO (" running in FLOW mode" );
778+ float shift = sd_ctx_params->flow_shift ;
779+ if (shift == INFINITY) {
780+ shift = 5.0 ;
781+ }
782+ denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
783+ } else if (sd_version_is_qwen_image (version)) {
784+ LOG_INFO (" running in FLOW mode" );
785+ float shift = sd_ctx_params->flow_shift ;
786+ if (shift == INFINITY) {
787+ shift = 3.0 ;
788+ }
789+ denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
790+ } else if (is_using_v_parameterization) {
791+ LOG_INFO (" running in v-prediction mode" );
792+ denoiser = std::make_shared<CompVisVDenoiser>();
793+ } else if (is_using_edm_v_parameterization) {
794+ LOG_INFO (" running in v-prediction EDM mode" );
795+ denoiser = std::make_shared<EDMVDenoiser>();
796+ } else {
797+ LOG_INFO (" running in eps-prediction mode" );
737798 }
738- denoiser = std::make_shared<FluxFlowDenoiser>(shift);
739- } else if (sd_version_is_wan (version)) {
740- LOG_INFO (" running in FLOW mode" );
741- float shift = sd_ctx_params->flow_shift ;
742- if (shift == INFINITY) {
743- shift = 5.0 ;
744- }
745- denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
746- } else if (sd_version_is_qwen_image (version)) {
747- LOG_INFO (" running in FLOW mode" );
748- float shift = sd_ctx_params->flow_shift ;
749- if (shift == INFINITY) {
750- shift = 3.0 ;
751- }
752- denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
753- } else if (is_using_v_parameterization) {
754- LOG_INFO (" running in v-prediction mode" );
755- denoiser = std::make_shared<CompVisVDenoiser>();
756- } else if (is_using_edm_v_parameterization) {
757- LOG_INFO (" running in v-prediction EDM mode" );
758- denoiser = std::make_shared<EDMVDenoiser>();
759- } else {
760- LOG_INFO (" running in eps-prediction mode" );
761799 }
762800
763801 auto comp_vis_denoiser = std::dynamic_pointer_cast<CompVisDenoiser>(denoiser);
@@ -1742,13 +1780,39 @@ enum scheduler_t str_to_schedule(const char* str) {
17421780 return SCHEDULE_COUNT;
17431781}
17441782
1783+ const char * prediction_to_str[] = {
1784+ " default" ,
1785+ " eps" ,
1786+ " v" ,
1787+ " edm_v" ,
1788+ " sd3_flow" ,
1789+ " flux_flow" ,
1790+ };
1791+
1792+ const char * sd_prediction_name (enum prediction_t prediction) {
1793+ if (prediction < PREDICTION_COUNT) {
1794+ return prediction_to_str[prediction];
1795+ }
1796+ return NONE_STR;
1797+ }
1798+
1799+ enum prediction_t str_to_prediction (const char * str) {
1800+ for (int i = 0 ; i < PREDICTION_COUNT; i++) {
1801+ if (!strcmp (str, prediction_to_str[i])) {
1802+ return (enum prediction_t )i;
1803+ }
1804+ }
1805+ return PREDICTION_COUNT;
1806+ }
1807+
17451808void sd_ctx_params_init (sd_ctx_params_t * sd_ctx_params) {
17461809 *sd_ctx_params = {};
17471810 sd_ctx_params->vae_decode_only = true ;
17481811 sd_ctx_params->free_params_immediately = true ;
17491812 sd_ctx_params->n_threads = get_num_physical_cores ();
17501813 sd_ctx_params->wtype = SD_TYPE_COUNT;
17511814 sd_ctx_params->rng_type = CUDA_RNG;
1815+ sd_ctx_params->prediction = DEFAULT_PRED;
17521816 sd_ctx_params->offload_params_to_cpu = false ;
17531817 sd_ctx_params->keep_clip_on_cpu = false ;
17541818 sd_ctx_params->keep_control_net_on_cpu = false ;
@@ -1788,6 +1852,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
17881852 " n_threads: %d\n "
17891853 " wtype: %s\n "
17901854 " rng_type: %s\n "
1855+ " prediction: %s\n "
17911856 " offload_params_to_cpu: %s\n "
17921857 " keep_clip_on_cpu: %s\n "
17931858 " keep_control_net_on_cpu: %s\n "
@@ -1816,6 +1881,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
18161881 sd_ctx_params->n_threads ,
18171882 sd_type_name (sd_ctx_params->wtype ),
18181883 sd_rng_type_name (sd_ctx_params->rng_type ),
1884+ sd_prediction_name (sd_ctx_params->prediction ),
18191885 BOOL_STR (sd_ctx_params->offload_params_to_cpu ),
18201886 BOOL_STR (sd_ctx_params->keep_clip_on_cpu ),
18211887 BOOL_STR (sd_ctx_params->keep_control_net_on_cpu ),
0 commit comments