make sure base model tensors data cannot be used in viewable operations

memory allocator would try to make lora application inplace on base model tensors.
since those are memory mapped this will result in memory access violations
This commit is contained in:
xaedes 2023-08-18 15:03:17 +02:00
parent 0bb897c82a
commit 44526cb261
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1224,6 +1224,24 @@ struct ggml_tensor * llama_build_lora_finetune_graphs(
// output tensors // output tensors
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one)); ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one)); ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
// make sure base model tensors data cannot be used in viewable operations
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->norm, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->output, one));
for (int il = 0; il < n_layer; ++il) {
struct my_llama_layer & layer = model->layers[il];
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.attention_norm, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_norm, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wq, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w1, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w2, one));
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, one));
}
// gradient tensors (will be set to zero by ggml_graph_reset) // gradient tensors (will be set to zero by ggml_graph_reset)
for (int i = 0; i < gf->n_nodes; ++i) { for (int i = 0; i < gf->n_nodes; ++i) {
if (!gf->grads[i]) continue; if (!gf->grads[i]) continue;