3636
3737#include " frontend.cpp"
3838
39- const char * rng_type_to_str[] = {
40- " std_default" ,
41- " cuda" ,
42- };
43-
44- // Names of the sampler method, same order as enum sample_method in stable-diffusion.h
45- const char * sample_method_str[] = {
46- " euler_a" ,
47- " euler" ,
48- " heun" ,
49- " dpm2" ,
50- " dpm++2s_a" ,
51- " dpm++2m" ,
52- " dpm++2mv2" ,
53- " ipndm" ,
54- " ipndm_v" ,
55- " lcm" ,
56- };
57-
58- // Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
59- const char * schedule_str[] = {
60- " default" ,
61- " discrete" ,
62- " karras" ,
63- " exponential" ,
64- " ays" ,
65- " gits" ,
66- };
67-
68- enum SDMode {
69- TXT2IMG,
70- IMG2IMG,
71- MODE_COUNT
72- };
73-
7439struct SDCtxParams {
7540 std::string model_path;
7641 std::string clip_l_path;
@@ -105,8 +70,6 @@ struct SDRequestParams {
10570 // TODO set to true if esrgan_path is specified in args
10671 bool upscale = false ;
10772
108- SDMode mode = TXT2IMG;
109-
11073 std::string prompt;
11174 std::string negative_prompt;
11275
@@ -195,11 +158,11 @@ void print_params(SDParams params) {
195158 printf (" clip_skip: %d\n " , params.lastRequest .clip_skip );
196159 printf (" width: %d\n " , params.lastRequest .width );
197160 printf (" height: %d\n " , params.lastRequest .height );
198- printf (" sample_method: %s\n " , sample_method_str[ params.lastRequest .sample_method ] );
199- printf (" schedule: %s\n " , schedule_str[ params.ctxParams .schedule ] );
161+ printf (" sample_method: %s\n " , sd_sample_method_name ( params.lastRequest .sample_method ) );
162+ printf (" schedule: %s\n " , sd_schedule_name ( params.ctxParams .schedule ) );
200163 printf (" sample_steps: %d\n " , params.lastRequest .sample_steps );
201164 printf (" strength(img2img): %.2f\n " , params.lastRequest .strength );
202- printf (" rng: %s\n " , rng_type_to_str[ params.ctxParams .rng_type ] );
165+ printf (" rng: %s\n " , sd_rng_type_name ( params.ctxParams .rng_type ) );
203166 printf (" seed: %ld\n " , params.lastRequest .seed );
204167 printf (" batch_count: %d\n " , params.lastRequest .batch_count );
205168 printf (" vae_tiling: %s\n " , params.ctxParams .vae_tiling ? " true" : " false" );
@@ -512,17 +475,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
512475 break ;
513476 }
514477 const char * schedule_selected = argv[i];
515- int schedule_found = -1 ;
516- for (int d = 0 ; d < N_SCHEDULES; d++) {
517- if (!strcmp (schedule_selected, schedule_str[d])) {
518- schedule_found = d;
519- }
520- }
521- if (schedule_found == -1 ) {
478+ schedule_t schedule_found = str_to_schedule (schedule_selected);
479+ if (schedule_found == SCHEDULE_COUNT) {
522480 invalid_arg = true ;
523481 break ;
524482 }
525- params.ctxParams .schedule = ( schedule_t ) schedule_found;
483+ params.ctxParams .schedule = schedule_found;
526484 } else if (arg == " -s" || arg == " --seed" ) {
527485 if (++i >= argc) {
528486 invalid_arg = true ;
@@ -535,13 +493,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
535493 break ;
536494 }
537495 const char * sample_method_selected = argv[i];
538- int sample_method_found = -1 ;
539- for (int m = 0 ; m < N_SAMPLE_METHODS; m++) {
540- if (!strcmp (sample_method_selected, sample_method_str[m])) {
541- sample_method_found = m;
542- }
543- }
544- if (sample_method_found == -1 ) {
496+ int sample_method_found = str_to_sample_method (sample_method_selected);
497+ if (sample_method_found == SAMPLE_METHOD_COUNT) {
545498 invalid_arg = true ;
546499 break ;
547500 }
@@ -689,8 +642,8 @@ std::string get_image_params(SDParams params, int64_t seed) {
689642 parameter_string += " Seed: " + std::to_string (seed) + " , " ;
690643 parameter_string += " Size: " + std::to_string (params.lastRequest .width ) + " x" + std::to_string (params.lastRequest .height ) + " , " ;
691644 parameter_string += " Model: " + sd_basename (params.ctxParams .model_path ) + " , " ;
692- parameter_string += " RNG: " + std::string (rng_type_to_str[ params.ctxParams .rng_type ] ) + " , " ;
693- parameter_string += " Sampler: " + std::string (sample_method_str[ params.lastRequest .sample_method ] );
645+ parameter_string += " RNG: " + std::string (sd_rng_type_name ( params.ctxParams .rng_type ) ) + " , " ;
646+ parameter_string += " Sampler: " + std::string (sd_sample_method_name ( params.lastRequest .sample_method ) );
694647 if (params.ctxParams .schedule == KARRAS) {
695648 parameter_string += " karras" ;
696649 }
@@ -807,14 +760,9 @@ bool parseJsonPrompt(std::string json_str, SDParams* params) {
807760 try {
808761 std::string sample_method = payload[" sample_method" ];
809762
810- int sample_method_found = -1 ;
811- for (int m = 0 ; m < N_SAMPLE_METHODS; m++) {
812- if (!strcmp (sample_method.c_str (), sample_method_str[m])) {
813- sample_method_found = m;
814- }
815- }
816- if (sample_method_found >= 0 ) {
817- params->lastRequest .sample_method = (sample_method_t )sample_method_found;
763+ sample_method_t sample_method_found = str_to_sample_method (sample_method.c_str ());
764+ if (sample_method_found != SAMPLE_METHOD_COUNT) {
765+ params->lastRequest .sample_method = sample_method_found;
818766 } else {
819767 sd_log (sd_log_level_t ::SD_LOG_WARN, " Unknown sampling method: %s\n " , sample_method.c_str ());
820768 }
@@ -1011,16 +959,11 @@ bool parseJsonPrompt(std::string json_str, SDParams* params) {
1011959 }
1012960
1013961 try {
1014- std::string schedule = payload[" schedule" ];
1015- int schedule_found = -1 ;
1016- for (int m = 0 ; m < N_SCHEDULES; m++) {
1017- if (!strcmp (schedule.c_str (), schedule_str[m])) {
1018- schedule_found = m;
1019- }
1020- }
1021- if (schedule_found >= 0 ) {
1022- if (params->ctxParams .schedule != (schedule_t )schedule_found) {
1023- params->ctxParams .schedule = (schedule_t )schedule_found;
962+ std::string schedule = payload[" schedule" ];
963+ schedule_t schedule_found = str_to_schedule (schedule.c_str ());
964+ if (schedule_found != SCHEDULE_COUNT) {
965+ if (params->ctxParams .schedule != schedule_found) {
966+ params->ctxParams .schedule = schedule_found;
1024967 updatectx = true ;
1025968 }
1026969 } else {
@@ -1189,30 +1132,31 @@ void start_server(SDParams params) {
11891132 std::lock_guard<std::mutex> results_lock (results_mutex);
11901133 task_results[task_id] = task_json;
11911134 }
1192-
1193- sd_ctx = new_sd_ctx (params.ctxParams .model_path .c_str (),
1194- params.ctxParams .clip_l_path .c_str (),
1195- params.ctxParams .clip_g_path .c_str (),
1196- params.ctxParams .t5xxl_path .c_str (),
1197- params.ctxParams .diffusion_model_path .c_str (),
1198- params.ctxParams .vae_path .c_str (),
1199- params.ctxParams .taesd_path .c_str (),
1200- params.ctxParams .controlnet_path .c_str (),
1201- params.ctxParams .lora_model_dir .c_str (),
1202- params.ctxParams .embeddings_path .c_str (),
1203- params.ctxParams .stacked_id_embeddings_path .c_str (),
1204- params.ctxParams .vae_decode_only ,
1205- params.ctxParams .vae_tiling ,
1206- false ,
1207- params.ctxParams .n_threads ,
1208- params.ctxParams .wtype ,
1209- params.ctxParams .rng_type ,
1210- params.ctxParams .schedule ,
1211- params.ctxParams .clip_on_cpu ,
1212- params.ctxParams .control_net_cpu ,
1213- params.ctxParams .vae_on_cpu ,
1214- params.ctxParams .diffusion_flash_attn ,
1215- true , false , 1 );
1135+ sd_ctx_params_t sd_ctx_params = {
1136+ params.ctxParams .model_path .c_str (),
1137+ params.ctxParams .clip_l_path .c_str (),
1138+ params.ctxParams .clip_g_path .c_str (),
1139+ params.ctxParams .t5xxl_path .c_str (),
1140+ params.ctxParams .diffusion_model_path .c_str (),
1141+ params.ctxParams .vae_path .c_str (),
1142+ params.ctxParams .taesd_path .c_str (),
1143+ params.ctxParams .controlnet_path .c_str (),
1144+ params.ctxParams .lora_model_dir .c_str (),
1145+ params.ctxParams .embeddings_path .c_str (),
1146+ params.ctxParams .stacked_id_embeddings_path .c_str (),
1147+ params.ctxParams .vae_decode_only ,
1148+ params.ctxParams .vae_tiling ,
1149+ false ,
1150+ params.ctxParams .n_threads ,
1151+ params.ctxParams .wtype ,
1152+ params.ctxParams .rng_type ,
1153+ params.ctxParams .schedule ,
1154+ params.ctxParams .clip_on_cpu ,
1155+ params.ctxParams .control_net_cpu ,
1156+ params.ctxParams .vae_on_cpu ,
1157+ params.ctxParams .diffusion_flash_attn ,
1158+ true , false , 1 };
1159+ sd_ctx = new_sd_ctx (&sd_ctx_params);
12161160 if (sd_ctx == NULL ) {
12171161 printf (" new_sd_ctx_t failed\n " );
12181162 std::lock_guard<std::mutex> results_lock (results_mutex);
@@ -1235,29 +1179,47 @@ void start_server(SDParams params) {
12351179
12361180 {
12371181 sd_image_t * results;
1238- results = txt2img (sd_ctx,
1239- params.lastRequest .prompt .c_str (),
1240- params.lastRequest .negative_prompt .c_str (),
1241- params.lastRequest .clip_skip ,
1242- params.lastRequest .cfg_scale ,
1243- params.lastRequest .guidance ,
1244- 0 .f , // eta
1245- params.lastRequest .width ,
1246- params.lastRequest .height ,
1247- params.lastRequest .sample_method ,
1248- params.lastRequest .sample_steps ,
1249- params.lastRequest .seed ,
1250- params.lastRequest .batch_count ,
1251- NULL ,
1252- 1 ,
1253- params.lastRequest .style_ratio ,
1254- params.lastRequest .normalize_input ,
1255- params.input_id_images_path .c_str (),
1256- params.lastRequest .skip_layers .data (),
1257- params.lastRequest .skip_layers .size (),
1258- params.lastRequest .slg_scale ,
1259- params.lastRequest .skip_layer_start ,
1260- params.lastRequest .skip_layer_end );
1182+ sd_slg_params_t slg = {
1183+ params.lastRequest .skip_layers .data (),
1184+ params.lastRequest .skip_layers .size (),
1185+ params.lastRequest .skip_layer_start ,
1186+ params.lastRequest .skip_layer_end ,
1187+ params.lastRequest .slg_scale };
1188+ sd_guidance_params_t guidance = {
1189+ params.lastRequest .cfg_scale ,
1190+ params.lastRequest .cfg_scale ,
1191+ params.lastRequest .cfg_scale ,
1192+ params.lastRequest .guidance ,
1193+ slg};
1194+ sd_image_t input_image = {
1195+ (uint32_t )params.lastRequest .width ,
1196+ (uint32_t )params.lastRequest .height ,
1197+ 3 ,
1198+ NULL };
1199+ sd_image_t mask_img = input_image;
1200+ sd_img_gen_params_t gen_params = {
1201+ params.lastRequest .prompt .c_str (),
1202+ params.lastRequest .negative_prompt .c_str (),
1203+ params.lastRequest .clip_skip ,
1204+ guidance,
1205+ input_image,
1206+ NULL , // ref images
1207+ 0 , // ref images count
1208+ mask_img,
1209+ params.lastRequest .width ,
1210+ params.lastRequest .height ,
1211+ params.lastRequest .sample_method ,
1212+ params.lastRequest .sample_steps ,
1213+ 0 .f , // eta
1214+ 1 .f , // denoise strength
1215+ params.lastRequest .seed ,
1216+ params.lastRequest .batch_count ,
1217+ NULL , // control image ptr
1218+ 1 .f , // control strength
1219+ params.lastRequest .style_ratio ,
1220+ params.lastRequest .normalize_input ,
1221+ params.input_id_images_path .c_str ()};
1222+ results = generate_image (sd_ctx, &gen_params);
12611223
12621224 if (results == NULL ) {
12631225 printf (" generate failed\n " );
@@ -1328,7 +1290,7 @@ void start_server(SDParams params) {
13281290 params_json[" guidance" ] = params.lastRequest .guidance ;
13291291 params_json[" width" ] = params.lastRequest .width ;
13301292 params_json[" height" ] = params.lastRequest .height ;
1331- params_json[" sample_method" ] = sample_method_str[ params.lastRequest .sample_method ] ;
1293+ params_json[" sample_method" ] = sd_sample_method_name ( params.lastRequest .sample_method ) ;
13321294 params_json[" sample_steps" ] = params.lastRequest .sample_steps ;
13331295 params_json[" seed" ] = params.lastRequest .seed ;
13341296 params_json[" batch_count" ] = params.lastRequest .batch_count ;
@@ -1352,7 +1314,7 @@ void start_server(SDParams params) {
13521314 context_params[" n_threads" ] = params.ctxParams .n_threads ;
13531315 context_params[" wtype" ] = params.ctxParams .wtype ;
13541316 context_params[" rng_type" ] = params.ctxParams .rng_type ;
1355- context_params[" schedule" ] = schedule_str[ params.ctxParams .schedule ] ;
1317+ context_params[" schedule" ] = sd_schedule_name ( params.ctxParams .schedule ) ;
13561318 context_params[" clip_on_cpu" ] = params.ctxParams .clip_on_cpu ;
13571319 context_params[" control_net_cpu" ] = params.ctxParams .control_net_cpu ;
13581320 context_params[" vae_on_cpu" ] = params.ctxParams .vae_on_cpu ;
@@ -1390,17 +1352,17 @@ void start_server(SDParams params) {
13901352 svr->Get (" /sample_methods" , [](const httplib::Request& req, httplib::Response& res) {
13911353 using json = nlohmann::json;
13921354 json response;
1393- for (int m = 0 ; m < N_SAMPLE_METHODS ; m++) {
1394- response.push_back (sample_method_str[m] );
1355+ for (int m = 0 ; m < SAMPLE_METHOD_COUNT ; m++) {
1356+ response.push_back (sd_sample_method_name (( sample_method_t )m) );
13951357 }
13961358 res.set_content (response.dump (), " application/json" );
13971359 });
13981360
13991361 svr->Get (" /schedules" , [](const httplib::Request& req, httplib::Response& res) {
14001362 using json = nlohmann::json;
14011363 json response;
1402- for (int s = 0 ; s < N_SCHEDULES ; s++) {
1403- response.push_back (schedule_str[s] );
1364+ for (int s = 0 ; s < SCHEDULE_COUNT ; s++) {
1365+ response.push_back (sd_schedule_name (( schedule_t )s) );
14041366 }
14051367 res.set_content (response.dump (), " application/json" );
14061368 });
0 commit comments