44#include < stdio.h>
55#include < string.h>
66#include < time.h>
7- #include < iostream>
8- #include < random>
97#include < string>
108#include < vector>
119#include < filesystem>
1210#include " gosd.h"
1311
14- // #include "preprocessing.hpp"
15- #include " flux.hpp"
16- #include " stable-diffusion.h"
17-
1812#define STB_IMAGE_IMPLEMENTATION
1913#define STB_IMAGE_STATIC
2014#include " stb_image.h"
2923
3024// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
3125const char * sample_method_str[] = {
32- " euler_a " ,
26+ " default " ,
3327 " euler" ,
3428 " heun" ,
3529 " dpm2" ,
@@ -41,6 +35,7 @@ const char* sample_method_str[] = {
4135 " lcm" ,
4236 " ddim_trailing" ,
4337 " tcd" ,
38+ " euler_a" ,
4439};
4540
4641static_assert (std::size(sample_method_str) == SAMPLE_METHOD_COUNT, " sample method mismatch" );
@@ -173,7 +168,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
173168 }
174169 if (sample_method_found == -1 ) {
175170 fprintf (stderr, " Invalid sample method, default to EULER_A!\n " );
176- sample_method_found = EULER_A ;
171+ sample_method_found = sample_method_t ::SAMPLE_METHOD_DEFAULT ;
177172 }
178173 sample_method = (sample_method_t )sample_method_found;
179174
@@ -197,9 +192,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
197192 ctx_params.control_net_path = " " ;
198193 ctx_params.lora_model_dir = lora_dir;
199194 ctx_params.embedding_dir = " " ;
200- ctx_params.stacked_id_embed_dir = " " ;
201195 ctx_params.vae_decode_only = false ;
202- ctx_params.vae_tiling = false ;
203196 ctx_params.free_params_immediately = false ;
204197 ctx_params.n_threads = threads;
205198 ctx_params.rng_type = STD_DEFAULT_RNG;
@@ -225,29 +218,65 @@ int load_model(const char *model, char *model_path, char* options[], int threads
225218 return 0 ;
226219}
227220
228- int gen_image (char *text, char *negativeText, int width, int height, int steps, int64_t seed, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count) {
221+ void sd_tiling_params_set_enabled (sd_tiling_params_t *params, bool enabled) {
222+ params->enabled = enabled;
223+ }
224+
225+ void sd_tiling_params_set_tile_sizes (sd_tiling_params_t *params, int tile_size_x, int tile_size_y) {
226+ params->tile_size_x = tile_size_x;
227+ params->tile_size_y = tile_size_y;
228+ }
229+
230+ void sd_tiling_params_set_rel_sizes (sd_tiling_params_t *params, float rel_size_x, float rel_size_y) {
231+ params->rel_size_x = rel_size_x;
232+ params->rel_size_y = rel_size_y;
233+ }
234+
235+ void sd_tiling_params_set_target_overlap (sd_tiling_params_t *params, float target_overlap) {
236+ params->target_overlap = target_overlap;
237+ }
238+
239+ sd_tiling_params_t * sd_img_gen_params_get_vae_tiling_params (sd_img_gen_params_t *params) {
240+ return ¶ms->vae_tiling_params ;
241+ }
242+
243+ sd_img_gen_params_t * sd_img_gen_params_new (void ) {
244+ sd_img_gen_params_t *params = (sd_img_gen_params_t *)std::malloc (sizeof (sd_img_gen_params_t ));
245+ sd_img_gen_params_init (params);
246+ return params;
247+ }
248+
249+ void sd_img_gen_params_set_prompts (sd_img_gen_params_t *params, const char *prompt, const char *negative_prompt) {
250+ params->prompt = prompt;
251+ params->negative_prompt = negative_prompt;
252+ }
253+
254+ void sd_img_gen_params_set_dimensions (sd_img_gen_params_t *params, int width, int height) {
255+ params->width = width;
256+ params->height = height;
257+ }
258+
259+ void sd_img_gen_params_set_seed (sd_img_gen_params_t *params, int64_t seed) {
260+ params->seed = seed;
261+ }
262+
263+ int gen_image (sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count) {
229264
230265 sd_image_t * results;
231266
232267 std::vector<int > skip_layers = {7 , 8 , 9 };
233268
234269 fprintf (stderr, " Generating image\n " );
235270
236- sd_img_gen_params_t p;
237- sd_img_gen_params_init (&p);
238-
239- p.prompt = text;
240- p.negative_prompt = negativeText;
241- p.sample_params .guidance .txt_cfg = cfg_scale;
242- p.sample_params .guidance .slg .layers = skip_layers.data ();
243- p.sample_params .guidance .slg .layer_count = skip_layers.size ();
244- p.width = width;
245- p.height = height;
246- p.sample_params .sample_method = sample_method;
247- p.sample_params .sample_steps = steps;
248- p.seed = seed;
249- p.input_id_images_path = " " ;
250- p.sample_params .scheduler = scheduler;
271+ p->sample_params .guidance .txt_cfg = cfg_scale;
272+ p->sample_params .guidance .slg .layers = skip_layers.data ();
273+ p->sample_params .guidance .slg .layer_count = skip_layers.size ();
274+ p->sample_params .sample_method = sample_method;
275+ p->sample_params .sample_steps = steps;
276+ p->sample_params .scheduler = scheduler;
277+
278+ int width = p->width ;
279+ int height = p->height ;
251280
252281 // Handle input image for img2img
253282 bool has_input_image = (src_image != NULL && strlen (src_image) > 0 );
@@ -296,13 +325,13 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
296325 input_image_buffer = resized_image_buffer;
297326 }
298327
299- p. init_image = {(uint32_t )width, (uint32_t )height, 3 , input_image_buffer};
300- p. strength = strength;
328+ p-> init_image = {(uint32_t )width, (uint32_t )height, 3 , input_image_buffer};
329+ p-> strength = strength;
301330 fprintf (stderr, " Using img2img with strength: %.2f\n " , strength);
302331 } else {
303332 // No input image, use empty image for text-to-image
304- p. init_image = {(uint32_t )width, (uint32_t )height, 3 , NULL };
305- p. strength = 0 .0f ;
333+ p-> init_image = {(uint32_t )width, (uint32_t )height, 3 , NULL };
334+ p-> strength = 0 .0f ;
306335 }
307336
308337 // Handle mask image for inpainting
@@ -342,12 +371,12 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
342371 mask_image_buffer = resized_mask_buffer;
343372 }
344373
345- p. mask_image = {(uint32_t )width, (uint32_t )height, 1 , mask_image_buffer};
374+ p-> mask_image = {(uint32_t )width, (uint32_t )height, 1 , mask_image_buffer};
346375 fprintf (stderr, " Using inpainting with mask\n " );
347376 } else {
348377 // No mask image, create default full mask
349378 default_mask_image_vec.resize (width * height, 255 );
350- p. mask_image = {(uint32_t )width, (uint32_t )height, 1 , default_mask_image_vec.data ()};
379+ p-> mask_image = {(uint32_t )width, (uint32_t )height, 1 , default_mask_image_vec.data ()};
351380 }
352381
353382 // Handle reference images
@@ -405,13 +434,15 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
405434 }
406435
407436 if (!ref_images_vec.empty ()) {
408- p. ref_images = ref_images_vec.data ();
409- p. ref_images_count = ref_images_vec.size ();
437+ p-> ref_images = ref_images_vec.data ();
438+ p-> ref_images_count = ref_images_vec.size ();
410439 fprintf (stderr, " Using %zu reference images\n " , ref_images_vec.size ());
411440 }
412441 }
413442
414- results = generate_image (sd_c, &p);
443+ results = generate_image (sd_c, p);
444+
445+ std::free (p);
415446
416447 if (results == NULL ) {
417448 fprintf (stderr, " NO results\n " );
0 commit comments