Add basic cpu setup
This commit is contained in:
parent
fd5ea0f897
commit
12112bfa48
8 changed files with 423 additions and 3 deletions
48
BRANCH_SETUP.md
Normal file
48
BRANCH_SETUP.md
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
# Setup this branch
|
||||||
|
|
||||||
|
## Create a lora adpter bin file
|
||||||
|
|
||||||
|
0. `mkdir models/open-llama` and download [Open-llama (all files)](https://huggingface.co/openlm-research/open_llama_3b_v2/tree/main) in the folder `./models/open-llama`
|
||||||
|
|
||||||
|
2. `mkdir data && touch data/hot-lora.txt` and write a couple of words in it.
|
||||||
|
|
||||||
|
3. Run:
|
||||||
|
```bash
|
||||||
|
# Convert base model to gguf
|
||||||
|
python3 convert-hf-to-gguf.py models/open-llama/
|
||||||
|
# Quantize base model
|
||||||
|
./quantize ./models/open-llama/ggml-model-f16.gguf ./models/open-llama/ggml-model-q8_0.gguf Q8_0
|
||||||
|
# Obtain Lora adapter
|
||||||
|
./finetune --model-base models/open-llama/ggml-model-q8_0.gguf \
|
||||||
|
--checkpoint-in models/open-llama/chk-lora-ggml-model-q8_0-hot-lora-LATEST.gguf \
|
||||||
|
--checkpoint-out models/open-llama/chk-lora-ggml-model-q8_0-hot-lora-ITERATION.gguf \
|
||||||
|
--lora-out models/open-llama/lora-ggml-model-q8_0-hot-lora-ITERATION.bin \
|
||||||
|
--train-data "data/hot-lora.txt" \
|
||||||
|
--save-every 1 \
|
||||||
|
--threads 1 \
|
||||||
|
--adam-iter 1 \
|
||||||
|
--batch 1 \
|
||||||
|
--ctx 16 \
|
||||||
|
--use-checkpointing
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run main with adapter
|
||||||
|
|
||||||
|
Run main with base model and lora adapter to hot-swap
|
||||||
|
```bash
|
||||||
|
./main ./models/open-llama/ggml-model-f16.gguf \
|
||||||
|
--hot-lora models/open-llama/lora-ggml-model-q8_0-hot-lora-ITERATION.bin \
|
||||||
|
-ngl 0 \
|
||||||
|
-n 128
|
||||||
|
```
|
||||||
|
|
||||||
|
With `ngl > 0` the code breaks. Probably because the Lora tensors try to interact with the base tensors (`lora_mul_mat`), but they are not moved to the buffer of the base tensors.
|
||||||
|
|
||||||
|
# Logic
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Current status
|
||||||
|
|
||||||
|
- Only ony Lora adapter can be passed.
|
||||||
|
- GPU not supported
|
|
@ -789,6 +789,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||||
params.model = argv[i];
|
params.model = argv[i];
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (arg == "-hl" || arg == "--hot-lora") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params.hot_lora = argv[i];
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (arg == "-md" || arg == "--model-draft") {
|
if (arg == "-md" || arg == "--model-draft") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
|
|
@ -100,6 +100,7 @@ struct gpt_params {
|
||||||
|
|
||||||
std::string model = ""; // model path
|
std::string model = ""; // model path
|
||||||
std::string model_draft = ""; // draft model for speculative decoding
|
std::string model_draft = ""; // draft model for speculative decoding
|
||||||
|
std::string hot_lora = ""; // lora model path for hot swapping
|
||||||
std::string model_alias = "unknown"; // model alias
|
std::string model_alias = "unknown"; // model alias
|
||||||
std::string model_url = ""; // model url to download
|
std::string model_url = ""; // model url to download
|
||||||
std::string hf_repo = ""; // HF repo
|
std::string hf_repo = ""; // HF repo
|
||||||
|
|
2
data/hot-lora.txt
Normal file
2
data/hot-lora.txt
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
|
||||||
|
how are you?
|
46
ggml.c
46
ggml.c
|
@ -4313,6 +4313,52 @@ struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * nam
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////// LORA
|
||||||
|
|
||||||
|
struct lora_tensor_pair* build_lora_weights_map(struct ggml_context* ctx) {
|
||||||
|
struct lora_tensor_pair* pair = malloc(sizeof(struct lora_tensor_pair));
|
||||||
|
if (!pair) return NULL;
|
||||||
|
pair->pairs = NULL;
|
||||||
|
pair->count = 0;
|
||||||
|
pair->capacity = 0;
|
||||||
|
|
||||||
|
struct ggml_object * obj = ctx->objects_begin;
|
||||||
|
char * const mem_buffer = ctx->mem_buffer;
|
||||||
|
|
||||||
|
while (obj != NULL) {
|
||||||
|
if (obj->type == GGML_OBJECT_TYPE_TENSOR) {
|
||||||
|
struct ggml_tensor * tensor = (struct ggml_tensor *)(mem_buffer + obj->offs);
|
||||||
|
char * tensor_name = tensor->name;
|
||||||
|
|
||||||
|
if (strlen(tensor_name) > 6 && (strcmp(tensor_name + strlen(tensor_name) - 6, ".loraA") == 0 ||
|
||||||
|
strcmp(tensor_name + strlen(tensor_name) - 6, ".loraB") == 0)) {
|
||||||
|
if (pair->count == pair->capacity) {
|
||||||
|
pair->capacity = pair->capacity > 0 ? pair->capacity * 2 : 4;
|
||||||
|
pair->pairs = realloc(pair->pairs, pair->capacity * sizeof(struct lora_tensor_info));
|
||||||
|
}
|
||||||
|
|
||||||
|
pair->pairs[pair->count].name = strdup(tensor_name);
|
||||||
|
pair->pairs[pair->count].tensor = tensor;
|
||||||
|
pair->count++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
obj = obj->next;
|
||||||
|
}
|
||||||
|
|
||||||
|
return pair;
|
||||||
|
}
|
||||||
|
|
||||||
|
void free_lora_tensor_pair(struct lora_tensor_pair* pair) {
|
||||||
|
if (!pair) return;
|
||||||
|
for (int i = 0; i < pair->count; i++) {
|
||||||
|
free(pair->pairs[i].name);
|
||||||
|
}
|
||||||
|
free(pair->pairs);
|
||||||
|
free(pair);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////// LORA
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
// ggml_dup
|
// ggml_dup
|
||||||
|
|
19
ggml.h
19
ggml.h
|
@ -835,6 +835,25 @@ extern "C" {
|
||||||
GGML_API struct ggml_tensor * ggml_get_next_tensor (const struct ggml_context * ctx, struct ggml_tensor * tensor);
|
GGML_API struct ggml_tensor * ggml_get_next_tensor (const struct ggml_context * ctx, struct ggml_tensor * tensor);
|
||||||
GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
|
GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
|
||||||
|
|
||||||
|
struct lora_tensor_info {
|
||||||
|
char* name;
|
||||||
|
struct ggml_tensor* tensor;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct lora_tensor_pair {
|
||||||
|
struct lora_tensor_info* pairs; // Dynamic array of tensor pairs
|
||||||
|
int count;
|
||||||
|
int capacity;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Function to build tensor pairs
|
||||||
|
struct lora_tensor_pair* build_lora_weights_map(struct ggml_context* ctx);
|
||||||
|
|
||||||
|
// Cleanup function for lora_tensor_pair
|
||||||
|
void free_lora_tensor_pair(struct lora_tensor_pair* pair);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
|
GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
|
||||||
GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
|
GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
|
||||||
GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
|
GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
|
||||||
|
|
299
llama.cpp
299
llama.cpp
|
@ -119,6 +119,212 @@ static void llama_log_callback_default(ggml_log_level level, const char * text,
|
||||||
// helpers
|
// helpers
|
||||||
//
|
//
|
||||||
|
|
||||||
|
///////// LORA
|
||||||
|
|
||||||
|
struct lora_weights {
|
||||||
|
ggml_tensor* loraA;
|
||||||
|
ggml_tensor* loraB;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct export_lora_params {
|
||||||
|
std::string fn_model_base;
|
||||||
|
std::string fn_model_out;
|
||||||
|
std::vector<struct lora_info> lora;
|
||||||
|
int n_threads;
|
||||||
|
};
|
||||||
|
|
||||||
|
static struct export_lora_params get_default_export_lora_params() {
|
||||||
|
struct export_lora_params result;
|
||||||
|
result.fn_model_base = "";
|
||||||
|
result.fn_model_out = "";
|
||||||
|
result.n_threads = GGML_DEFAULT_N_THREADS;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct lora_info {
|
||||||
|
std::string filename;
|
||||||
|
float scale;
|
||||||
|
};
|
||||||
|
// TODO lora_data should maybe sub lora_weights in llama.cpp
|
||||||
|
struct lora_data {
|
||||||
|
struct lora_info info;
|
||||||
|
std::vector<uint8_t> data;
|
||||||
|
struct ggml_context * ctx;
|
||||||
|
|
||||||
|
uint32_t lora_r;
|
||||||
|
uint32_t lora_alpha;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_file_lora {
|
||||||
|
// use FILE * so we don't have to re-open the file to mmap
|
||||||
|
FILE * fp;
|
||||||
|
size_t size;
|
||||||
|
|
||||||
|
llama_file_lora(const char * fname, const char * mode) {
|
||||||
|
fp = std::fopen(fname, mode);
|
||||||
|
if (fp == NULL) {
|
||||||
|
size = 0;
|
||||||
|
} else {
|
||||||
|
seek(0, SEEK_END);
|
||||||
|
size = tell();
|
||||||
|
seek(0, SEEK_SET);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t tell() const {
|
||||||
|
#ifdef _WIN32
|
||||||
|
__int64 ret = _ftelli64(fp);
|
||||||
|
#else
|
||||||
|
long ret = std::ftell(fp);
|
||||||
|
#endif
|
||||||
|
GGML_ASSERT(ret != -1); // this really shouldn't fail
|
||||||
|
return (size_t) ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
void seek(size_t offset, int whence) {
|
||||||
|
#ifdef _WIN32
|
||||||
|
int ret = _fseeki64(fp, (__int64) offset, whence);
|
||||||
|
#else
|
||||||
|
int ret = std::fseek(fp, (long) offset, whence);
|
||||||
|
#endif
|
||||||
|
GGML_ASSERT(ret == 0); // same
|
||||||
|
}
|
||||||
|
|
||||||
|
void read_raw(void * ptr, size_t size) {
|
||||||
|
if (size == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
errno = 0;
|
||||||
|
std::size_t ret = std::fread(ptr, size, 1, fp);
|
||||||
|
if (ferror(fp)) {
|
||||||
|
die_fmt("read error: %s", strerror(errno));
|
||||||
|
}
|
||||||
|
if (ret != 1) {
|
||||||
|
die("unexpectedly reached end of file");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::uint32_t read_u32() {
|
||||||
|
std::uint32_t ret;
|
||||||
|
read_raw(&ret, sizeof(ret));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string read_string(std::uint32_t len) {
|
||||||
|
std::vector<char> chars(len);
|
||||||
|
read_raw(chars.data(), len);
|
||||||
|
return std::string(chars.data(), len);
|
||||||
|
}
|
||||||
|
|
||||||
|
void write_raw(const void * ptr, size_t size) {
|
||||||
|
if (size == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
errno = 0;
|
||||||
|
size_t ret = std::fwrite(ptr, size, 1, fp);
|
||||||
|
if (ret != 1) {
|
||||||
|
die_fmt("write error: %s", strerror(errno));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void write_u32(std::uint32_t val) {
|
||||||
|
write_raw(&val, sizeof(val));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool eof() {
|
||||||
|
return tell() >= size;
|
||||||
|
}
|
||||||
|
|
||||||
|
~llama_file_lora() {
|
||||||
|
if (fp) {
|
||||||
|
std::fclose(fp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static void free_lora(struct lora_data * lora) {
|
||||||
|
if (lora->ctx != NULL) {
|
||||||
|
ggml_free(lora->ctx);
|
||||||
|
}
|
||||||
|
delete lora;
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct lora_data * load_lora(struct lora_info * info) {
|
||||||
|
struct lora_data * result = new struct lora_data;
|
||||||
|
result->info = *info;
|
||||||
|
result->ctx = NULL;
|
||||||
|
result->lora_r = 1;
|
||||||
|
result->lora_alpha = 1;
|
||||||
|
|
||||||
|
struct llama_file_lora file(info->filename.c_str(), "rb");
|
||||||
|
if (file.fp == NULL) {
|
||||||
|
fprintf(stderr, "warning: Could not open lora adapter '%s'. Ignoring this adapter.\n",
|
||||||
|
info->filename.c_str());
|
||||||
|
free_lora(result);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_init_params params_ggml;
|
||||||
|
params_ggml.mem_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE;
|
||||||
|
params_ggml.mem_buffer = NULL;
|
||||||
|
params_ggml.no_alloc = true;
|
||||||
|
result->ctx = ggml_init(params_ggml);
|
||||||
|
|
||||||
|
uint32_t magic = file.read_u32();
|
||||||
|
if (magic != LLAMA_FILE_MAGIC_GGLA) {
|
||||||
|
die_fmt("unexpected lora header file magic in '%s'", info->filename.c_str());
|
||||||
|
}
|
||||||
|
uint32_t version = file.read_u32();
|
||||||
|
if (version != 1) {
|
||||||
|
die_fmt("unexpected lora file version '%u' in '%s'", (unsigned) version, info->filename.c_str());
|
||||||
|
}
|
||||||
|
result->lora_r = file.read_u32();
|
||||||
|
result->lora_alpha = file.read_u32();
|
||||||
|
// read tensor infos from file
|
||||||
|
std::vector<char> name_buf;
|
||||||
|
std::vector<struct ggml_tensor *> tensors;
|
||||||
|
std::vector<size_t> tensors_offset;
|
||||||
|
size_t total_nbytes_pad = 0;
|
||||||
|
while(!file.eof()) {
|
||||||
|
int64_t ne[4] = {1,1,1,1};
|
||||||
|
uint32_t n_dims = file.read_u32();
|
||||||
|
uint32_t namelen = file.read_u32();
|
||||||
|
uint32_t type = file.read_u32();
|
||||||
|
for (uint32_t k = 0; k < n_dims; ++k) {
|
||||||
|
ne[k] = (int64_t)file.read_u32();
|
||||||
|
}
|
||||||
|
name_buf.clear();
|
||||||
|
name_buf.resize(namelen + 1, '\0');
|
||||||
|
file.read_raw(name_buf.data(), namelen);
|
||||||
|
file.seek((0-file.tell()) & 31, SEEK_CUR);
|
||||||
|
size_t offset = file.tell();
|
||||||
|
struct ggml_tensor * tensor = ggml_new_tensor(result->ctx, (enum ggml_type) type, n_dims, ne);
|
||||||
|
ggml_set_name(tensor, name_buf.data());
|
||||||
|
size_t nbytes = ggml_nbytes(tensor);
|
||||||
|
size_t nbytes_pad = ggml_nbytes_pad(tensor);
|
||||||
|
total_nbytes_pad += nbytes_pad;
|
||||||
|
tensors.push_back(tensor);
|
||||||
|
tensors_offset.push_back(offset);
|
||||||
|
file.seek(nbytes, SEEK_CUR);
|
||||||
|
}
|
||||||
|
// read tensor data
|
||||||
|
result->data.resize(total_nbytes_pad);
|
||||||
|
size_t data_offset = 0;
|
||||||
|
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||||
|
struct ggml_tensor * tensor = tensors[i];
|
||||||
|
size_t offset = tensors_offset[i];
|
||||||
|
size_t nbytes = ggml_nbytes(tensor);
|
||||||
|
size_t nbytes_pad = ggml_nbytes_pad(tensor);
|
||||||
|
file.seek(offset, SEEK_SET);
|
||||||
|
tensor->data = result->data.data() + data_offset;
|
||||||
|
file.read_raw(tensor->data, nbytes);
|
||||||
|
data_offset += nbytes_pad;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
///////// LORA
|
||||||
|
|
||||||
static size_t utf8_len(char src) {
|
static size_t utf8_len(char src) {
|
||||||
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||||
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
|
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
|
||||||
|
@ -2295,6 +2501,10 @@ struct llama_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_cparams cparams;
|
llama_cparams cparams;
|
||||||
|
bool lora_loaded = false;
|
||||||
|
std::map<std::string, lora_weights> lora_weights_map;
|
||||||
|
lora_data llora_data;
|
||||||
|
float lora_scale = 1.0f;
|
||||||
|
|
||||||
std::vector<ggml_backend_t> backends;
|
std::vector<ggml_backend_t> backends;
|
||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_METAL
|
||||||
|
@ -7447,21 +7657,21 @@ struct llm_build_context {
|
||||||
// self-attention
|
// self-attention
|
||||||
{
|
{
|
||||||
// compute Q and K and RoPE them
|
// compute Q and K and RoPE them
|
||||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
struct ggml_tensor * Qcur = lora_mul_mat(lctx, ctx0, model.layers[il].wq, cur);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
if (model.layers[il].bq) {
|
if (model.layers[il].bq) {
|
||||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
struct ggml_tensor * Kcur = lora_mul_mat(lctx, ctx0, model.layers[il].wk, cur);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
if (model.layers[il].bk) {
|
if (model.layers[il].bk) {
|
||||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
struct ggml_tensor * Vcur = lora_mul_mat(lctx, ctx0, model.layers[il].wv, cur);
|
||||||
cb(Vcur, "Vcur", il);
|
cb(Vcur, "Vcur", il);
|
||||||
if (model.layers[il].bv) {
|
if (model.layers[il].bv) {
|
||||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||||
|
@ -9470,6 +9680,35 @@ struct llm_build_context {
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static ggml_tensor * lora_mul_mat(
|
||||||
|
llama_context & lctx,
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_tensor * weight,
|
||||||
|
ggml_tensor * cur) {
|
||||||
|
ggml_tensor * mm = ggml_mul_mat(ctx0, weight, cur);
|
||||||
|
|
||||||
|
auto it = lctx.lora_weights_map.find(weight->name);
|
||||||
|
if (it == lctx.lora_weights_map.end()) {
|
||||||
|
return mm;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * loraA = it->second.loraA;
|
||||||
|
ggml_tensor * loraB = it->second.loraB;
|
||||||
|
|
||||||
|
ggml_tensor * t_lora = ggml_mul_mat(ctx0,
|
||||||
|
ggml_mul_mat(ctx0, loraA, loraB),
|
||||||
|
cur
|
||||||
|
);
|
||||||
|
|
||||||
|
if (lctx.lora_scale != 1.0f) {
|
||||||
|
t_lora = ggml_scale(ctx0, t_lora, lctx.lora_scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * t_patch = ggml_add(ctx0, mm, t_lora);
|
||||||
|
return t_patch;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_cgraph * build_phi3() {
|
struct ggml_cgraph * build_phi3() {
|
||||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||||
|
|
||||||
|
@ -16025,6 +16264,29 @@ void llama_free_model(struct llama_model * model) {
|
||||||
delete model;
|
delete model;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static std::map<std::string, lora_weights> get_lora_weights_map_cpp(struct ggml_context* ctx) {
|
||||||
|
struct lora_tensor_pair* pair = build_lora_weights_map(ctx);
|
||||||
|
std::map<std::string, lora_weights> map;
|
||||||
|
|
||||||
|
if (pair) {
|
||||||
|
for (int i = 0; i < pair->count; i++) {
|
||||||
|
std::string name(pair->pairs[i].name);
|
||||||
|
std::string base_name = name.substr(0, name.size() - 6);
|
||||||
|
std::string suffix = name.substr(name.size() - 6);
|
||||||
|
|
||||||
|
if (suffix == ".loraA") {
|
||||||
|
map[base_name].loraA = pair->pairs[i].tensor;
|
||||||
|
} else if (suffix == ".loraB") {
|
||||||
|
map[base_name].loraB = pair->pairs[i].tensor;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
free_lora_tensor_pair(pair);
|
||||||
|
}
|
||||||
|
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
struct llama_context * llama_new_context_with_model(
|
struct llama_context * llama_new_context_with_model(
|
||||||
struct llama_model * model,
|
struct llama_model * model,
|
||||||
struct llama_context_params params) {
|
struct llama_context_params params) {
|
||||||
|
@ -16056,6 +16318,37 @@ struct llama_context * llama_new_context_with_model(
|
||||||
|
|
||||||
llama_context * ctx = new llama_context(*model);
|
llama_context * ctx = new llama_context(*model);
|
||||||
|
|
||||||
|
/// LORA
|
||||||
|
struct export_lora_params * lora_params = new struct export_lora_params;
|
||||||
|
struct lora_info lora;
|
||||||
|
lora.filename = "./models/open-llama/lora-ggml-model-q8_0-shakespeare-LATEST.bin";
|
||||||
|
lora.scale = 1.0f; // redundant as already inside lora_context, but should be here for multiple loras
|
||||||
|
lora_params->lora.push_back(lora);
|
||||||
|
// load all loras
|
||||||
|
std::vector<struct lora_data *> loras;
|
||||||
|
for (size_t i = 0; i < lora_params->lora.size(); ++i) {
|
||||||
|
struct lora_data * llora_data = load_lora(&lora_params->lora[i]);
|
||||||
|
if (llora_data != NULL) {
|
||||||
|
loras.push_back(llora_data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (loras.size() == 0) {
|
||||||
|
fprintf(stderr, "warning: no lora adapters will be applied.\n");
|
||||||
|
}
|
||||||
|
// Assign data
|
||||||
|
ctx->llora_data = *loras[0];
|
||||||
|
|
||||||
|
// build the map?
|
||||||
|
ctx->lora_weights_map = get_lora_weights_map_cpp((ctx->llora_data).ctx);
|
||||||
|
std::vector<std::string> keys;
|
||||||
|
for (const auto& pair : ctx->lora_weights_map) {
|
||||||
|
keys.push_back(pair.first);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/// END LORA
|
||||||
|
|
||||||
const auto & hparams = model->hparams;
|
const auto & hparams = model->hparams;
|
||||||
auto & cparams = ctx->cparams;
|
auto & cparams = ctx->cparams;
|
||||||
|
|
||||||
|
|
3
llama.h
3
llama.h
|
@ -45,6 +45,9 @@
|
||||||
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
||||||
#define LLAMA_STATE_SEQ_VERSION 1
|
#define LLAMA_STATE_SEQ_VERSION 1
|
||||||
|
|
||||||
|
#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0)
|
||||||
|
#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0)
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue