avoid leaking ggml_context on failure
cleanup ggml-ci
This commit is contained in:
parent
8c1828aa6c
commit
537846fce5
1 changed files with 28 additions and 36 deletions
64
llama.cpp
64
llama.cpp
|
@ -5898,7 +5898,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: after the GGUF PR, this likely won't work and needs to be updated
|
||||
static int llama_apply_lora_from_file_internal(
|
||||
const struct llama_model & model, const char * path_lora, const char * path_base_model, int n_threads
|
||||
) {
|
||||
|
@ -5906,7 +5905,6 @@ static int llama_apply_lora_from_file_internal(
|
|||
|
||||
const int64_t t_start_lora_us = ggml_time_us();
|
||||
|
||||
// auto fin = std::ifstream(path_lora, std::ios::binary);
|
||||
llama_file fin(path_lora, "rb");
|
||||
|
||||
// verify magic and version
|
||||
|
@ -5931,11 +5929,11 @@ static int llama_apply_lora_from_file_internal(
|
|||
LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
|
||||
|
||||
// create a name -> tensor map of the model to accelerate lookups
|
||||
// find the max tensor size to estimate the required temporary buffer size
|
||||
size_t max_tensor_size = 0;
|
||||
std::unordered_map<std::string, struct ggml_tensor*> model_tensors;
|
||||
for (const auto & kv : model.tensors_by_name) {
|
||||
model_tensors.insert(kv);
|
||||
// find the max tensor size to estimate the required temporary buffer size
|
||||
// skip input and output layers as they are not often finetuned and can be very large
|
||||
if (kv.first.find("token_embd") != std::string::npos ||
|
||||
kv.first.find("output") != std::string::npos) {
|
||||
|
@ -5950,18 +5948,22 @@ static int llama_apply_lora_from_file_internal(
|
|||
size_t lora_ctx_size = max_tensor_size * 3;
|
||||
LLAMA_LOG_INFO("%s: allocating %.f MB for lora temporary buffer\n", __func__, lora_ctx_size / 1024.0 / 1024.0);
|
||||
std::vector<uint8_t> lora_buf(lora_ctx_size);
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = lora_buf.size();
|
||||
params.mem_buffer = lora_buf.data();
|
||||
params.no_alloc = false;
|
||||
|
||||
ggml_context * lora_ctx = ggml_init(params);
|
||||
std::unordered_map<std::string, struct ggml_tensor *> lora_tensors;
|
||||
using unique_context = std::unique_ptr<ggml_context, decltype(&ggml_free)>;
|
||||
|
||||
unique_context lora_ctx(nullptr, ggml_free);
|
||||
lora_ctx.reset(ggml_init(params));
|
||||
std::unordered_map<std::string, struct ggml_tensor *> lora_tensors;
|
||||
|
||||
// load base model
|
||||
std::unique_ptr<llama_model_loader> ml;
|
||||
ggml_context * base_ctx = NULL;
|
||||
|
||||
unique_context base_ctx(nullptr, ggml_free);
|
||||
std::vector<uint8_t> base_buf;
|
||||
if (path_base_model) {
|
||||
LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
|
||||
|
@ -5970,6 +5972,7 @@ static int llama_apply_lora_from_file_internal(
|
|||
size_t ctx_size;
|
||||
size_t mmapped_size;
|
||||
ml->calc_sizes(ctx_size, mmapped_size);
|
||||
|
||||
base_buf.resize(ctx_size);
|
||||
|
||||
ggml_init_params base_params;
|
||||
|
@ -5977,9 +5980,9 @@ static int llama_apply_lora_from_file_internal(
|
|||
base_params.mem_buffer = base_buf.data();
|
||||
base_params.no_alloc = ml->use_mmap;
|
||||
|
||||
base_ctx = ggml_init(base_params);
|
||||
base_ctx.reset(ggml_init(base_params));
|
||||
|
||||
// maybe this should in llama_model_loader
|
||||
// maybe this should be in llama_model_loader
|
||||
if (ml->use_mmap) {
|
||||
ml->mapping.reset(new llama_mmap(&ml->file, /* prefetch */ 0, ggml_is_numa()));
|
||||
}
|
||||
|
@ -6005,7 +6008,10 @@ static int llama_apply_lora_from_file_internal(
|
|||
fin.read_raw(&name_len, sizeof(name_len));
|
||||
fin.read_raw(&ftype, sizeof(ftype));
|
||||
|
||||
GGML_ASSERT(n_dims <= 2);
|
||||
if (n_dims != 2) {
|
||||
LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims);
|
||||
return 1;
|
||||
}
|
||||
|
||||
int32_t ne[2] = { 1, 1 };
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
|
@ -6050,15 +6056,8 @@ static int llama_apply_lora_from_file_internal(
|
|||
return false;
|
||||
}
|
||||
}
|
||||
ggml_tensor * lora_tensor;
|
||||
if (n_dims == 2) {
|
||||
lora_tensor = ggml_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1]);
|
||||
}
|
||||
else {
|
||||
LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims);
|
||||
return 1;
|
||||
}
|
||||
ggml_set_name(lora_tensor, "lora_tensor");
|
||||
ggml_tensor * lora_tensor = ggml_new_tensor_2d(lora_ctx.get(), wtype, ne[0], ne[1]);
|
||||
ggml_set_name(lora_tensor, name.c_str());
|
||||
|
||||
// load tensor data
|
||||
size_t offset = fin.tell();
|
||||
|
@ -6095,13 +6094,11 @@ static int llama_apply_lora_from_file_internal(
|
|||
|
||||
// load from base model
|
||||
if (gguf_find_tensor(ctx_gguf, base_name.c_str()) < 0) {
|
||||
// TODO: throw
|
||||
LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
// TODO: not tested!! maybe not working!
|
||||
base_t = ml->create_tensor(base_ctx, base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU);
|
||||
base_t = ml->create_tensor(base_ctx.get(), base_name, { dest_t->ne[0], dest_t->ne[1] }, GGML_BACKEND_CPU);
|
||||
ml->load_data_for(base_t);
|
||||
} else {
|
||||
base_t = dest_t;
|
||||
|
@ -6130,31 +6127,31 @@ static int llama_apply_lora_from_file_internal(
|
|||
}
|
||||
|
||||
// w = w + BA*s
|
||||
ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB);
|
||||
ggml_tensor * BA = ggml_mul_mat(lora_ctx.get(), loraA, loraB);
|
||||
offload_func(BA);
|
||||
ggml_set_name(BA, "BA");
|
||||
|
||||
if (scaling != 1.0f) {
|
||||
ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling);
|
||||
ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx.get(), scaling);
|
||||
ggml_set_name(scale_tensor, "scale_tensor");
|
||||
|
||||
BA = ggml_scale_inplace(lora_ctx, BA, scale_tensor);
|
||||
BA = ggml_scale_inplace(lora_ctx.get(), BA, scale_tensor);
|
||||
offload_func(BA);
|
||||
ggml_set_name(BA, "BA_scaled");
|
||||
}
|
||||
|
||||
ggml_tensor * r;
|
||||
if (base_t == dest_t) {
|
||||
r = ggml_add_inplace(lora_ctx, dest_t, BA);
|
||||
r = ggml_add_inplace(lora_ctx.get(), dest_t, BA);
|
||||
offload_func_force_inplace(r);
|
||||
ggml_set_name(r, "r_add_inplace");
|
||||
}
|
||||
else {
|
||||
r = ggml_add(lora_ctx, base_t, BA);
|
||||
r = ggml_add(lora_ctx.get(), base_t, BA);
|
||||
offload_func(r);
|
||||
ggml_set_name(r, "r_add");
|
||||
|
||||
r = ggml_cpy(lora_ctx, r, dest_t);
|
||||
r = ggml_cpy(lora_ctx.get(), r, dest_t);
|
||||
offload_func(r);
|
||||
ggml_set_name(r, "r_cpy");
|
||||
}
|
||||
|
@ -6163,10 +6160,11 @@ static int llama_apply_lora_from_file_internal(
|
|||
|
||||
ggml_graph_compute_helper(work_buffer, &gf, n_threads);
|
||||
|
||||
// we won't need these tensors again, reset the context to save memory
|
||||
// the tensors in the adapter must be sorted such that loraA and loraB of the same tensor are next to each other
|
||||
GGML_ASSERT(lora_tensors.size() == 2);
|
||||
ggml_free(lora_ctx);
|
||||
lora_ctx = ggml_init(params);
|
||||
|
||||
// we won't need these tensors again, reset the context to save memory
|
||||
lora_ctx.reset(ggml_init(params));
|
||||
lora_tensors.clear();
|
||||
|
||||
n_tensors++;
|
||||
|
@ -6176,12 +6174,6 @@ static int llama_apply_lora_from_file_internal(
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: this should be in a destructor, it will leak on failure
|
||||
ggml_free(lora_ctx);
|
||||
if (base_ctx) {
|
||||
ggml_free(base_ctx);
|
||||
}
|
||||
|
||||
const int64_t t_lora_us = ggml_time_us() - t_start_lora_us;
|
||||
LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue