Skip to content

Commit 5869987

Browse files
authored
fix: make weight override more robust against ggml changes (#760)
1 parent 48956ff commit 5869987

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2310,7 +2310,7 @@ std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std
23102310
if (type_name == "f32") {
23112311
tensor_type = GGML_TYPE_F32;
23122312
} else {
2313-
for (size_t i = 0; i < SD_TYPE_COUNT; i++) {
2313+
for (size_t i = 0; i < GGML_TYPE_COUNT; i++) {
23142314
auto trait = ggml_get_type_traits((ggml_type)i);
23152315
if (trait->to_float && trait->type_size && type_name == trait->type_name) {
23162316
tensor_type = (ggml_type)i;

stable-diffusion.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,9 @@ class StableDiffusionGGML {
265265
}
266266

267267
LOG_INFO("Version: %s ", model_version_to_str[version]);
268-
ggml_type wtype = (ggml_type)sd_ctx_params->wtype;
268+
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
269+
? (ggml_type)sd_ctx_params->wtype
270+
: GGML_TYPE_COUNT;
269271
if (wtype == GGML_TYPE_COUNT) {
270272
model_wtype = model_loader.get_sd_wtype();
271273
if (model_wtype == GGML_TYPE_COUNT) {
@@ -1465,11 +1467,14 @@ class StableDiffusionGGML {
14651467
#define NONE_STR "NONE"
14661468

14671469
const char* sd_type_name(enum sd_type_t type) {
1468-
return ggml_type_name((ggml_type)type);
1470+
if ((int)type < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)) {
1471+
return ggml_type_name((ggml_type)type);
1472+
}
1473+
return NONE_STR;
14691474
}
14701475

14711476
enum sd_type_t str_to_sd_type(const char* str) {
1472-
for (int i = 0; i < SD_TYPE_COUNT; i++) {
1477+
for (int i = 0; i < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT); i++) {
14731478
auto trait = ggml_get_type_traits((ggml_type)i);
14741479
if (!strcmp(str, trait->type_name)) {
14751480
return (enum sd_type_t)i;

0 commit comments

Comments
 (0)