|
1 | 1 | #ifndef __LORA_HPP__ |
2 | 2 | #define __LORA_HPP__ |
3 | 3 |
|
| 4 | +#include <mutex> |
4 | 5 | #include "ggml_extend.hpp" |
5 | 6 |
|
6 | 7 | #define LORA_GRAPH_BASE_SIZE 10240 |
@@ -115,49 +116,61 @@ struct LoraModel : public GGMLRunner { |
115 | 116 | return "lora"; |
116 | 117 | } |
117 | 118 |
|
118 | | - bool load_from_file(bool filter_tensor = false) { |
| 119 | + bool load_from_file(bool filter_tensor = false, int n_threads = 0) { |
119 | 120 | LOG_INFO("loading LoRA from '%s'", file_path.c_str()); |
120 | 121 |
|
121 | 122 | if (load_failed) { |
122 | 123 | LOG_ERROR("init lora model loader from file failed: '%s'", file_path.c_str()); |
123 | 124 | return false; |
124 | 125 | } |
125 | 126 |
|
| 127 | + std::unordered_map<std::string, TensorStorage> tensors_to_create; |
| 128 | + std::mutex lora_mutex; |
126 | 129 | bool dry_run = true; |
127 | 130 | auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { |
128 | | - const std::string& name = tensor_storage.name; |
| 131 | + if (dry_run) { |
| 132 | + const std::string& name = tensor_storage.name; |
129 | 133 |
|
130 | | - if (filter_tensor && !contains(name, "lora")) { |
131 | | - // LOG_INFO("skipping LoRA tesnor '%s'", name.c_str()); |
132 | | - return true; |
133 | | - } |
134 | | - // LOG_INFO("lora_tensor %s", name.c_str()); |
135 | | - for (int i = 0; i < LORA_TYPE_COUNT; i++) { |
136 | | - if (name.find(type_fingerprints[i]) != std::string::npos) { |
137 | | - type = (lora_t)i; |
138 | | - break; |
| 134 | + if (filter_tensor && !contains(name, "lora")) { |
| 135 | + return true; |
139 | 136 | } |
140 | | - } |
141 | 137 |
|
142 | | - if (dry_run) { |
143 | | - struct ggml_tensor* real = ggml_new_tensor(params_ctx, |
144 | | - tensor_storage.type, |
145 | | - tensor_storage.n_dims, |
146 | | - tensor_storage.ne); |
147 | | - lora_tensors[name] = real; |
| 138 | + { |
| 139 | + std::lock_guard<std::mutex> lock(lora_mutex); |
| 140 | + for (int i = 0; i < LORA_TYPE_COUNT; i++) { |
| 141 | + if (name.find(type_fingerprints[i]) != std::string::npos) { |
| 142 | + type = (lora_t)i; |
| 143 | + break; |
| 144 | + } |
| 145 | + } |
| 146 | + tensors_to_create[name] = tensor_storage; |
| 147 | + } |
148 | 148 | } else { |
149 | | - auto real = lora_tensors[name]; |
150 | | - *dst_tensor = real; |
| 149 | + const std::string& name = tensor_storage.name; |
| 150 | + auto iter = lora_tensors.find(name); |
| 151 | + if (iter != lora_tensors.end()) { |
| 152 | + *dst_tensor = iter->second; |
| 153 | + } |
151 | 154 | } |
152 | | - |
153 | 155 | return true; |
154 | 156 | }; |
155 | 157 |
|
156 | | - model_loader.load_tensors(on_new_tensor_cb); |
| 158 | + model_loader.load_tensors(on_new_tensor_cb, n_threads); |
| 159 | + |
| 160 | + for (const auto& pair : tensors_to_create) { |
| 161 | + const auto& name = pair.first; |
| 162 | + const auto& ts = pair.second; |
| 163 | + struct ggml_tensor* real = ggml_new_tensor(params_ctx, |
| 164 | + ts.type, |
| 165 | + ts.n_dims, |
| 166 | + ts.ne); |
| 167 | + lora_tensors[name] = real; |
| 168 | + } |
| 169 | + |
157 | 170 | alloc_params_buffer(); |
158 | | - // exit(0); |
| 171 | + |
159 | 172 | dry_run = false; |
160 | | - model_loader.load_tensors(on_new_tensor_cb); |
| 173 | + model_loader.load_tensors(on_new_tensor_cb, n_threads); |
161 | 174 |
|
162 | 175 | LOG_DEBUG("lora type: \"%s\"/\"%s\"", lora_downs[type].c_str(), lora_ups[type].c_str()); |
163 | 176 |
|
|
0 commit comments