1+ #include " stable-diffusion.h"
12#include < cstdint>
23#define GGML_MAX_NAME 128
34
2324
2425// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
2526const char * sample_method_str[] = {
26- " default" ,
2727 " euler" ,
28+ " euler_a" ,
2829 " heun" ,
2930 " dpm2" ,
3031 " dpm++2s_a" ,
@@ -35,29 +36,29 @@ const char* sample_method_str[] = {
3536 " lcm" ,
3637 " ddim_trailing" ,
3738 " tcd" ,
38- " euler_a" ,
3939};
4040
4141static_assert (std::size(sample_method_str) == SAMPLE_METHOD_COUNT, " sample method mismatch" );
4242
4343// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
4444const char * schedulers[] = {
45- " default" ,
4645 " discrete" ,
4746 " karras" ,
4847 " exponential" ,
4948 " ays" ,
5049 " gits" ,
50+ " sgm_uniform" ,
51+ " simple" ,
5152 " smoothstep" ,
53+ " lcm" ,
5254};
5355
54- static_assert (std::size(schedulers) == SCHEDULE_COUNT , " schedulers mismatch" );
56+ static_assert (std::size(schedulers) == SCHEDULER_COUNT , " schedulers mismatch" );
5557
5658sd_ctx_t * sd_c;
5759// Moved from the context (load time) to generation time params
58- scheduler_t scheduler = scheduler_t ::DEFAULT;
59-
60- sample_method_t sample_method;
60+ scheduler_t scheduler = SCHEDULER_COUNT;
61+ sample_method_t sample_method = SAMPLE_METHOD_COUNT;
6162
6263// Copied from the upstream CLI
6364static void sd_log_cb (enum sd_log_level_t level, const char * log, void * data) {
@@ -159,26 +160,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads
159160
160161 fprintf (stderr, " parsed options\n " );
161162
162- int sample_method_found = -1 ;
163- for (int m = 0 ; m < SAMPLE_METHOD_COUNT; m++) {
164- if (!strcmp (sampler, sample_method_str[m])) {
165- sample_method_found = m;
166- fprintf (stderr, " Found sampler: %s\n " , sampler);
167- }
168- }
169- if (sample_method_found == -1 ) {
170- fprintf (stderr, " Invalid sample method, default to EULER_A!\n " );
171- sample_method_found = sample_method_t ::SAMPLE_METHOD_DEFAULT;
172- }
173- sample_method = (sample_method_t )sample_method_found;
174-
175- for (int d = 0 ; d < SCHEDULE_COUNT; d++) {
176- if (!strcmp (scheduler_str, schedulers[d])) {
177- scheduler = (scheduler_t )d;
178- fprintf (stderr, " Found scheduler: %s\n " , scheduler_str);
179- }
180- }
181-
182163 fprintf (stderr, " Creating context\n " );
183164 sd_ctx_params_t ctx_params;
184165 sd_ctx_params_init (&ctx_params);
@@ -208,6 +189,29 @@ int load_model(const char *model, char *model_path, char* options[], int threads
208189 }
209190 fprintf (stderr, " Created context: OK\n " );
210191
192+ int sample_method_found = -1 ;
193+ for (int m = 0 ; m < SAMPLE_METHOD_COUNT; m++) {
194+ if (!strcmp (sampler, sample_method_str[m])) {
195+ sample_method_found = m;
196+ fprintf (stderr, " Found sampler: %s\n " , sampler);
197+ }
198+ }
199+ if (sample_method_found == -1 ) {
200+ fprintf (stderr, " Invalid sample method, default to EULER_A!\n " );
201+ sample_method_found = sd_get_default_sample_method (sd_ctx);
202+ }
203+ sample_method = (sample_method_t )sample_method_found;
204+
205+ for (int d = 0 ; d < SCHEDULER_COUNT; d++) {
206+ if (!strcmp (scheduler_str, schedulers[d])) {
207+ scheduler = (scheduler_t )d;
208+ fprintf (stderr, " Found scheduler: %s\n " , scheduler_str);
209+ }
210+ }
211+ if (scheduler == SCHEDULER_COUNT) {
212+ scheduler = sd_get_default_scheduler (sd_ctx);
213+ }
214+
211215 sd_c = sd_ctx;
212216
213217 // Clean up allocated memory
0 commit comments