llama_control_vector_load_one() and llama_control_vector_load() now break on error
This commit is contained in:
parent
0395a3c4f2
commit
a299709134
1 changed files with 24 additions and 12 deletions
|
@ -2804,7 +2804,6 @@ float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n)
|
|||
//
|
||||
|
||||
static llama_control_vector_data llama_control_vector_load_one(const llama_control_vector_load_info & load_info) {
|
||||
|
||||
llama_control_vector_data result = { -1, {} };
|
||||
|
||||
ggml_context * ctx = nullptr;
|
||||
|
@ -2814,7 +2813,7 @@ static llama_control_vector_data llama_control_vector_load_one(const llama_contr
|
|||
};
|
||||
struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params);
|
||||
if (!ctx_gguf) {
|
||||
fprintf(stderr, "%s: failed to load control vector from %s\n", __func__, load_info.fname.c_str());
|
||||
fprintf(stderr, "%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str());
|
||||
ggml_free(ctx);
|
||||
return result;
|
||||
}
|
||||
|
@ -2840,27 +2839,32 @@ static llama_control_vector_data llama_control_vector_load_one(const llama_contr
|
|||
}
|
||||
if (layer_idx < 0) {
|
||||
fprintf(stderr, "%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
|
||||
continue;
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
} else if (layer_idx == 0) {
|
||||
fprintf(stderr, "%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
|
||||
continue;
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
|
||||
struct ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str());
|
||||
if (tensor->type != GGML_TYPE_F32) {
|
||||
fprintf(stderr, "%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str());
|
||||
continue;
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
if (ggml_n_dims(tensor) != 1) {
|
||||
fprintf(stderr, "%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str());
|
||||
continue;
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
|
||||
if (result.n_embd == -1) {
|
||||
result.n_embd = ggml_nelements(tensor);
|
||||
} else if (ggml_nelements(tensor) != result.n_embd) {
|
||||
fprintf(stderr, "%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str());
|
||||
continue;
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
|
||||
// extend if necessary - do not store data for layer 0 (it's not used)
|
||||
|
@ -2869,11 +2873,16 @@ static llama_control_vector_data llama_control_vector_load_one(const llama_contr
|
|||
const float * src = (const float *) tensor->data;
|
||||
float * dst = result.data.data() + result.n_embd * (layer_idx - 1); // layer 1 at [0]
|
||||
for (int j = 0; j < result.n_embd; j++) {
|
||||
dst[j] += src[j] * load_info.strength; // allow multiple for same layer in same file
|
||||
dst[j] += src[j] * load_info.strength; // allows multiple directions for same layer in same file
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if (result.n_embd == -1) {
|
||||
fprintf(stderr, "%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str());
|
||||
result.data.clear();
|
||||
}
|
||||
|
||||
gguf_free(ctx_gguf);
|
||||
ggml_free(ctx);
|
||||
|
||||
|
@ -2887,11 +2896,13 @@ llama_control_vector_data llama_control_vector_load(const std::vector<llama_cont
|
|||
auto cur = llama_control_vector_load_one(info);
|
||||
|
||||
if (cur.n_embd == -1) {
|
||||
continue;
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
if (result.n_embd != -1 && result.n_embd != cur.n_embd) {
|
||||
fprintf(stderr, "%s: control vector in %s does not match previous dimensions\n", __func__, info.fname.c_str());
|
||||
continue;
|
||||
fprintf(stderr, "%s: control vectors in %s does not match previous dimensions\n", __func__, info.fname.c_str());
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
|
||||
if (result.n_embd == -1) {
|
||||
|
@ -2905,7 +2916,8 @@ llama_control_vector_data llama_control_vector_load(const std::vector<llama_cont
|
|||
}
|
||||
|
||||
if (result.n_embd == -1) {
|
||||
fprintf(stderr, "%s: no valid vectors passed\n", __func__);
|
||||
fprintf(stderr, "%s: no valid control vector files passed\n", __func__);
|
||||
result.data.clear();
|
||||
}
|
||||
|
||||
return result;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue