2222#define STB_IMAGE_RESIZE_STATIC
2323#include " stb_image_resize.h"
2424
25+ #define IMATRIX_IMPL
26+ #include " imatrix.hpp"
27+ static IMatrixCollector g_collector;
28+
2529const char * rng_type_to_str[] = {
2630 " std_default" ,
2731 " cuda" ,
@@ -129,6 +133,12 @@ struct SDParams {
129133 float slg_scale = 0 .f;
130134 float skip_layer_start = 0 .01f ;
131135 float skip_layer_end = 0 .2f ;
136+
137+ /* Imatrix params */
138+
139+ std::string imatrix_out = " " ;
140+
141+ std::vector<std::string> imatrix_in = {};
132142};
133143
134144void print_params (SDParams params) {
@@ -204,6 +214,8 @@ void print_usage(int argc, const char* argv[]) {
204214 printf (" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n " );
205215 printf (" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n " );
206216 printf (" If not specified, the default is the type of the weight file\n " );
217+ printf (" --imat-out [PATH] If set, compute the imatrix for this run and save it to the provided path" );
218+ printf (" --imat-in [PATH] Use imatrix for quantization." );
207219 printf (" --lora-model-dir [DIR] lora model directory\n " );
208220 printf (" -i, --init-img [IMAGE] path to the input image, required by img2img\n " );
209221 printf (" --mask [MASK] path to the mask image, required by img2img with mask\n " );
@@ -629,6 +641,18 @@ void parse_args(int argc, const char** argv, SDParams& params) {
629641 break ;
630642 }
631643 params.skip_layer_end = std::stof (argv[i]);
644+ } else if (arg == " --imat-out" ) {
645+ if (++i >= argc) {
646+ invalid_arg = true ;
647+ break ;
648+ }
649+ params.imatrix_out = argv[i];
650+ } else if (arg == " --imat-in" ) {
651+ if (++i >= argc) {
652+ invalid_arg = true ;
653+ break ;
654+ }
655+ params.imatrix_in .push_back (std::string (argv[i]));
632656 } else {
633657 fprintf (stderr, " error: unknown argument: %s\n " , arg.c_str ());
634658 print_usage (argc, argv);
@@ -787,6 +811,10 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
787811 fflush (out_stream);
788812}
789813
814+ static bool collect_imatrix (struct ggml_tensor * t, bool ask, void * user_data) {
815+ return g_collector.collect_imatrix (t, ask, user_data);
816+ }
817+
790818int main (int argc, const char * argv[]) {
791819 SDParams params;
792820
@@ -799,8 +827,21 @@ int main(int argc, const char* argv[]) {
799827 printf (" %s" , sd_get_system_info ());
800828 }
801829
830+ if (params.imatrix_out != " " ) {
831+ sd_set_backend_eval_callback ((sd_graph_eval_callback_t )collect_imatrix, ¶ms);
832+ }
833+ if (params.imatrix_out != " " || params.mode == CONVERT || params.wtype != SD_TYPE_COUNT) {
834+ setConvertImatrixCollector ((void *)&g_collector);
835+ for (const auto & in_file : params.imatrix_in ) {
836+ printf (" loading imatrix from '%s'\n " , in_file.c_str ());
837+ if (!g_collector.load_imatrix (in_file.c_str ())) {
838+ printf (" Failed to load %s\n " , in_file.c_str ());
839+ }
840+ }
841+ }
842+
802843 if (params.mode == CONVERT) {
803- bool success = convert (params.model_path .c_str (), params.vae_path .c_str (), params.output_path .c_str (), params.wtype , NULL );
844+ bool success = convert (params.model_path .c_str (), params.vae_path .c_str (), params.output_path .c_str (), params.wtype );
804845 if (!success) {
805846 fprintf (stderr,
806847 " convert '%s'/'%s' to '%s' failed\n " ,
@@ -1075,19 +1116,19 @@ int main(int argc, const char* argv[]) {
10751116
10761117 std::string dummy_name, ext, lc_ext;
10771118 bool is_jpg;
1078- size_t last = params.output_path .find_last_of (" ." );
1119+ size_t last = params.output_path .find_last_of (" ." );
10791120 size_t last_path = std::min (params.output_path .find_last_of (" /" ),
10801121 params.output_path .find_last_of (" \\ " ));
1081- if (last != std::string::npos // filename has extension
1082- && (last_path == std::string::npos || last > last_path)) {
1122+ if (last != std::string::npos // filename has extension
1123+ && (last_path == std::string::npos || last > last_path)) {
10831124 dummy_name = params.output_path .substr (0 , last);
10841125 ext = lc_ext = params.output_path .substr (last);
10851126 std::transform (ext.begin (), ext.end (), lc_ext.begin (), ::tolower);
10861127 is_jpg = lc_ext == " .jpg" || lc_ext == " .jpeg" || lc_ext == " .jpe" ;
10871128 } else {
10881129 dummy_name = params.output_path ;
10891130 ext = lc_ext = " " ;
1090- is_jpg = false ;
1131+ is_jpg = false ;
10911132 }
10921133 // appending ".png" to absent or unknown extension
10931134 if (!is_jpg && lc_ext != " .png" ) {
@@ -1099,7 +1140,7 @@ int main(int argc, const char* argv[]) {
10991140 continue ;
11001141 }
11011142 std::string final_image_path = i > 0 ? dummy_name + " _" + std::to_string (i + 1 ) + ext : dummy_name + ext;
1102- if (is_jpg) {
1143+ if (is_jpg) {
11031144 stbi_write_jpg (final_image_path.c_str (), results[i].width , results[i].height , results[i].channel ,
11041145 results[i].data , 90 , get_image_params (params, params.seed + i).c_str ());
11051146 printf (" save result JPEG image to '%s'\n " , final_image_path.c_str ());
@@ -1111,6 +1152,9 @@ int main(int argc, const char* argv[]) {
11111152 free (results[i].data );
11121153 results[i].data = NULL ;
11131154 }
1155+ if (params.imatrix_out != " " ) {
1156+ g_collector.save_imatrix (params.imatrix_out );
1157+ }
11141158 free (results);
11151159 free_sd_ctx (sd_ctx);
11161160 free (control_image_buffer);
0 commit comments