Skip to content

Commit 55c2e05

Browse files
rmatifleejet
andauthored
feat: optimize tensor loading time (#790)
* opt tensor loading * fix build failure * revert the changes * allow the use of n_threads * fix lora loading * optimize lora loading * add mutex * use atomic * fix build * fix potential duplicate issue * avoid duplicate lookup of lora tensor * fix progeress bar * remove unused remove_duplicates --------- Co-authored-by: leejet <leejet714@gmail.com>
1 parent 52a97b3 commit 55c2e05

File tree

4 files changed

+326
-255
lines changed

4 files changed

+326
-255
lines changed

lora.hpp

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef __LORA_HPP__
22
#define __LORA_HPP__
33

4+
#include <mutex>
45
#include "ggml_extend.hpp"
56

67
#define LORA_GRAPH_BASE_SIZE 10240
@@ -115,49 +116,61 @@ struct LoraModel : public GGMLRunner {
115116
return "lora";
116117
}
117118

118-
bool load_from_file(bool filter_tensor = false) {
119+
bool load_from_file(bool filter_tensor = false, int n_threads = 0) {
119120
LOG_INFO("loading LoRA from '%s'", file_path.c_str());
120121

121122
if (load_failed) {
122123
LOG_ERROR("init lora model loader from file failed: '%s'", file_path.c_str());
123124
return false;
124125
}
125126

127+
std::unordered_map<std::string, TensorStorage> tensors_to_create;
128+
std::mutex lora_mutex;
126129
bool dry_run = true;
127130
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
128-
const std::string& name = tensor_storage.name;
131+
if (dry_run) {
132+
const std::string& name = tensor_storage.name;
129133

130-
if (filter_tensor && !contains(name, "lora")) {
131-
// LOG_INFO("skipping LoRA tesnor '%s'", name.c_str());
132-
return true;
133-
}
134-
// LOG_INFO("lora_tensor %s", name.c_str());
135-
for (int i = 0; i < LORA_TYPE_COUNT; i++) {
136-
if (name.find(type_fingerprints[i]) != std::string::npos) {
137-
type = (lora_t)i;
138-
break;
134+
if (filter_tensor && !contains(name, "lora")) {
135+
return true;
139136
}
140-
}
141137

142-
if (dry_run) {
143-
struct ggml_tensor* real = ggml_new_tensor(params_ctx,
144-
tensor_storage.type,
145-
tensor_storage.n_dims,
146-
tensor_storage.ne);
147-
lora_tensors[name] = real;
138+
{
139+
std::lock_guard<std::mutex> lock(lora_mutex);
140+
for (int i = 0; i < LORA_TYPE_COUNT; i++) {
141+
if (name.find(type_fingerprints[i]) != std::string::npos) {
142+
type = (lora_t)i;
143+
break;
144+
}
145+
}
146+
tensors_to_create[name] = tensor_storage;
147+
}
148148
} else {
149-
auto real = lora_tensors[name];
150-
*dst_tensor = real;
149+
const std::string& name = tensor_storage.name;
150+
auto iter = lora_tensors.find(name);
151+
if (iter != lora_tensors.end()) {
152+
*dst_tensor = iter->second;
153+
}
151154
}
152-
153155
return true;
154156
};
155157

156-
model_loader.load_tensors(on_new_tensor_cb);
158+
model_loader.load_tensors(on_new_tensor_cb, n_threads);
159+
160+
for (const auto& pair : tensors_to_create) {
161+
const auto& name = pair.first;
162+
const auto& ts = pair.second;
163+
struct ggml_tensor* real = ggml_new_tensor(params_ctx,
164+
ts.type,
165+
ts.n_dims,
166+
ts.ne);
167+
lora_tensors[name] = real;
168+
}
169+
157170
alloc_params_buffer();
158-
// exit(0);
171+
159172
dry_run = false;
160-
model_loader.load_tensors(on_new_tensor_cb);
173+
model_loader.load_tensors(on_new_tensor_cb, n_threads);
161174

162175
LOG_DEBUG("lora type: \"%s\"/\"%s\"", lora_downs[type].c_str(), lora_ups[type].c_str());
163176

0 commit comments

Comments
 (0)