Add lora support
This commit is contained in:
parent
3173a62eb9
commit
f52101e889
9 changed files with 335 additions and 3 deletions
101
convert-lora-to-ggml.py
Normal file
101
convert-lora-to-ggml.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
import os
|
||||
import re
|
||||
import struct
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
# TODO: import this from convert.py once #545 is merged
|
||||
@dataclass(frozen=True)
|
||||
class UnquantizedDataType:
|
||||
name: str
|
||||
|
||||
DT_F16 = UnquantizedDataType('F16')
|
||||
DT_F32 = UnquantizedDataType('F32')
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class QuantizedDataType:
|
||||
groupsize: int
|
||||
have_addends: bool
|
||||
have_g_idx: bool
|
||||
|
||||
DataType = UnquantizedDataType
|
||||
|
||||
DATA_TYPE_TO_FTYPE: dict[DataType, int] = {
|
||||
DT_F32: 0,
|
||||
DT_F16: 1,
|
||||
}
|
||||
|
||||
DATA_TYPE_TO_NUMPY: dict[DataType, np.dtype[Any]] = {
|
||||
DT_F16: np.dtype(np.float16),
|
||||
DT_F32: np.dtype(np.float32),
|
||||
}
|
||||
|
||||
NUMPY_TYPE_TO_DATA_TYPE: dict[np.dtype[Any], DataType] = {dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()}
|
||||
|
||||
HF_SUBLAYER_TO_GGML = {
|
||||
"self_attn.q_proj": "attention.wq.weight",
|
||||
"self_attn.k_proj": "attention.wk.weight",
|
||||
"self_attn.v_proj": "attention.wv.weight",
|
||||
"self_attn.o_proj": "attention.wo.weight",
|
||||
}
|
||||
|
||||
def translate_tensor_name(t):
|
||||
match = re.match(r'.*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight', t)
|
||||
if match:
|
||||
nn = match.group(1)
|
||||
sub_layer = match.group(2)
|
||||
lora_type = match.group(3)
|
||||
|
||||
sub_layer_renamed = HF_SUBLAYER_TO_GGML.get(sub_layer)
|
||||
if sub_layer_renamed is None:
|
||||
print(f"Error: unrecognized sub-layer {sub_layer} in tensor {t}")
|
||||
exit(1)
|
||||
|
||||
output_string = f"layers.{nn}.{HF_SUBLAYER_TO_GGML[sub_layer]}.lora{lora_type}"
|
||||
return output_string
|
||||
else:
|
||||
print(f"Error: unrecognized tensor {t}")
|
||||
exit(1)
|
||||
|
||||
def write_file_header(fout):
|
||||
fout.write(b"ggla"[::-1]) # magic (ggml lora)
|
||||
fout.write(struct.pack("i", 1)) # file version
|
||||
|
||||
|
||||
def write_tensor_header(self, name: str, shape: Sequence[int], data_type: 1) -> None:
|
||||
sname = name.encode('utf-8')
|
||||
fout.write(struct.pack("iii", len(shape), len(sname), DATA_TYPE_TO_FTYPE[NUMPY_TYPE_TO_DATA_TYPE[data_type]]))
|
||||
fout.write(struct.pack("i" * len(shape), *shape[::-1]))
|
||||
fout.write(sname)
|
||||
fout.seek((fout.tell() + 31) & -32)
|
||||
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print(f"Usage: python {sys.argv[0]} adapter_model.bin [ggml_adapter_model.bin]")
|
||||
sys.exit(1)
|
||||
|
||||
input_path = sys.argv[1]
|
||||
if len(sys.argv) > 2:
|
||||
output_path = sys.argv[2]
|
||||
else:
|
||||
output_filename = f"ggml_{os.path.basename(input_path)}"
|
||||
output_path = os.path.join(os.path.dirname(input_path), output_filename)
|
||||
|
||||
model = torch.load(input_path, map_location="cpu")
|
||||
|
||||
with open(output_path, "wb") as fout:
|
||||
write_file_header(fout)
|
||||
for k, v in model.items():
|
||||
# since ggml doesn't always support other types for the second operand,
|
||||
# the tensors are always converted and exported as f32
|
||||
t = v.float().numpy()
|
||||
print(f"{k} => {translate_tensor_name(k)} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB")
|
||||
write_tensor_header(fout, translate_tensor_name(k), t.shape, t.dtype)
|
||||
t.tofile(fout)
|
||||
|
||||
print(f"Converted {input_path} to {output_path}")
|
|
@ -139,6 +139,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
break;
|
||||
}
|
||||
params.model = argv[i];
|
||||
} else if (arg == "--lora") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.lora_adapter = argv[i];
|
||||
} else if (arg == "-i" || arg == "--interactive") {
|
||||
params.interactive = true;
|
||||
} else if (arg == "--embedding") {
|
||||
|
@ -242,6 +248,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
}
|
||||
fprintf(stderr, " --mtest compute maximum memory usage\n");
|
||||
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
|
||||
fprintf(stderr, " --lora FNAME apply LoRA adapter\n");
|
||||
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
||||
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
||||
fprintf(stderr, "\n");
|
||||
|
|
|
@ -31,11 +31,11 @@ struct gpt_params {
|
|||
|
||||
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
|
||||
std::string prompt = "";
|
||||
std::string input_prefix = ""; // string to prefix user inputs with
|
||||
|
||||
|
||||
std::string input_prefix = ""; // string to prefix user inputs with
|
||||
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
||||
|
||||
std::string lora_adapter = ""; // lora adapter path
|
||||
|
||||
bool memory_f16 = true; // use f16 instead of f32 for memory kv
|
||||
bool random_prompt = false; // do not randomize prompt if none provided
|
||||
bool use_color = false; // use color to distinguish generations and inputs
|
||||
|
|
|
@ -114,6 +114,14 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
if (!params.lora_adapter.empty()) {
|
||||
int err = llama_apply_lora_from_file(ctx, params.lora_adapter.c_str(), params.n_threads);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// print system information
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
|
|
|
@ -134,6 +134,14 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
if (!params.lora_adapter.empty()) {
|
||||
int err = llama_apply_lora_from_file(ctx, params.lora_adapter.c_str(), params.n_threads);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// print system information
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
|
|
45
ggml.c
45
ggml.c
|
@ -5813,6 +5813,47 @@ static void ggml_compute_forward_add_f32(
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_add_f16_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int n = ggml_nrows(src0);
|
||||
const int nc = src0->ne[0];
|
||||
|
||||
const size_t nb00 = src0->nb[0];
|
||||
const size_t nb01 = src0->nb[1];
|
||||
|
||||
const size_t nb10 = src1->nb[0];
|
||||
const size_t nb11 = src1->nb[1];
|
||||
|
||||
const size_t nb0 = dst->nb[0];
|
||||
const size_t nb1 = dst->nb[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
||||
|
||||
for (int j = ith; j < n; j += nth) {
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
|
||||
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
|
||||
for (int i = 0; i < nc; i++) {
|
||||
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
|
||||
|
||||
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_add(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
|
@ -5823,6 +5864,10 @@ static void ggml_compute_forward_add(
|
|||
{
|
||||
ggml_compute_forward_add_f32(params, src0, src1, dst);
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
|
|
6
ggml.h
6
ggml.h
|
@ -430,6 +430,12 @@ struct ggml_tensor * ggml_add(
|
|||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
|
||||
struct ggml_tensor * ggml_add_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
struct ggml_tensor * ggml_sub(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
|
|
148
llama.cpp
148
llama.cpp
|
@ -1758,6 +1758,154 @@ int llama_model_quantize(
|
|||
}
|
||||
}
|
||||
|
||||
int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, int n_threads) {
|
||||
// TODO: refactor all of this after PR #801
|
||||
auto & model = ctx->model;
|
||||
|
||||
auto fin = std::ifstream(path_lora, std::ios::binary);
|
||||
if (!fin) {
|
||||
fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_lora);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// verify magic and version
|
||||
{
|
||||
uint32_t magic;
|
||||
fin.read((char *) &magic, sizeof(magic));
|
||||
if (magic != 'ggla') {
|
||||
fprintf(stderr, "%s: bad file magic\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
uint32_t format_version;
|
||||
fin.read((char *) &format_version, sizeof(format_version));
|
||||
|
||||
if (format_version != 1) {
|
||||
fprintf(stderr, "%s: unsupported file version\n", __func__ );
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// create a temporary ggml context to store the lora tensors
|
||||
std::vector<uint8_t> buf(1024 * 1024 * 100);
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = buf.size();
|
||||
params.mem_buffer = buf.data();
|
||||
params.no_alloc = false;
|
||||
|
||||
ggml_context* lora_ctx = ggml_init(params);
|
||||
std::unordered_map<std::string, struct ggml_tensor *> lora_tensors;
|
||||
|
||||
fprintf(stderr, "%s: ", __func__);
|
||||
|
||||
// read tensors and apply
|
||||
int n_tensors = 0;
|
||||
while (true) {
|
||||
int32_t n_dims;
|
||||
int32_t length;
|
||||
int32_t ftype;
|
||||
|
||||
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
|
||||
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
|
||||
if (fin.eof()) {
|
||||
break;
|
||||
}
|
||||
|
||||
int32_t nelements = 1;
|
||||
int32_t ne[2] = { 1, 1 };
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
||||
nelements *= ne[i];
|
||||
}
|
||||
|
||||
std::string name(length, 0);
|
||||
fin.read(&name[0], length);
|
||||
|
||||
// check for lora suffix and get the type of tensor
|
||||
const std::string lora_suffix = ".lora";
|
||||
size_t pos = name.rfind(lora_suffix);
|
||||
if (pos == std::string::npos) {
|
||||
fprintf(stderr, "%s: error: '%s' is not a lora tensor\n", __func__, name.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::string lora_type = name.substr(pos + lora_suffix.length());
|
||||
std::string base_name = name;
|
||||
base_name.erase(pos);
|
||||
// fprintf(stderr, "%s: %s => %s (lora type %s) ", __func__, name.c_str(),base_name.c_str(), lora_type.c_str());
|
||||
|
||||
if (model.tensors.find(base_name.data()) == model.tensors.end()) {
|
||||
fprintf(stderr, "%s: unknown tensor '%s' in lora adapter\n", __func__, name.data());
|
||||
return 1;
|
||||
}
|
||||
|
||||
// create ggml tensor
|
||||
ggml_type wtype;
|
||||
switch (ftype) {
|
||||
case 0: wtype = GGML_TYPE_F32; break;
|
||||
case 1: wtype = GGML_TYPE_F16; break;
|
||||
default:
|
||||
{
|
||||
fprintf(stderr, "%s: invalid tensor data type '%d'\n",
|
||||
__func__, ftype);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
ggml_tensor* lora_tensor;
|
||||
if (n_dims == 2) {
|
||||
lora_tensor = ggml_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1]);
|
||||
}
|
||||
else {
|
||||
fprintf(stderr, "%s: unsupported tensor dimension %d\n", __func__, n_dims);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// load tensor data
|
||||
size_t offset = fin.tellg();
|
||||
size_t tensor_data_size = ggml_nbytes(lora_tensor);
|
||||
offset = (offset + 31) & -32;
|
||||
fin.seekg(offset);
|
||||
fin.read((char*)lora_tensor->data, tensor_data_size);
|
||||
|
||||
lora_tensors[name] = lora_tensor;
|
||||
|
||||
// check if we have both A and B tensors and apply
|
||||
if (lora_tensors.find(base_name + ".loraA") != lora_tensors.end() &&
|
||||
lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) {
|
||||
|
||||
ggml_tensor * tensor = model.tensors[base_name];
|
||||
ggml_tensor * loraA = ggml_transpose(lora_ctx, lora_tensors[base_name + ".loraA"]);
|
||||
ggml_tensor * loraB = lora_tensors[base_name + ".loraB"];
|
||||
|
||||
if (tensor->ne[0] != loraA->ne[1]) {
|
||||
fprintf(stderr, "%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
|
||||
" are you sure that this adapter is for this model?\n", __func__, tensor->ne[0], loraA->ne[1]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// w = w + BA
|
||||
ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraB, loraA);
|
||||
ggml_tensor * r = ggml_add_inplace(lora_ctx, tensor, BA);
|
||||
|
||||
struct ggml_cgraph gf = ggml_build_forward(r);
|
||||
gf.n_threads = n_threads;
|
||||
ggml_graph_compute(lora_ctx, &gf);
|
||||
|
||||
// we won't need these tensors again, reset the context to save memory
|
||||
ggml_free(lora_ctx);
|
||||
lora_ctx = ggml_init(params);
|
||||
lora_tensors.clear();
|
||||
|
||||
n_tensors++;
|
||||
if (n_tensors % 8 == 0)
|
||||
fprintf(stderr, ".");
|
||||
}
|
||||
}
|
||||
fprintf(stderr, " done\n");
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Returns the KV cache that will contain the context for the
|
||||
// ongoing prediction with the model.
|
||||
const uint8_t * llama_get_kv_cache(struct llama_context * ctx) {
|
||||
|
|
9
llama.h
9
llama.h
|
@ -96,6 +96,15 @@ extern "C" {
|
|||
const char * fname_out,
|
||||
enum llama_ftype ftype);
|
||||
|
||||
// Apply a LoRA adapter to a loaded model
|
||||
// The model needs to be reloaded before applying a new adapter, otherwise
|
||||
// the adapter will the applied on top of the previous one
|
||||
// Returns 0 on success
|
||||
LLAMA_API int llama_apply_lora_from_file(
|
||||
struct llama_context * ctx,
|
||||
const char * path_lora,
|
||||
int n_threads);
|
||||
|
||||
// Returns the KV cache that will contain the context for the
|
||||
// ongoing prediction with the model.
|
||||
LLAMA_API const uint8_t * llama_get_kv_cache(struct llama_context * ctx);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue