From f52101e8897ea3acdf7903b398ee4201db1176aa Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Thu, 6 Apr 2023 23:18:59 +0200 Subject: [PATCH] Add lora support --- convert-lora-to-ggml.py | 101 ++++++++++++++++++++ examples/common.cpp | 7 ++ examples/common.h | 6 +- examples/main/main.cpp | 8 ++ examples/perplexity/perplexity.cpp | 8 ++ ggml.c | 45 +++++++++ ggml.h | 6 ++ llama.cpp | 148 +++++++++++++++++++++++++++++ llama.h | 9 ++ 9 files changed, 335 insertions(+), 3 deletions(-) create mode 100644 convert-lora-to-ggml.py diff --git a/convert-lora-to-ggml.py b/convert-lora-to-ggml.py new file mode 100644 index 000000000..988627181 --- /dev/null +++ b/convert-lora-to-ggml.py @@ -0,0 +1,101 @@ +import os +import re +import struct +import sys +from dataclasses import dataclass +from typing import Any, Sequence + +import numpy as np +import torch + + +# TODO: import this from convert.py once #545 is merged +@dataclass(frozen=True) +class UnquantizedDataType: + name: str + +DT_F16 = UnquantizedDataType('F16') +DT_F32 = UnquantizedDataType('F32') + +@dataclass(frozen=True) +class QuantizedDataType: + groupsize: int + have_addends: bool + have_g_idx: bool + +DataType = UnquantizedDataType + +DATA_TYPE_TO_FTYPE: dict[DataType, int] = { + DT_F32: 0, + DT_F16: 1, +} + +DATA_TYPE_TO_NUMPY: dict[DataType, np.dtype[Any]] = { + DT_F16: np.dtype(np.float16), + DT_F32: np.dtype(np.float32), +} + +NUMPY_TYPE_TO_DATA_TYPE: dict[np.dtype[Any], DataType] = {dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()} + +HF_SUBLAYER_TO_GGML = { + "self_attn.q_proj": "attention.wq.weight", + "self_attn.k_proj": "attention.wk.weight", + "self_attn.v_proj": "attention.wv.weight", + "self_attn.o_proj": "attention.wo.weight", +} + +def translate_tensor_name(t): + match = re.match(r'.*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight', t) + if match: + nn = match.group(1) + sub_layer = match.group(2) + lora_type = match.group(3) + + sub_layer_renamed = HF_SUBLAYER_TO_GGML.get(sub_layer) + if sub_layer_renamed is None: + print(f"Error: unrecognized sub-layer {sub_layer} in tensor {t}") + exit(1) + + output_string = f"layers.{nn}.{HF_SUBLAYER_TO_GGML[sub_layer]}.lora{lora_type}" + return output_string + else: + print(f"Error: unrecognized tensor {t}") + exit(1) + +def write_file_header(fout): + fout.write(b"ggla"[::-1]) # magic (ggml lora) + fout.write(struct.pack("i", 1)) # file version + + +def write_tensor_header(self, name: str, shape: Sequence[int], data_type: 1) -> None: + sname = name.encode('utf-8') + fout.write(struct.pack("iii", len(shape), len(sname), DATA_TYPE_TO_FTYPE[NUMPY_TYPE_TO_DATA_TYPE[data_type]])) + fout.write(struct.pack("i" * len(shape), *shape[::-1])) + fout.write(sname) + fout.seek((fout.tell() + 31) & -32) + + +if len(sys.argv) < 2: + print(f"Usage: python {sys.argv[0]} adapter_model.bin [ggml_adapter_model.bin]") + sys.exit(1) + +input_path = sys.argv[1] +if len(sys.argv) > 2: + output_path = sys.argv[2] +else: + output_filename = f"ggml_{os.path.basename(input_path)}" + output_path = os.path.join(os.path.dirname(input_path), output_filename) + +model = torch.load(input_path, map_location="cpu") + +with open(output_path, "wb") as fout: + write_file_header(fout) + for k, v in model.items(): + # since ggml doesn't always support other types for the second operand, + # the tensors are always converted and exported as f32 + t = v.float().numpy() + print(f"{k} => {translate_tensor_name(k)} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB") + write_tensor_header(fout, translate_tensor_name(k), t.shape, t.dtype) + t.tofile(fout) + +print(f"Converted {input_path} to {output_path}") \ No newline at end of file diff --git a/examples/common.cpp b/examples/common.cpp index 0772dbfe1..403b2cc15 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -139,6 +139,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.model = argv[i]; + } else if (arg == "--lora") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.lora_adapter = argv[i]; } else if (arg == "-i" || arg == "--interactive") { params.interactive = true; } else if (arg == "--embedding") { @@ -242,6 +248,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { } fprintf(stderr, " --mtest compute maximum memory usage\n"); fprintf(stderr, " --verbose-prompt print prompt before generation\n"); + fprintf(stderr, " --lora FNAME apply LoRA adapter\n"); fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); fprintf(stderr, "\n"); diff --git a/examples/common.h b/examples/common.h index 1ea6f7445..ba825f306 100644 --- a/examples/common.h +++ b/examples/common.h @@ -31,11 +31,11 @@ struct gpt_params { std::string model = "models/lamma-7B/ggml-model.bin"; // model path std::string prompt = ""; - std::string input_prefix = ""; // string to prefix user inputs with - - + std::string input_prefix = ""; // string to prefix user inputs with std::vector antiprompt; // string upon seeing which more user input is prompted + std::string lora_adapter = ""; // lora adapter path + bool memory_f16 = true; // use f16 instead of f32 for memory kv bool random_prompt = false; // do not randomize prompt if none provided bool use_color = false; // use color to distinguish generations and inputs diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3e4b0034e..a50fc641c 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -114,6 +114,14 @@ int main(int argc, char ** argv) { } } + if (!params.lora_adapter.empty()) { + int err = llama_apply_lora_from_file(ctx, params.lora_adapter.c_str(), params.n_threads); + if (err != 0) { + fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); + return 1; + } + } + // print system information { fprintf(stderr, "\n"); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 19449e16e..716c5e0e4 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -134,6 +134,14 @@ int main(int argc, char ** argv) { } } + if (!params.lora_adapter.empty()) { + int err = llama_apply_lora_from_file(ctx, params.lora_adapter.c_str(), params.n_threads); + if (err != 0) { + fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); + return 1; + } + } + // print system information { fprintf(stderr, "\n"); diff --git a/ggml.c b/ggml.c index 69974989c..a486cad67 100644 --- a/ggml.c +++ b/ggml.c @@ -5813,6 +5813,47 @@ static void ggml_compute_forward_add_f32( } } +static void ggml_compute_forward_add_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + + const size_t nb10 = src1->nb[0]; + const size_t nb11 = src1->nb[1]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + for (int j = ith; j < n; j += nth) { + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01); + for (int i = 0; i < nc; i++) { + float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10); + + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr); + } + } +} + static void ggml_compute_forward_add( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -5823,6 +5864,10 @@ static void ggml_compute_forward_add( { ggml_compute_forward_add_f32(params, src0, src1, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_add_f16_f32(params, src0, src1, dst); + } break; default: { GGML_ASSERT(false); diff --git a/ggml.h b/ggml.h index 241e96a19..add002581 100644 --- a/ggml.h +++ b/ggml.h @@ -430,6 +430,12 @@ struct ggml_tensor * ggml_add( struct ggml_tensor * a, struct ggml_tensor * b); + +struct ggml_tensor * ggml_add_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + struct ggml_tensor * ggml_sub( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/llama.cpp b/llama.cpp index a6429a4e7..ba1f089b8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1758,6 +1758,154 @@ int llama_model_quantize( } } +int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, int n_threads) { + // TODO: refactor all of this after PR #801 + auto & model = ctx->model; + + auto fin = std::ifstream(path_lora, std::ios::binary); + if (!fin) { + fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_lora); + return 1; + } + + // verify magic and version + { + uint32_t magic; + fin.read((char *) &magic, sizeof(magic)); + if (magic != 'ggla') { + fprintf(stderr, "%s: bad file magic\n", __func__); + return 1; + } + uint32_t format_version; + fin.read((char *) &format_version, sizeof(format_version)); + + if (format_version != 1) { + fprintf(stderr, "%s: unsupported file version\n", __func__ ); + return 1; + } + } + + // create a temporary ggml context to store the lora tensors + std::vector buf(1024 * 1024 * 100); + struct ggml_init_params params; + params.mem_size = buf.size(); + params.mem_buffer = buf.data(); + params.no_alloc = false; + + ggml_context* lora_ctx = ggml_init(params); + std::unordered_map lora_tensors; + + fprintf(stderr, "%s: ", __func__); + + // read tensors and apply + int n_tensors = 0; + while (true) { + int32_t n_dims; + int32_t length; + int32_t ftype; + + fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + fin.read(reinterpret_cast(&length), sizeof(length)); + fin.read(reinterpret_cast(&ftype), sizeof(ftype)); + if (fin.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } + + std::string name(length, 0); + fin.read(&name[0], length); + + // check for lora suffix and get the type of tensor + const std::string lora_suffix = ".lora"; + size_t pos = name.rfind(lora_suffix); + if (pos == std::string::npos) { + fprintf(stderr, "%s: error: '%s' is not a lora tensor\n", __func__, name.c_str()); + return 1; + } + + std::string lora_type = name.substr(pos + lora_suffix.length()); + std::string base_name = name; + base_name.erase(pos); + // fprintf(stderr, "%s: %s => %s (lora type %s) ", __func__, name.c_str(),base_name.c_str(), lora_type.c_str()); + + if (model.tensors.find(base_name.data()) == model.tensors.end()) { + fprintf(stderr, "%s: unknown tensor '%s' in lora adapter\n", __func__, name.data()); + return 1; + } + + // create ggml tensor + ggml_type wtype; + switch (ftype) { + case 0: wtype = GGML_TYPE_F32; break; + case 1: wtype = GGML_TYPE_F16; break; + default: + { + fprintf(stderr, "%s: invalid tensor data type '%d'\n", + __func__, ftype); + return false; + } + } + ggml_tensor* lora_tensor; + if (n_dims == 2) { + lora_tensor = ggml_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1]); + } + else { + fprintf(stderr, "%s: unsupported tensor dimension %d\n", __func__, n_dims); + return 1; + } + + // load tensor data + size_t offset = fin.tellg(); + size_t tensor_data_size = ggml_nbytes(lora_tensor); + offset = (offset + 31) & -32; + fin.seekg(offset); + fin.read((char*)lora_tensor->data, tensor_data_size); + + lora_tensors[name] = lora_tensor; + + // check if we have both A and B tensors and apply + if (lora_tensors.find(base_name + ".loraA") != lora_tensors.end() && + lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) { + + ggml_tensor * tensor = model.tensors[base_name]; + ggml_tensor * loraA = ggml_transpose(lora_ctx, lora_tensors[base_name + ".loraA"]); + ggml_tensor * loraB = lora_tensors[base_name + ".loraB"]; + + if (tensor->ne[0] != loraA->ne[1]) { + fprintf(stderr, "%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");" + " are you sure that this adapter is for this model?\n", __func__, tensor->ne[0], loraA->ne[1]); + return 1; + } + + // w = w + BA + ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraB, loraA); + ggml_tensor * r = ggml_add_inplace(lora_ctx, tensor, BA); + + struct ggml_cgraph gf = ggml_build_forward(r); + gf.n_threads = n_threads; + ggml_graph_compute(lora_ctx, &gf); + + // we won't need these tensors again, reset the context to save memory + ggml_free(lora_ctx); + lora_ctx = ggml_init(params); + lora_tensors.clear(); + + n_tensors++; + if (n_tensors % 8 == 0) + fprintf(stderr, "."); + } + } + fprintf(stderr, " done\n"); + + return 0; +} + // Returns the KV cache that will contain the context for the // ongoing prediction with the model. const uint8_t * llama_get_kv_cache(struct llama_context * ctx) { diff --git a/llama.h b/llama.h index 192217593..535f1b18e 100644 --- a/llama.h +++ b/llama.h @@ -96,6 +96,15 @@ extern "C" { const char * fname_out, enum llama_ftype ftype); + // Apply a LoRA adapter to a loaded model + // The model needs to be reloaded before applying a new adapter, otherwise + // the adapter will the applied on top of the previous one + // Returns 0 on success + LLAMA_API int llama_apply_lora_from_file( + struct llama_context * ctx, + const char * path_lora, + int n_threads); + // Returns the KV cache that will contain the context for the // ongoing prediction with the model. LLAMA_API const uint8_t * llama_get_kv_cache(struct llama_context * ctx);