From f607bd12176f11c19a1c26b3272558dbc453b1c5 Mon Sep 17 00:00:00 2001 From: Howard Su Date: Thu, 6 Jul 2023 21:12:44 +0800 Subject: [PATCH] Add new APIs --- llama.cpp | 62 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- llama.h | 15 ++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 7419b03b6..aa588c77d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -193,6 +193,14 @@ struct llama_layer { struct ggml_tensor * w3; }; +struct llama_lora_layers { + // LoRA optional + struct ggml_tensor * wq_a; + struct ggml_tensor * wq_b; + struct ggml_tensor * wv_a; + struct ggml_tensor * wv_b; +}; + struct llama_kv_cache { struct ggml_tensor * k = NULL; struct ggml_tensor * v = NULL; @@ -303,6 +311,7 @@ struct llama_context { const llama_model & model; const llama_vocab & vocab; + std::vector lora_layers; bool model_owner = false; @@ -2709,7 +2718,7 @@ int llama_model_quantize( } } -int llama_apply_lora_from_file_internal(const struct llama_model & model, const char * path_lora, const char * path_base_model, int n_threads) { +static int llama_apply_lora_from_file_internal(const struct llama_model & model, const char * path_lora, const char * path_base_model, int n_threads) { fprintf(stderr, "%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora); const int64_t t_start_lora_us = ggml_time_us(); @@ -3525,3 +3534,54 @@ const char * llama_print_system_info(void) { const std::vector>& llama_internal_get_tensor_map(struct llama_context * ctx) { return ctx->model.tensors_by_name; } + +// finetune related code +int llama_enable_finetune(struct llama_context * ctx, enum llama_finetune_type flags, int n_lora) { + auto model = &ctx->model; + const auto& hparams = model->hparams; + + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_embd = hparams.n_embd; + + struct ggml_context* ctx0 = model->ctx; + + if (flags & LLAMA_FINETUNE_FULL) { + ggml_set_param(ctx0, model->tok_embeddings); + ggml_set_param(ctx0, model->norm); + + for (uint32_t i = 0; i < n_layer; ++i) { + auto & layer = model->layers[i]; + + ggml_set_param(ctx0, layer.attention_norm); + ggml_set_param(ctx0, layer.wq); + ggml_set_param(ctx0, layer.wk); + ggml_set_param(ctx0, layer.wv); + ggml_set_param(ctx0, layer.wo); + ggml_set_param(ctx0, layer.ffn_norm); + ggml_set_param(ctx0, layer.w1); + ggml_set_param(ctx0, layer.w2); + ggml_set_param(ctx0, layer.w3); + } + } else if (flags & LLAMA_FINETUNE_LORA) { + // create AB tensor if they are not present + for (uint32_t i = 0; i < n_layer; ++i) { + llama_lora_layers layer = {0}; + + if (flags & LLAMA_FINETUNE_LORA_Q) { + if (layer.wq_a == nullptr || layer.wq_b == nullptr) { + layer.wq_a = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_lora, n_embd); + layer.wq_b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_embd, n_lora); + // initialize + } + ggml_set_param(ctx0, layer.wq_a); + ggml_set_param(ctx0, layer.wq_b); + } + + if (flags & LLAMA_FINETUNE_LORA_Q) { + + } + } + } + + return 0; +} diff --git a/llama.h b/llama.h index 5bb1964bd..31bceafcd 100644 --- a/llama.h +++ b/llama.h @@ -126,6 +126,16 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors }; + enum llama_finetune_type { + LLAMA_FINETUNE_FULL = 0x01, + LLAMA_FINETUNE_LORA = 0x10, + + LLAMA_FINETUNE_LORA_W = 0x1000, // valid only LoRA + LLAMA_FINETUNE_LORA_K = 0x2000, + LLAMA_FINETUNE_LORA_Q = 0x4000, + LLAMA_FINETUNE_LORA_V = 0x8000, + }; + // model quantization parameters typedef struct llama_model_quantize_params { int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() @@ -242,6 +252,11 @@ extern "C" { // IMPORTANT: do not use for anything else other than debugging and testing! LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname); + // Enable finetune on the context, flags indicate what type of finetune + LLAMA_API int llama_enable_finetune(struct llama_context * ctx, enum llama_finetune_type flags); + + LLAMA_API int llama_finetune(struct llama_context * ctx, void * input, void * output); + // Convert the provided text into tokens. // The tokens pointer must be large enough to hold the resulting tokens. // Returns the number of tokens on success, no more than n_max_tokens