Merge branch 'master' into gg/flash-attn
This commit is contained in:
commit
31109ca00a
87 changed files with 5115 additions and 1531 deletions
|
@ -38,6 +38,7 @@ else()
|
|||
add_subdirectory(speculative)
|
||||
add_subdirectory(lookahead)
|
||||
add_subdirectory(lookup)
|
||||
add_subdirectory(gguf)
|
||||
add_subdirectory(train-text-from-scratch)
|
||||
add_subdirectory(imatrix)
|
||||
if (LLAMA_BUILD_SERVER)
|
||||
|
|
|
@ -1533,16 +1533,17 @@ int main(int argc, char ** argv) {
|
|||
|
||||
int n_past = 0;
|
||||
|
||||
ggml_cgraph gf = {};
|
||||
struct ggml_cgraph * gf = NULL;
|
||||
gf = ggml_new_graph_custom(ctx0, LLAMA_TRAIN_MAX_NODES, true);
|
||||
|
||||
get_example_targets_batch(ctx0, 64*ex+0, tokens_input, targets);
|
||||
|
||||
struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch);
|
||||
struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, gf, tokens_input, n_tokens, n_past, n_batch);
|
||||
// struct ggml_tensor * e = cross_entropy_loss(ctx0, targets, logits);
|
||||
struct ggml_tensor * e = square_error_loss(ctx0, targets, logits);
|
||||
|
||||
ggml_build_forward_expand(&gf, e);
|
||||
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
|
||||
ggml_build_forward_expand(gf, e);
|
||||
ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1);
|
||||
|
||||
float error_before_opt = ggml_get_f32_1d(e, 0);
|
||||
|
||||
|
@ -1552,8 +1553,8 @@ int main(int argc, char ** argv) {
|
|||
opt_params_lbfgs.lbfgs.n_iter = 16;
|
||||
ggml_opt(ctx0, opt_params_lbfgs, e);
|
||||
//
|
||||
ggml_build_forward_expand(&gf, e);
|
||||
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
|
||||
ggml_build_forward_expand(gf, e);
|
||||
ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1);
|
||||
|
||||
float error_after_opt = ggml_get_f32_1d(e, 0);
|
||||
|
||||
|
@ -1600,13 +1601,14 @@ int main(int argc, char ** argv) {
|
|||
};
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
ggml_cgraph gf = {};
|
||||
struct ggml_cgraph * gf = NULL;
|
||||
gf = ggml_new_graph_custom(ctx0, LLAMA_TRAIN_MAX_NODES, true);
|
||||
|
||||
int n_past = 0;
|
||||
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past);
|
||||
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, gf, tokens_input, sample_ctx, n_past);
|
||||
|
||||
ggml_build_forward_expand(&gf, logits);
|
||||
ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
|
||||
ggml_build_forward_expand(gf, logits);
|
||||
ggml_graph_compute_helper(work_buffer, gf, /*n_threads*/ 1);
|
||||
|
||||
struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
|
||||
struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
|
||||
|
|
|
@ -82,7 +82,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// init LLM
|
||||
|
||||
llama_backend_init(params.numa);
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
// initialize the model
|
||||
|
||||
|
@ -158,7 +159,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
LOG_TEE("\n");
|
||||
LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d, n_threads = %d, n_threads_batch = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq, ctx_params.n_threads, ctx_params.n_threads_batch);
|
||||
LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq, ctx_params.n_threads, ctx_params.n_threads_batch);
|
||||
LOG_TEE("\n");
|
||||
|
||||
LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
|
||||
|
|
|
@ -17,7 +17,7 @@ let n_parallel: Int = arguments.count > 3 && Int(arguments[3]) != nil ? Int(argu
|
|||
let n_len: Int = 32
|
||||
|
||||
// init LLM
|
||||
llama_backend_init(false)
|
||||
llama_backend_init()
|
||||
defer {
|
||||
llama_backend_free()
|
||||
}
|
||||
|
|
|
@ -50,7 +50,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// init LLM
|
||||
|
||||
llama_backend_init(params.numa);
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
// initialize the model
|
||||
|
||||
|
@ -91,7 +92,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req);
|
||||
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %u, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req);
|
||||
|
||||
// make sure the KV cache is big enough to hold all the prompt and generated tokens
|
||||
if (n_kv_req > n_ctx) {
|
||||
|
|
|
@ -119,7 +119,8 @@ int main(int argc, char ** argv)
|
|||
// Init LLM :
|
||||
//---------------------------------
|
||||
|
||||
llama_backend_init(params.numa);
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
llama_model * model;
|
||||
llama_context * ctx;
|
||||
|
|
|
@ -325,14 +325,14 @@ struct train_params {
|
|||
};
|
||||
|
||||
static void print_params(struct my_llama_hparams * params) {
|
||||
printf("%s: n_vocab: %d\n", __func__, params->n_vocab);
|
||||
printf("%s: n_ctx: %d\n", __func__, params->n_ctx);
|
||||
printf("%s: n_embd: %d\n", __func__, params->n_embd);
|
||||
printf("%s: n_mult: %d\n", __func__, params->n_mult);
|
||||
printf("%s: n_head: %d\n", __func__, params->n_head);
|
||||
printf("%s: n_ff: %d\n", __func__, params->n_ff);
|
||||
printf("%s: n_layer: %d\n", __func__, params->n_layer);
|
||||
printf("%s: n_rot: %d\n", __func__, params->n_rot);
|
||||
printf("%s: n_vocab: %u\n", __func__, params->n_vocab);
|
||||
printf("%s: n_ctx: %u\n", __func__, params->n_ctx);
|
||||
printf("%s: n_embd: %u\n", __func__, params->n_embd);
|
||||
printf("%s: n_mult: %u\n", __func__, params->n_mult);
|
||||
printf("%s: n_head: %u\n", __func__, params->n_head);
|
||||
printf("%s: n_ff: %u\n", __func__, params->n_ff);
|
||||
printf("%s: n_layer: %u\n", __func__, params->n_layer);
|
||||
printf("%s: n_rot: %u\n", __func__, params->n_rot);
|
||||
}
|
||||
|
||||
static void init_model(struct my_llama_model * model) {
|
||||
|
@ -350,25 +350,25 @@ static void init_model(struct my_llama_model * model) {
|
|||
model->train_tokens = 0;
|
||||
|
||||
model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
|
||||
printf("[%s:GG] Allocating [%d] x [%d] = [%d] float space for model->tok_embeddings\n",__func__,n_embd , n_vocab, n_embd * n_vocab);
|
||||
printf("[%s:GG] Allocating [%u] x [%u] = [%u] float space for model->tok_embeddings\n",__func__,n_embd , n_vocab, n_embd * n_vocab);
|
||||
|
||||
model->norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
printf("[%s:GG] Allocating [%d] float space for model->norm\n",__func__,n_embd);
|
||||
printf("[%s:GG] Allocating [%u] float space for model->norm\n",__func__,n_embd);
|
||||
|
||||
model->output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
|
||||
printf("[%s:GG] Allocating [%d] x[%d] = [%d] float space for model->output\n",__func__,n_embd, n_vocab, n_embd * n_vocab);
|
||||
printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for model->output\n",__func__,n_embd, n_vocab, n_embd * n_vocab);
|
||||
|
||||
// printing the per-layer allocations here so we dont print in the for loop.
|
||||
printf("[%s:GG] Allocating [%d] x[%d] = [%d] float space for layer.wq for [%d] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer);
|
||||
printf("[%s:GG] Allocating [%d] x[%d] = [%d] float space for layer.wk for [%d] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer);
|
||||
printf("[%s:GG] Allocating [%d] x[%d] = [%d] float space for layer.wv for [%d] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer);
|
||||
printf("[%s:GG] Allocating [%d] x[%d] = [%d] float space for layer.wo for [%d] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer);
|
||||
printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.wq for [%u] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer);
|
||||
printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.wk for [%u] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer);
|
||||
printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.wv for [%u] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer);
|
||||
printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.wo for [%u] layers\n",__func__, n_embd, n_embd, n_embd * n_embd, n_layer);
|
||||
|
||||
printf("[%s:GG] Allocating [%d] float space for layer.ffn_norm for [%d] layers\n",__func__,n_embd, n_layer);
|
||||
printf("[%s:GG] Allocating [%u] float space for layer.ffn_norm for [%u] layers\n",__func__,n_embd, n_layer);
|
||||
|
||||
printf("[%s:GG] Allocating [%d] x[%d] = [%d] float space for layer.w1 for [%d] layers\n",__func__, n_ff, n_embd, n_embd * n_ff, n_layer);
|
||||
printf("[%s:GG] Allocating [%d] x[%d] = [%d] float space for layer.w2 for [%d] layers\n",__func__, n_embd, n_ff, n_ff * n_embd, n_layer);
|
||||
printf("[%s:GG] Allocating [%d] x[%d] = [%d] float space for layer.w3 for [%d] layers\n",__func__, n_ff, n_embd, n_embd * n_ff, n_layer);
|
||||
printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.w1 for [%u] layers\n",__func__, n_ff, n_embd, n_embd * n_ff, n_layer);
|
||||
printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.w2 for [%u] layers\n",__func__, n_embd, n_ff, n_ff * n_embd, n_layer);
|
||||
printf("[%s:GG] Allocating [%u] x[%u] = [%u] float space for layer.w3 for [%u] layers\n",__func__, n_ff, n_embd, n_embd * n_ff, n_layer);
|
||||
|
||||
ggml_set_name(model->tok_embeddings, "tok_embeddings.weight");
|
||||
ggml_set_name(model->norm, "norm.weight");
|
||||
|
|
|
@ -7,6 +7,51 @@
|
|||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
static std::vector<std::string> split_lines(const std::string & s) {
|
||||
std::string line;
|
||||
std::vector<std::string> lines;
|
||||
std::stringstream ss(s);
|
||||
while (std::getline(ss, line)) {
|
||||
lines.push_back(line);
|
||||
}
|
||||
return lines;
|
||||
}
|
||||
|
||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
llama_batch_add(batch, tokens[i], i, { seq_id }, false);
|
||||
}
|
||||
}
|
||||
|
||||
static void normalize(float * vec, float * out, int n) {
|
||||
float norm = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
norm += vec[i] * vec[i];
|
||||
}
|
||||
norm = sqrt(norm);
|
||||
for (int i = 0; i < n; i++) {
|
||||
out[i] = vec[i] / norm;
|
||||
}
|
||||
}
|
||||
|
||||
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
|
||||
// clear previous kv_cache values (irrelevant for embeddings)
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
// run model
|
||||
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
||||
if (llama_decode(ctx, batch) < 0) {
|
||||
fprintf(stderr, "%s : failed to decode\n", __func__);
|
||||
}
|
||||
|
||||
// normalize on copy
|
||||
for (int k = 0; k < n_seq; k++) {
|
||||
float * emb = llama_get_embeddings_ith(ctx, k);
|
||||
float * out = output + k * n_embd;
|
||||
normalize(emb, out, n_embd);
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
|
||||
|
@ -29,7 +74,8 @@ int main(int argc, char ** argv) {
|
|||
params.prompt = gpt_random_prompt(rng);
|
||||
}
|
||||
|
||||
llama_backend_init(params.numa);
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
llama_model * model;
|
||||
llama_context * ctx;
|
||||
|
@ -55,59 +101,84 @@ int main(int argc, char ** argv) {
|
|||
fprintf(stderr, "%s\n", get_system_info(params).c_str());
|
||||
}
|
||||
|
||||
int n_past = 0;
|
||||
// split the prompt into lines
|
||||
std::vector<std::string> prompts = split_lines(params.prompt);
|
||||
|
||||
// tokenize the prompt
|
||||
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
|
||||
// max batch size
|
||||
const uint64_t n_batch = params.n_batch;
|
||||
GGML_ASSERT(params.n_batch == params.n_ctx);
|
||||
|
||||
// tokenize the prompts and trim
|
||||
std::vector<std::vector<int32_t>> inputs;
|
||||
for (const auto & prompt : prompts) {
|
||||
auto inp = ::llama_tokenize(ctx, prompt, true);
|
||||
if (inp.size() > n_batch) {
|
||||
inp.resize(n_batch);
|
||||
}
|
||||
inputs.push_back(inp);
|
||||
}
|
||||
|
||||
// tokenization stats
|
||||
if (params.verbose_prompt) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
|
||||
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
|
||||
for (int i = 0; i < (int) embd_inp.size(); i++) {
|
||||
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
|
||||
for (int i = 0; i < (int) inputs.size(); i++) {
|
||||
fprintf(stderr, "%s: prompt %d: '%s'\n", __func__, i, prompts[i].c_str());
|
||||
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, inputs[i].size());
|
||||
for (int j = 0; j < (int) inputs[i].size(); j++) {
|
||||
fprintf(stderr, "%6d -> '%s'\n", inputs[i][j], llama_token_to_piece(ctx, inputs[i][j]).c_str());
|
||||
}
|
||||
fprintf(stderr, "\n\n");
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
if (embd_inp.size() > (size_t)n_ctx) {
|
||||
fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n",
|
||||
__func__, embd_inp.size(), n_ctx);
|
||||
return 1;
|
||||
}
|
||||
|
||||
while (!embd_inp.empty()) {
|
||||
int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
|
||||
if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0))) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
n_past += n_tokens;
|
||||
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens);
|
||||
}
|
||||
// initialize batch
|
||||
const int n_prompts = prompts.size();
|
||||
struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts);
|
||||
|
||||
// allocate output
|
||||
const int n_embd = llama_n_embd(model);
|
||||
auto * embeddings = llama_get_embeddings(ctx);
|
||||
std::vector<float> embeddings(n_prompts * n_embd, 0);
|
||||
float * emb = embeddings.data();
|
||||
|
||||
// l2-normalize embeddings
|
||||
float norm = 0;
|
||||
for (int i = 0; i < n_embd; i++) {
|
||||
norm += embeddings[i] * embeddings[i];
|
||||
}
|
||||
norm = sqrt(norm);
|
||||
for (int i = 0; i < n_embd; i++) {
|
||||
embeddings[i] /= norm;
|
||||
// break into batches
|
||||
int p = 0; // number of prompts processed already
|
||||
int s = 0; // number of prompts in current batch
|
||||
for (int k = 0; k < n_prompts; k++) {
|
||||
// clamp to n_batch tokens
|
||||
auto & inp = inputs[k];
|
||||
const uint64_t n_toks = inp.size();
|
||||
|
||||
// encode if at capacity
|
||||
if (batch.n_tokens + n_toks > n_batch) {
|
||||
float * out = emb + p * n_embd;
|
||||
batch_decode(ctx, batch, out, s, n_embd);
|
||||
llama_batch_clear(batch);
|
||||
p += s;
|
||||
s = 0;
|
||||
}
|
||||
|
||||
// add to batch
|
||||
batch_add_seq(batch, inp, s);
|
||||
s += 1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_embd; i++) {
|
||||
printf("%f ", embeddings[i]);
|
||||
}
|
||||
printf("\n");
|
||||
// final batch
|
||||
float * out = emb + p * n_embd;
|
||||
batch_decode(ctx, batch, out, s, n_embd);
|
||||
|
||||
// print first 3 embeddings
|
||||
for (int j = 0; j < std::min(3, n_prompts); j++) {
|
||||
fprintf(stderr, "embedding %d: ", j);
|
||||
for (int i = 0; i < n_embd; i++) {
|
||||
fprintf(stderr, "%f ", emb[j * n_embd + i]);
|
||||
}
|
||||
fprintf(stderr, "\n\n");
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
// clean up
|
||||
llama_print_timings(ctx);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
|
|
|
@ -7,8 +7,6 @@
|
|||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
static const size_t tensor_alignment = 32;
|
||||
|
||||
struct lora_info {
|
||||
std::string filename;
|
||||
float scale;
|
||||
|
|
|
@ -80,9 +80,9 @@ The LORA rank can be configured for each model tensor type separately with these
|
|||
--rank-wk N LORA rank for wk tensor (default 4)
|
||||
--rank-wv N LORA rank for wv tensor (default 4)
|
||||
--rank-wo N LORA rank for wo tensor (default 4)
|
||||
--rank-w1 N LORA rank for w1 tensor (default 4)
|
||||
--rank-w2 N LORA rank for w2 tensor (default 4)
|
||||
--rank-w3 N LORA rank for w3 tensor (default 4)
|
||||
--rank-ffn_gate N LORA rank for ffn_gate tensor (default 4)
|
||||
--rank-ffn_down N LORA rank for ffn_down tensor (default 4)
|
||||
--rank-ffn_up N LORA rank for ffn_up tensor (default 4)
|
||||
```
|
||||
|
||||
The LORA rank of 'norm' tensors should always be 1.
|
||||
|
|
|
@ -60,9 +60,9 @@ struct my_llama_layer {
|
|||
struct ggml_tensor * ffn_norm;
|
||||
|
||||
// ff
|
||||
struct ggml_tensor * w1;
|
||||
struct ggml_tensor * w2;
|
||||
struct ggml_tensor * w3;
|
||||
struct ggml_tensor * ffn_gate; // w1
|
||||
struct ggml_tensor * ffn_down; // w2
|
||||
struct ggml_tensor * ffn_up; // w3
|
||||
};
|
||||
|
||||
struct my_llama_model {
|
||||
|
@ -85,9 +85,9 @@ struct my_llama_lora_hparams {
|
|||
uint32_t n_rank_wv = 4;
|
||||
uint32_t n_rank_wo = 4;
|
||||
uint32_t n_rank_ffn_norm = 1;
|
||||
uint32_t n_rank_w1 = 4;
|
||||
uint32_t n_rank_w2 = 4;
|
||||
uint32_t n_rank_w3 = 4;
|
||||
uint32_t n_rank_ffn_gate = 4;
|
||||
uint32_t n_rank_ffn_down = 4;
|
||||
uint32_t n_rank_ffn_up = 4;
|
||||
uint32_t n_rank_tok_embeddings = 4;
|
||||
uint32_t n_rank_norm = 1;
|
||||
uint32_t n_rank_output = 4;
|
||||
|
@ -117,12 +117,12 @@ struct my_llama_lora_layer {
|
|||
struct ggml_tensor * ffn_norm_b;
|
||||
|
||||
// ff
|
||||
struct ggml_tensor * w1_a;
|
||||
struct ggml_tensor * w1_b;
|
||||
struct ggml_tensor * w2_a;
|
||||
struct ggml_tensor * w2_b;
|
||||
struct ggml_tensor * w3_a;
|
||||
struct ggml_tensor * w3_b;
|
||||
struct ggml_tensor * ffn_gate_a;
|
||||
struct ggml_tensor * ffn_gate_b;
|
||||
struct ggml_tensor * ffn_down_a;
|
||||
struct ggml_tensor * ffn_down_b;
|
||||
struct ggml_tensor * ffn_up_a;
|
||||
struct ggml_tensor * ffn_up_b;
|
||||
};
|
||||
|
||||
struct my_llama_lora {
|
||||
|
@ -208,9 +208,9 @@ static void print_lora_params(struct my_llama_lora_hparams * params) {
|
|||
printf("%s: n_rank_wv : %u\n", __func__, params->n_rank_wv);
|
||||
printf("%s: n_rank_wo : %u\n", __func__, params->n_rank_wo);
|
||||
printf("%s: n_rank_ffn_norm : %u\n", __func__, params->n_rank_ffn_norm);
|
||||
printf("%s: n_rank_w1 : %u\n", __func__, params->n_rank_w1);
|
||||
printf("%s: n_rank_w2 : %u\n", __func__, params->n_rank_w2);
|
||||
printf("%s: n_rank_w3 : %u\n", __func__, params->n_rank_w3);
|
||||
printf("%s: n_rank_ffn_gate : %u\n", __func__, params->n_rank_ffn_gate);
|
||||
printf("%s: n_rank_ffn_down : %u\n", __func__, params->n_rank_ffn_down);
|
||||
printf("%s: n_rank_ffn_up : %u\n", __func__, params->n_rank_ffn_up);
|
||||
printf("%s: n_rank_tok_embeddings : %u\n", __func__, params->n_rank_tok_embeddings);
|
||||
printf("%s: n_rank_norm : %u\n", __func__, params->n_rank_norm);
|
||||
printf("%s: n_rank_output : %u\n", __func__, params->n_rank_output);
|
||||
|
@ -319,9 +319,9 @@ static void init_model(struct llama_model * input, struct my_llama_model * model
|
|||
layer.wv = llama_get_model_tensor(input, tni(LLM_TENSOR_ATTN_V, i));
|
||||
layer.wo = llama_get_model_tensor(input, tni(LLM_TENSOR_ATTN_OUT, i));
|
||||
layer.ffn_norm = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_NORM, i));
|
||||
layer.w1 = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_GATE, i));
|
||||
layer.w2 = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_DOWN, i));
|
||||
layer.w3 = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_UP, i));
|
||||
layer.ffn_gate = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_GATE, i));
|
||||
layer.ffn_down = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_DOWN, i));
|
||||
layer.ffn_up = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_UP, i));
|
||||
|
||||
assert_shape_1d(layer.attention_norm, hparams.n_embd);
|
||||
assert_shape_2d(layer.wq, hparams.n_embd, hparams.n_embd);
|
||||
|
@ -329,9 +329,9 @@ static void init_model(struct llama_model * input, struct my_llama_model * model
|
|||
assert_shape_2d(layer.wv, hparams.n_embd, hparams.n_embd_gqa());
|
||||
assert_shape_2d(layer.wo, hparams.n_embd, hparams.n_embd);
|
||||
assert_shape_1d(layer.ffn_norm, hparams.n_embd);
|
||||
assert_shape_2d(layer.w1, hparams.n_embd, hparams.n_ff);
|
||||
assert_shape_2d(layer.w2, hparams.n_ff, hparams.n_embd);
|
||||
assert_shape_2d(layer.w3, hparams.n_embd, hparams.n_ff);
|
||||
assert_shape_2d(layer.ffn_gate, hparams.n_embd, hparams.n_ff);
|
||||
assert_shape_2d(layer.ffn_down, hparams.n_ff, hparams.n_embd);
|
||||
assert_shape_2d(layer.ffn_up, hparams.n_embd, hparams.n_ff);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -362,12 +362,12 @@ static void set_param_lora(struct my_llama_lora * lora) {
|
|||
ggml_set_param(ctx, layer.wo_b);
|
||||
ggml_set_param(ctx, layer.ffn_norm_a);
|
||||
ggml_set_param(ctx, layer.ffn_norm_b);
|
||||
ggml_set_param(ctx, layer.w1_a);
|
||||
ggml_set_param(ctx, layer.w1_b);
|
||||
ggml_set_param(ctx, layer.w2_a);
|
||||
ggml_set_param(ctx, layer.w2_b);
|
||||
ggml_set_param(ctx, layer.w3_a);
|
||||
ggml_set_param(ctx, layer.w3_b);
|
||||
ggml_set_param(ctx, layer.ffn_gate_a);
|
||||
ggml_set_param(ctx, layer.ffn_gate_b);
|
||||
ggml_set_param(ctx, layer.ffn_down_a);
|
||||
ggml_set_param(ctx, layer.ffn_down_b);
|
||||
ggml_set_param(ctx, layer.ffn_up_a);
|
||||
ggml_set_param(ctx, layer.ffn_up_b);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -435,12 +435,12 @@ static void init_lora(const struct my_llama_model * model, struct my_llama_lora
|
|||
layer.ffn_norm_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_ffn_norm, n_embd);
|
||||
layer.ffn_norm_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_ffn_norm, 1);
|
||||
|
||||
layer.w1_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_w1, n_embd);
|
||||
layer.w1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_w1, n_ff);
|
||||
layer.w2_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_w2, n_ff);
|
||||
layer.w2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_w2, n_embd);
|
||||
layer.w3_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_w3, n_embd);
|
||||
layer.w3_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_w3, n_ff);
|
||||
layer.ffn_gate_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_ffn_gate, n_embd);
|
||||
layer.ffn_gate_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_ffn_gate, n_ff);
|
||||
layer.ffn_down_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_ffn_down, n_ff);
|
||||
layer.ffn_down_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_ffn_down, n_embd);
|
||||
layer.ffn_up_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_ffn_up, n_embd);
|
||||
layer.ffn_up_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_ffn_up, n_ff);
|
||||
|
||||
ggml_set_name(layer.attention_norm_a, tni(LLM_TENSOR_ATTN_NORM, ".weight.lora_a", i));
|
||||
ggml_set_name(layer.attention_norm_b, tni(LLM_TENSOR_ATTN_NORM, ".weight.lora_b", i));
|
||||
|
@ -454,12 +454,12 @@ static void init_lora(const struct my_llama_model * model, struct my_llama_lora
|
|||
ggml_set_name(layer.wo_b, tni(LLM_TENSOR_ATTN_OUT, ".weight.lora_b", i));
|
||||
ggml_set_name(layer.ffn_norm_a, tni(LLM_TENSOR_FFN_NORM, ".weight.lora_a", i));
|
||||
ggml_set_name(layer.ffn_norm_b, tni(LLM_TENSOR_FFN_NORM, ".weight.lora_b", i));
|
||||
ggml_set_name(layer.w1_a, tni(LLM_TENSOR_FFN_GATE, ".weight.lora_a", i));
|
||||
ggml_set_name(layer.w1_b, tni(LLM_TENSOR_FFN_GATE, ".weight.lora_b", i));
|
||||
ggml_set_name(layer.w2_a, tni(LLM_TENSOR_FFN_DOWN, ".weight.lora_a", i));
|
||||
ggml_set_name(layer.w2_b, tni(LLM_TENSOR_FFN_DOWN, ".weight.lora_b", i));
|
||||
ggml_set_name(layer.w3_a, tni(LLM_TENSOR_FFN_UP, ".weight.lora_a", i));
|
||||
ggml_set_name(layer.w3_b, tni(LLM_TENSOR_FFN_UP, ".weight.lora_b", i));
|
||||
ggml_set_name(layer.ffn_gate_a, tni(LLM_TENSOR_FFN_GATE, ".weight.lora_a", i));
|
||||
ggml_set_name(layer.ffn_gate_b, tni(LLM_TENSOR_FFN_GATE, ".weight.lora_b", i));
|
||||
ggml_set_name(layer.ffn_down_a, tni(LLM_TENSOR_FFN_DOWN, ".weight.lora_a", i));
|
||||
ggml_set_name(layer.ffn_down_b, tni(LLM_TENSOR_FFN_DOWN, ".weight.lora_b", i));
|
||||
ggml_set_name(layer.ffn_up_a, tni(LLM_TENSOR_FFN_UP, ".weight.lora_a", i));
|
||||
ggml_set_name(layer.ffn_up_b, tni(LLM_TENSOR_FFN_UP, ".weight.lora_b", i));
|
||||
}
|
||||
|
||||
set_param_lora(lora);
|
||||
|
@ -497,12 +497,12 @@ static void randomize_lora(struct my_llama_lora * lora, int seed, float mean, fl
|
|||
randomize_tensor_normal(layer.ffn_norm_a, rnd);
|
||||
ggml_set_zero(layer.ffn_norm_b);
|
||||
|
||||
randomize_tensor_normal(layer.w1_a, rnd);
|
||||
ggml_set_zero(layer.w1_b);
|
||||
randomize_tensor_normal(layer.w2_a, rnd);
|
||||
ggml_set_zero(layer.w2_b);
|
||||
randomize_tensor_normal(layer.w3_a, rnd);
|
||||
ggml_set_zero(layer.w3_b);
|
||||
randomize_tensor_normal(layer.ffn_gate_a, rnd);
|
||||
ggml_set_zero(layer.ffn_gate_b);
|
||||
randomize_tensor_normal(layer.ffn_down_a, rnd);
|
||||
ggml_set_zero(layer.ffn_down_b);
|
||||
randomize_tensor_normal(layer.ffn_up_a, rnd);
|
||||
ggml_set_zero(layer.ffn_up_b);
|
||||
}
|
||||
|
||||
free_random_normal_distribution(rnd);
|
||||
|
@ -610,13 +610,13 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
|
|||
|
||||
struct ggml_tensor * attention_norm = add_to_f32(ctx, layer.attention_norm, ggml_mul_mat(ctx, llayer.attention_norm_a, llayer.attention_norm_b));
|
||||
struct ggml_tensor * ffn_norm = add_to_f32(ctx, layer.ffn_norm, ggml_mul_mat(ctx, llayer.ffn_norm_a, llayer.ffn_norm_b));
|
||||
struct ggml_tensor * wq = add_to_f32(ctx, layer.wq, ggml_mul_mat(ctx, llayer.wq_a, llayer.wq_b));
|
||||
struct ggml_tensor * wk = add_to_f32(ctx, layer.wk, ggml_mul_mat(ctx, llayer.wk_a, llayer.wk_b));
|
||||
struct ggml_tensor * wv = add_to_f32(ctx, layer.wv, ggml_mul_mat(ctx, llayer.wv_a, llayer.wv_b));
|
||||
struct ggml_tensor * wo = add_to_f32(ctx, layer.wo, ggml_mul_mat(ctx, llayer.wo_a, llayer.wo_b));
|
||||
struct ggml_tensor * w1 = add_to_f32(ctx, layer.w1, ggml_mul_mat(ctx, llayer.w1_a, llayer.w1_b));
|
||||
struct ggml_tensor * w2 = add_to_f32(ctx, layer.w2, ggml_mul_mat(ctx, llayer.w2_a, llayer.w2_b));
|
||||
struct ggml_tensor * w3 = add_to_f32(ctx, layer.w3, ggml_mul_mat(ctx, llayer.w3_a, llayer.w3_b));
|
||||
struct ggml_tensor * wq = add_to_f32(ctx, layer.wq, ggml_mul_mat(ctx, llayer.wq_a, llayer.wq_b));
|
||||
struct ggml_tensor * wk = add_to_f32(ctx, layer.wk, ggml_mul_mat(ctx, llayer.wk_a, llayer.wk_b));
|
||||
struct ggml_tensor * wv = add_to_f32(ctx, layer.wv, ggml_mul_mat(ctx, llayer.wv_a, llayer.wv_b));
|
||||
struct ggml_tensor * wo = add_to_f32(ctx, layer.wo, ggml_mul_mat(ctx, llayer.wo_a, llayer.wo_b));
|
||||
struct ggml_tensor * ffn_gate = add_to_f32(ctx, layer.ffn_gate, ggml_mul_mat(ctx, llayer.ffn_gate_a, llayer.ffn_gate_b));
|
||||
struct ggml_tensor * ffn_down = add_to_f32(ctx, layer.ffn_down, ggml_mul_mat(ctx, llayer.ffn_down_a, llayer.ffn_down_b));
|
||||
struct ggml_tensor * ffn_up = add_to_f32(ctx, layer.ffn_up, ggml_mul_mat(ctx, llayer.ffn_up_a, llayer.ffn_up_b));
|
||||
|
||||
struct ggml_tensor * t02 = ggml_rms_norm (ctx, cur, rms_norm_eps); set_name(t02, "t02"); assert_shape_2d(t02, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t03 = ggml_repeat (ctx, attention_norm, t02); set_name(t03, "t03"); assert_shape_2d(t03, n_embd, N*n_batch);
|
||||
|
@ -659,11 +659,11 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
|
|||
struct ggml_tensor * t22 = ggml_rms_norm (ctx, t21, rms_norm_eps); set_name(t22, "t22"); assert_shape_2d(t22, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t23 = ggml_repeat (ctx, ffn_norm, t22); set_name(t23, "t23"); assert_shape_2d(t23, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t24 = ggml_mul (ctx, t23, t22); set_name(t24, "t24"); assert_shape_2d(t24, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t25 = ggml_mul_mat (ctx, w3, t24); set_name(t25, "t25"); assert_shape_2d(t25, n_ff, N*n_batch);
|
||||
struct ggml_tensor * t26 = ggml_mul_mat (ctx, w1, t24); set_name(t26, "t26"); assert_shape_2d(t26, n_ff, N*n_batch);
|
||||
struct ggml_tensor * t25 = ggml_mul_mat (ctx, ffn_up, t24); set_name(t25, "t25"); assert_shape_2d(t25, n_ff, N*n_batch);
|
||||
struct ggml_tensor * t26 = ggml_mul_mat (ctx, ffn_gate, t24); set_name(t26, "t26"); assert_shape_2d(t26, n_ff, N*n_batch);
|
||||
struct ggml_tensor * t27 = ggml_silu (ctx, t26); set_name(t27, "t27"); assert_shape_2d(t27, n_ff, N*n_batch);
|
||||
struct ggml_tensor * t28 = ggml_mul (ctx, t27, t25); set_name(t28, "t28"); assert_shape_2d(t28, n_ff, N*n_batch);
|
||||
struct ggml_tensor * t29 = ggml_mul_mat (ctx, w2, t28); set_name(t29, "t29"); assert_shape_2d(t29, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t29 = ggml_mul_mat (ctx, ffn_down, t28); set_name(t29, "t29"); assert_shape_2d(t29, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t30 = ggml_add (ctx, t29, t21); set_name(t30, "t30"); assert_shape_2d(t30, n_embd, N*n_batch);
|
||||
cur = t30;
|
||||
if (enable_checkpointing) {
|
||||
|
@ -723,9 +723,9 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
|
|||
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, 1.0f));
|
||||
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, 1.0f));
|
||||
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, 1.0f));
|
||||
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w1, 1.0f));
|
||||
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w2, 1.0f));
|
||||
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, 1.0f));
|
||||
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_gate, 1.0f));
|
||||
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_down, 1.0f));
|
||||
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_up, 1.0f));
|
||||
}
|
||||
|
||||
// allocating checkpoints in one block to reduce memory fragmentation
|
||||
|
@ -798,9 +798,9 @@ static void load_llama_lora_gguf(struct gguf_context * fctx, struct ggml_context
|
|||
GGUF_GET_KEY(fctx, lora->hparams.n_rank_wv, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_ATTN_V);
|
||||
GGUF_GET_KEY(fctx, lora->hparams.n_rank_wo, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_ATTN_OUT);
|
||||
GGUF_GET_KEY(fctx, lora->hparams.n_rank_ffn_norm, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_FFN_NORM);
|
||||
GGUF_GET_KEY(fctx, lora->hparams.n_rank_w1, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_FFN_GATE);
|
||||
GGUF_GET_KEY(fctx, lora->hparams.n_rank_w2, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_FFN_DOWN);
|
||||
GGUF_GET_KEY(fctx, lora->hparams.n_rank_w3, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_FFN_UP);
|
||||
GGUF_GET_KEY(fctx, lora->hparams.n_rank_ffn_gate, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_FFN_GATE);
|
||||
GGUF_GET_KEY(fctx, lora->hparams.n_rank_ffn_down, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_FFN_DOWN);
|
||||
GGUF_GET_KEY(fctx, lora->hparams.n_rank_ffn_up, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_FFN_UP);
|
||||
|
||||
init_lora(model, lora);
|
||||
|
||||
|
@ -825,12 +825,12 @@ static void load_llama_lora_gguf(struct gguf_context * fctx, struct ggml_context
|
|||
copy_tensor_by_name(layer.wo_b, f_ggml_ctx, ggml_get_name(layer.wo_b));
|
||||
copy_tensor_by_name(layer.ffn_norm_a, f_ggml_ctx, ggml_get_name(layer.ffn_norm_a));
|
||||
copy_tensor_by_name(layer.ffn_norm_b, f_ggml_ctx, ggml_get_name(layer.ffn_norm_b));
|
||||
copy_tensor_by_name(layer.w1_a, f_ggml_ctx, ggml_get_name(layer.w1_a));
|
||||
copy_tensor_by_name(layer.w1_b, f_ggml_ctx, ggml_get_name(layer.w1_b));
|
||||
copy_tensor_by_name(layer.w2_a, f_ggml_ctx, ggml_get_name(layer.w2_a));
|
||||
copy_tensor_by_name(layer.w2_b, f_ggml_ctx, ggml_get_name(layer.w2_b));
|
||||
copy_tensor_by_name(layer.w3_a, f_ggml_ctx, ggml_get_name(layer.w3_a));
|
||||
copy_tensor_by_name(layer.w3_b, f_ggml_ctx, ggml_get_name(layer.w3_b));
|
||||
copy_tensor_by_name(layer.ffn_gate_a, f_ggml_ctx, ggml_get_name(layer.ffn_gate_a));
|
||||
copy_tensor_by_name(layer.ffn_gate_b, f_ggml_ctx, ggml_get_name(layer.ffn_gate_b));
|
||||
copy_tensor_by_name(layer.ffn_down_a, f_ggml_ctx, ggml_get_name(layer.ffn_down_a));
|
||||
copy_tensor_by_name(layer.ffn_down_b, f_ggml_ctx, ggml_get_name(layer.ffn_down_b));
|
||||
copy_tensor_by_name(layer.ffn_up_a, f_ggml_ctx, ggml_get_name(layer.ffn_up_a));
|
||||
copy_tensor_by_name(layer.ffn_up_b, f_ggml_ctx, ggml_get_name(layer.ffn_up_b));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -868,9 +868,9 @@ static void save_llama_lora_gguf(struct gguf_context * fctx, struct my_llama_mod
|
|||
gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_ATTN_V, lora->hparams.n_rank_wv);
|
||||
gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_ATTN_OUT, lora->hparams.n_rank_wo);
|
||||
gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_FFN_NORM, lora->hparams.n_rank_ffn_norm);
|
||||
gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_FFN_GATE, lora->hparams.n_rank_w1);
|
||||
gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_FFN_DOWN, lora->hparams.n_rank_w2);
|
||||
gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_FFN_UP, lora->hparams.n_rank_w3);
|
||||
gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_FFN_GATE, lora->hparams.n_rank_ffn_gate);
|
||||
gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_FFN_DOWN, lora->hparams.n_rank_ffn_down);
|
||||
gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_FFN_UP, lora->hparams.n_rank_ffn_up);
|
||||
|
||||
gguf_add_tensor(fctx, lora->tok_embeddings_a);
|
||||
gguf_add_tensor(fctx, lora->tok_embeddings_b);
|
||||
|
@ -894,12 +894,12 @@ static void save_llama_lora_gguf(struct gguf_context * fctx, struct my_llama_mod
|
|||
gguf_add_tensor(fctx, layer.wo_b);
|
||||
gguf_add_tensor(fctx, layer.ffn_norm_a);
|
||||
gguf_add_tensor(fctx, layer.ffn_norm_b);
|
||||
gguf_add_tensor(fctx, layer.w1_a);
|
||||
gguf_add_tensor(fctx, layer.w1_b);
|
||||
gguf_add_tensor(fctx, layer.w2_a);
|
||||
gguf_add_tensor(fctx, layer.w2_b);
|
||||
gguf_add_tensor(fctx, layer.w3_a);
|
||||
gguf_add_tensor(fctx, layer.w3_b);
|
||||
gguf_add_tensor(fctx, layer.ffn_gate_a);
|
||||
gguf_add_tensor(fctx, layer.ffn_gate_b);
|
||||
gguf_add_tensor(fctx, layer.ffn_down_a);
|
||||
gguf_add_tensor(fctx, layer.ffn_down_b);
|
||||
gguf_add_tensor(fctx, layer.ffn_up_a);
|
||||
gguf_add_tensor(fctx, layer.ffn_up_b);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1104,12 +1104,12 @@ static void save_as_llama_lora(const char * filename, struct my_llama_lora * lor
|
|||
write_tensor(&file, layer.wo_b, tni(LLM_TENSOR_ATTN_OUT, i, ".weight.loraB"));
|
||||
write_tensor(&file, layer.ffn_norm_a, tni(LLM_TENSOR_FFN_NORM, i, ".weight.loraA"));
|
||||
write_tensor(&file, layer.ffn_norm_b, tni(LLM_TENSOR_FFN_NORM, i, ".weight.loraB"));
|
||||
write_tensor(&file, layer.w1_a, tni(LLM_TENSOR_FFN_GATE, i, ".weight.loraA"));
|
||||
write_tensor(&file, layer.w1_b, tni(LLM_TENSOR_FFN_GATE, i, ".weight.loraB"));
|
||||
write_tensor(&file, layer.w2_a, tni(LLM_TENSOR_FFN_DOWN, i, ".weight.loraA"));
|
||||
write_tensor(&file, layer.w2_b, tni(LLM_TENSOR_FFN_DOWN, i, ".weight.loraB"));
|
||||
write_tensor(&file, layer.w3_a, tni(LLM_TENSOR_FFN_UP, i, ".weight.loraA"));
|
||||
write_tensor(&file, layer.w3_b, tni(LLM_TENSOR_FFN_UP, i, ".weight.loraB"));
|
||||
write_tensor(&file, layer.ffn_gate_a, tni(LLM_TENSOR_FFN_GATE, i, ".weight.loraA"));
|
||||
write_tensor(&file, layer.ffn_gate_b, tni(LLM_TENSOR_FFN_GATE, i, ".weight.loraB"));
|
||||
write_tensor(&file, layer.ffn_down_a, tni(LLM_TENSOR_FFN_DOWN, i, ".weight.loraA"));
|
||||
write_tensor(&file, layer.ffn_down_b, tni(LLM_TENSOR_FFN_DOWN, i, ".weight.loraB"));
|
||||
write_tensor(&file, layer.ffn_up_a, tni(LLM_TENSOR_FFN_UP, i, ".weight.loraA"));
|
||||
write_tensor(&file, layer.ffn_up_b, tni(LLM_TENSOR_FFN_UP, i, ".weight.loraB"));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1139,9 +1139,9 @@ struct train_params {
|
|||
uint32_t n_rank_wv;
|
||||
uint32_t n_rank_wo;
|
||||
uint32_t n_rank_ffn_norm;
|
||||
uint32_t n_rank_w1;
|
||||
uint32_t n_rank_w2;
|
||||
uint32_t n_rank_w3;
|
||||
uint32_t n_rank_ffn_gate;
|
||||
uint32_t n_rank_ffn_down;
|
||||
uint32_t n_rank_ffn_up;
|
||||
uint32_t n_rank_tok_embeddings;
|
||||
uint32_t n_rank_norm;
|
||||
uint32_t n_rank_output;
|
||||
|
@ -1152,9 +1152,9 @@ struct train_params {
|
|||
bool custom_n_rank_wv;
|
||||
bool custom_n_rank_wo;
|
||||
bool custom_n_rank_ffn_norm;
|
||||
bool custom_n_rank_w1;
|
||||
bool custom_n_rank_w2;
|
||||
bool custom_n_rank_w3;
|
||||
bool custom_n_rank_ffn_gate;
|
||||
bool custom_n_rank_ffn_down;
|
||||
bool custom_n_rank_ffn_up;
|
||||
bool custom_n_rank_tok_embeddings;
|
||||
bool custom_n_rank_norm;
|
||||
bool custom_n_rank_output;
|
||||
|
@ -1186,9 +1186,9 @@ static struct train_params get_default_train_params() {
|
|||
params.n_rank_wv = 4;
|
||||
params.n_rank_wo = 4;
|
||||
params.n_rank_ffn_norm = 1;
|
||||
params.n_rank_w1 = 4;
|
||||
params.n_rank_w2 = 4;
|
||||
params.n_rank_w3 = 4;
|
||||
params.n_rank_ffn_gate = 4;
|
||||
params.n_rank_ffn_down = 4;
|
||||
params.n_rank_ffn_up = 4;
|
||||
params.n_rank_tok_embeddings = 4;
|
||||
params.n_rank_norm = 1;
|
||||
params.n_rank_output = 4;
|
||||
|
@ -1199,9 +1199,9 @@ static struct train_params get_default_train_params() {
|
|||
params.custom_n_rank_wv = false;
|
||||
params.custom_n_rank_wo = false;
|
||||
params.custom_n_rank_ffn_norm = false;
|
||||
params.custom_n_rank_w1 = false;
|
||||
params.custom_n_rank_w2 = false;
|
||||
params.custom_n_rank_w3 = false;
|
||||
params.custom_n_rank_ffn_gate = false;
|
||||
params.custom_n_rank_ffn_down = false;
|
||||
params.custom_n_rank_ffn_up = false;
|
||||
params.custom_n_rank_tok_embeddings = false;
|
||||
params.custom_n_rank_norm = false;
|
||||
params.custom_n_rank_output = false;
|
||||
|
@ -1232,9 +1232,9 @@ static void train_print_usage(int argc, char ** argv, const struct train_params
|
|||
fprintf(stderr, " --rank-wk N LORA rank for wk tensor, overrides default rank.\n");
|
||||
fprintf(stderr, " --rank-wv N LORA rank for wv tensor, overrides default rank.\n");
|
||||
fprintf(stderr, " --rank-wo N LORA rank for wo tensor, overrides default rank.\n");
|
||||
fprintf(stderr, " --rank-w1 N LORA rank for w1 tensor, overrides default rank.\n");
|
||||
fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor, overrides default rank.\n");
|
||||
fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor, overrides default rank.\n");
|
||||
fprintf(stderr, " --rank-ffn_gate N LORA rank for ffn_gate tensor, overrides default rank.\n");
|
||||
fprintf(stderr, " --rank-ffn_down N LORA rank for ffn_down tensor, overrides default rank.\n");
|
||||
fprintf(stderr, " --rank-ffn_up N LORA rank for ffn_up tensor, overrides default rank.\n");
|
||||
|
||||
print_common_train_usage(argc, argv, ¶ms->common);
|
||||
}
|
||||
|
@ -1369,27 +1369,27 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
|
|||
}
|
||||
params->n_rank_wo = std::stoi(argv[i]);
|
||||
params->custom_n_rank_wo = true;
|
||||
} else if (arg == "--rank-w1") {
|
||||
} else if (arg == "--rank-ffn_gate") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->n_rank_w1 = std::stoi(argv[i]);
|
||||
params->custom_n_rank_w1 = true;
|
||||
} else if (arg == "--rank-w2") {
|
||||
params->n_rank_ffn_gate = std::stoi(argv[i]);
|
||||
params->custom_n_rank_ffn_gate = true;
|
||||
} else if (arg == "--rank-ffn_down") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->n_rank_w2 = std::stoi(argv[i]);
|
||||
params->custom_n_rank_w2 = true;
|
||||
} else if (arg == "--rank-w3") {
|
||||
params->n_rank_ffn_down = std::stoi(argv[i]);
|
||||
params->custom_n_rank_ffn_down = true;
|
||||
} else if (arg == "--rank-ffn_up") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->n_rank_w3 = std::stoi(argv[i]);
|
||||
params->custom_n_rank_w3 = true;
|
||||
params->n_rank_ffn_up = std::stoi(argv[i]);
|
||||
params->custom_n_rank_ffn_up = true;
|
||||
} else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
train_print_usage(argc, argv, &default_params);
|
||||
|
@ -1452,12 +1452,12 @@ static int64_t get_parameter_count(struct my_llama_lora* lora) {
|
|||
nx += ggml_nelements(layer.wo_b);
|
||||
nx += ggml_nelements(layer.ffn_norm_a);
|
||||
nx += ggml_nelements(layer.ffn_norm_b);
|
||||
nx += ggml_nelements(layer.w1_a);
|
||||
nx += ggml_nelements(layer.w1_b);
|
||||
nx += ggml_nelements(layer.w2_a);
|
||||
nx += ggml_nelements(layer.w2_b);
|
||||
nx += ggml_nelements(layer.w3_a);
|
||||
nx += ggml_nelements(layer.w3_b);
|
||||
nx += ggml_nelements(layer.ffn_gate_a);
|
||||
nx += ggml_nelements(layer.ffn_gate_b);
|
||||
nx += ggml_nelements(layer.ffn_down_a);
|
||||
nx += ggml_nelements(layer.ffn_down_b);
|
||||
nx += ggml_nelements(layer.ffn_up_a);
|
||||
nx += ggml_nelements(layer.ffn_up_b);
|
||||
}
|
||||
return nx;
|
||||
}
|
||||
|
@ -1511,9 +1511,9 @@ int main(int argc, char ** argv) {
|
|||
uint32_t n_rank_wv = params.custom_n_rank_wv ? params.n_rank_wv : params.lora_r;
|
||||
uint32_t n_rank_wo = params.custom_n_rank_wo ? params.n_rank_wo : params.lora_r;
|
||||
uint32_t n_rank_ffn_norm = params.custom_n_rank_ffn_norm ? params.n_rank_ffn_norm : 1;
|
||||
uint32_t n_rank_w1 = params.custom_n_rank_w1 ? params.n_rank_w1 : params.lora_r;
|
||||
uint32_t n_rank_w2 = params.custom_n_rank_w2 ? params.n_rank_w2 : params.lora_r;
|
||||
uint32_t n_rank_w3 = params.custom_n_rank_w3 ? params.n_rank_w3 : params.lora_r;
|
||||
uint32_t n_rank_ffn_gate = params.custom_n_rank_ffn_gate ? params.n_rank_ffn_gate : params.lora_r;
|
||||
uint32_t n_rank_ffn_down = params.custom_n_rank_ffn_down ? params.n_rank_ffn_down : params.lora_r;
|
||||
uint32_t n_rank_ffn_up = params.custom_n_rank_ffn_up ? params.n_rank_ffn_up : params.lora_r;
|
||||
uint32_t n_rank_tok_embeddings = params.custom_n_rank_tok_embeddings ? params.n_rank_tok_embeddings : params.lora_r;
|
||||
uint32_t n_rank_norm = params.custom_n_rank_norm ? params.n_rank_norm : 1;
|
||||
uint32_t n_rank_output = params.custom_n_rank_output ? params.n_rank_output : params.lora_r;
|
||||
|
@ -1523,9 +1523,9 @@ int main(int argc, char ** argv) {
|
|||
lora.hparams.n_rank_wv = n_rank_wv;
|
||||
lora.hparams.n_rank_wo = n_rank_wo;
|
||||
lora.hparams.n_rank_ffn_norm = n_rank_ffn_norm;
|
||||
lora.hparams.n_rank_w1 = n_rank_w1;
|
||||
lora.hparams.n_rank_w2 = n_rank_w2;
|
||||
lora.hparams.n_rank_w3 = n_rank_w3;
|
||||
lora.hparams.n_rank_ffn_gate = n_rank_ffn_gate;
|
||||
lora.hparams.n_rank_ffn_down = n_rank_ffn_down;
|
||||
lora.hparams.n_rank_ffn_up = n_rank_ffn_up;
|
||||
lora.hparams.n_rank_tok_embeddings = n_rank_tok_embeddings;
|
||||
lora.hparams.n_rank_norm = n_rank_norm;
|
||||
lora.hparams.n_rank_output = n_rank_output;
|
||||
|
@ -1566,9 +1566,9 @@ int main(int argc, char ** argv) {
|
|||
|| (lora.hparams.n_rank_wv != n_rank_wv)
|
||||
|| (lora.hparams.n_rank_wo != n_rank_wo)
|
||||
|| (lora.hparams.n_rank_ffn_norm != n_rank_ffn_norm)
|
||||
|| (lora.hparams.n_rank_w1 != n_rank_w1)
|
||||
|| (lora.hparams.n_rank_w2 != n_rank_w2)
|
||||
|| (lora.hparams.n_rank_w3 != n_rank_w3)
|
||||
|| (lora.hparams.n_rank_ffn_gate != n_rank_ffn_gate)
|
||||
|| (lora.hparams.n_rank_ffn_down != n_rank_ffn_down)
|
||||
|| (lora.hparams.n_rank_ffn_up != n_rank_ffn_up)
|
||||
|| (lora.hparams.n_rank_tok_embeddings != n_rank_tok_embeddings)
|
||||
|| (lora.hparams.n_rank_norm != n_rank_norm)
|
||||
|| (lora.hparams.n_rank_output != n_rank_output)
|
||||
|
|
|
@ -568,7 +568,8 @@ int main(int argc, char ** argv) {
|
|||
params.prompt = gpt_random_prompt(rng);
|
||||
}
|
||||
|
||||
llama_backend_init(params.numa);
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
llama_model_params mparams = llama_model_params_from_gpt_params(params);
|
||||
|
||||
|
|
|
@ -202,7 +202,8 @@ int main(int argc, char ** argv) {
|
|||
std::mt19937 rng(params.seed);
|
||||
|
||||
LOG("%s: llama backend init\n", __func__);
|
||||
llama_backend_init(params.numa);
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
llama_model * model;
|
||||
llama_context * ctx;
|
||||
|
|
|
@ -1151,8 +1151,7 @@ int main(int argc, char ** argv) {
|
|||
if (!params.verbose) {
|
||||
llama_log_set(llama_null_log_callback, NULL);
|
||||
}
|
||||
bool numa = false;
|
||||
llama_backend_init(numa);
|
||||
llama_backend_init();
|
||||
|
||||
// initialize printer
|
||||
std::unique_ptr<printer> p;
|
||||
|
|
|
@ -274,8 +274,8 @@ Java_com_example_llama_Llm_new_1batch(JNIEnv *, jobject, jint n_tokens, jint emb
|
|||
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_example_llama_Llm_backend_1init(JNIEnv *, jobject, jboolean numa) {
|
||||
llama_backend_init(numa);
|
||||
Java_com_example_llama_Llm_backend_1init(JNIEnv *, jobject) {
|
||||
llama_backend_init();
|
||||
}
|
||||
|
||||
extern "C"
|
||||
|
|
|
@ -51,7 +51,7 @@ actor LlamaContext {
|
|||
}
|
||||
|
||||
static func create_context(path: String) throws -> LlamaContext {
|
||||
llama_backend_init(false)
|
||||
llama_backend_init()
|
||||
var model_params = llama_model_default_params()
|
||||
|
||||
#if targetEnvironment(simulator)
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
# LLaVA
|
||||
|
||||
Currently this implementation supports [llava-v1.5](https://huggingface.co/liuhaotian/llava-v1.5-7b) variants.
|
||||
Currently this implementation supports [llava-v1.5](https://huggingface.co/liuhaotian/llava-v1.5-7b) variants,
|
||||
as well as llava-1.6 [llava-v1.6](https://huggingface.co/collections/liuhaotian/llava-16-65b9e40155f60fd046a5ccf2) variants.
|
||||
|
||||
The pre-converted [7b](https://huggingface.co/mys/ggml_llava-v1.5-7b)
|
||||
and [13b](https://huggingface.co/mys/ggml_llava-v1.5-13b)
|
||||
models are available.
|
||||
For llava-1.6 a variety of prepared gguf models are available as well [7b-34b](https://huggingface.co/cmp-nct/llava-1.6-gguf)
|
||||
|
||||
After API is confirmed, more models will be supported / uploaded.
|
||||
|
||||
|
@ -18,10 +20,11 @@ After building, run: `./llava-cli` to see the usage. For example:
|
|||
```
|
||||
|
||||
**note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so.
|
||||
**note**: For GPU offloading ensure to use the `-ngl` flag just like usual
|
||||
|
||||
## Model conversion
|
||||
## LLaVA 1.5
|
||||
|
||||
- Clone `llava-v15-7b` and `clip-vit-large-patch14-336` locally:
|
||||
- Clone a LLaVA and a CLIP model ([available options](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)). For example:
|
||||
|
||||
```sh
|
||||
git clone https://huggingface.co/liuhaotian/llava-v1.5-7b
|
||||
|
@ -50,13 +53,54 @@ python ./examples/llava/convert-image-encoder-to-gguf.py -m ../clip-vit-large-pa
|
|||
5. Use `convert.py` to convert the LLaMA part of LLaVA to GGUF:
|
||||
|
||||
```sh
|
||||
python ./convert.py ../llava-v1.5-7b
|
||||
python ./convert.py ../llava-v1.5-7b --skip-unknown
|
||||
```
|
||||
|
||||
Now both the LLaMA part and the image encoder is in the `llava-v1.5-7b` directory.
|
||||
|
||||
## LLaVA 1.6 gguf conversion
|
||||
|
||||
1) Backup your pth/safetensor model files as llava-surgery modifies them
|
||||
2) Use `python llava-surgery-v2.py -C -m /path/to/hf-model` which also supports llava-1.5 variants pytorch as well as safetensor models:
|
||||
- you will find a llava.projector and a llava.clip file in your model directory
|
||||
3) Copy the llava.clip file into a subdirectory (like vit), rename it to pytorch_model.bin and add a fitting vit configuration to the directory (https://huggingface.co/cmp-nct/llava-1.6-gguf/blob/main/config_vit.json) and rename it to config.json.
|
||||
4) Create the visual gguf model: `python ./examples/llava/convert-image-encoder-to-gguf.py -m ../path/to/vit --llava-projector ../path/to/llava.projector --output-dir ../path/to/output --clip-model-is-vision`
|
||||
- This is similar to llava-1.5, the difference is that we tell the encoder that we are working with the pure vision model part of CLIP
|
||||
5) Everything else as usual: convert.py the hf model, quantize as needed
|
||||
**note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096)
|
||||
**note** llava-1.6 greatly benefits from batched prompt processing (defaults work)
|
||||
|
||||
## llava-cli templating and llava-1.6 prompting
|
||||
|
||||
llava-1.5 models all use the same vicuna prompt, here you can just add your image question like `-p "Provide a full description."`
|
||||
For llava-1.5 models which are not vicuna (mistral and Yi) you need to adapt system prompt as well as user prompt, for this purpose llava-cli has a basic templating system:
|
||||
|
||||
**For Mistral and using llava-cli binary:**
|
||||
Add this: `-p "<image>\nUSER:\nProvide a full description.\nASSISTANT:\n"`
|
||||
The mistral template for llava-1.6 seems to be no system print and a USER/ASSISTANT role
|
||||
|
||||
**For the 34B this should work:**
|
||||
Add this: `-e -p <|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nProvide a full description.<|im_end|><|im_start|>assistant\n`
|
||||
|
||||
|
||||
## How to know if you are running in llava-1.5 or llava-1.6 mode
|
||||
|
||||
When running llava-cli you will see a visual information right before the prompt is being processed:
|
||||
|
||||
**Llava-1.5:**
|
||||
`encode_image_with_clip: image embedding created: 576 tokens`
|
||||
|
||||
**Llava-1.6 (anything above 576):**
|
||||
`encode_image_with_clip: image embedding created: 2880 tokens`
|
||||
|
||||
|
||||
Alternatively just pay notice to how many "tokens" have been used for your prompt, it will also show 1000+ tokens for llava-1.6
|
||||
|
||||
|
||||
|
||||
|
||||
## TODO
|
||||
|
||||
- [ ] Support non-CPU backend for the image encoding part.
|
||||
- [x] Support non-CPU backend for the image encoding part.
|
||||
- [ ] Support different sampling methods.
|
||||
- [ ] Support more model variants.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// NOTE: This is modified from clip.cpp only for LLaVA,
|
||||
// so there might be still unnecessary artifacts hanging around
|
||||
// I'll gradually clean and extend it
|
||||
|
||||
// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
|
||||
#include "clip.h"
|
||||
#include "ggml.h"
|
||||
#include "ggml-alloc.h"
|
||||
|
@ -30,6 +30,26 @@
|
|||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <cinttypes>
|
||||
#include <limits>
|
||||
|
||||
//#define CLIP_DEBUG_FUNCTIONS
|
||||
|
||||
// RGB uint8 image
|
||||
struct clip_image_u8 {
|
||||
int nx;
|
||||
int ny;
|
||||
|
||||
std::vector<uint8_t> buf;
|
||||
};
|
||||
|
||||
// RGB float32 image (NHWC)
|
||||
// Memory layout: RGBRGBRGB...
|
||||
struct clip_image_f32 {
|
||||
int nx;
|
||||
int ny;
|
||||
|
||||
std::vector<float> buf;
|
||||
};
|
||||
|
||||
static std::string format(const char * fmt, ...) {
|
||||
va_list ap;
|
||||
|
@ -50,50 +70,56 @@ static std::string format(const char * fmt, ...) {
|
|||
// key constants
|
||||
//
|
||||
|
||||
#define KEY_FTYPE "general.file_type"
|
||||
#define KEY_NAME "general.name"
|
||||
#define KEY_DESCRIPTION "general.description"
|
||||
#define KEY_HAS_TEXT_ENC "clip.has_text_encoder"
|
||||
#define KEY_HAS_VIS_ENC "clip.has_vision_encoder"
|
||||
#define KEY_FTYPE "general.file_type"
|
||||
#define KEY_NAME "general.name"
|
||||
#define KEY_DESCRIPTION "general.description"
|
||||
#define KEY_HAS_TEXT_ENC "clip.has_text_encoder"
|
||||
#define KEY_HAS_VIS_ENC "clip.has_vision_encoder"
|
||||
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
|
||||
#define KEY_USE_GELU "clip.use_gelu"
|
||||
#define KEY_N_EMBD "clip.%s.embedding_length"
|
||||
#define KEY_N_FF "clip.%s.feed_forward_length"
|
||||
#define KEY_N_BLOCK "clip.%s.block_count"
|
||||
#define KEY_N_HEAD "clip.%s.attention.head_count"
|
||||
#define KEY_USE_GELU "clip.use_gelu"
|
||||
#define KEY_N_EMBD "clip.%s.embedding_length"
|
||||
#define KEY_N_FF "clip.%s.feed_forward_length"
|
||||
#define KEY_N_BLOCK "clip.%s.block_count"
|
||||
#define KEY_N_HEAD "clip.%s.attention.head_count"
|
||||
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
|
||||
#define KEY_PROJ_DIM "clip.%s.projection_dim"
|
||||
#define KEY_TOKENS "tokenizer.ggml.tokens"
|
||||
#define KEY_N_POSITIONS "clip.text.context_length"
|
||||
#define KEY_IMAGE_SIZE "clip.vision.image_size"
|
||||
#define KEY_PATCH_SIZE "clip.vision.patch_size"
|
||||
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
||||
#define KEY_IMAGE_STD "clip.vision.image_std"
|
||||
#define KEY_PROJ_TYPE "clip.projector_type"
|
||||
#define KEY_PROJ_DIM "clip.%s.projection_dim"
|
||||
#define KEY_TOKENS "tokenizer.ggml.tokens"
|
||||
#define KEY_N_POSITIONS "clip.text.context_length"
|
||||
#define KEY_IMAGE_SIZE "clip.vision.image_size"
|
||||
#define KEY_PATCH_SIZE "clip.vision.patch_size"
|
||||
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
||||
#define KEY_IMAGE_STD "clip.vision.image_std"
|
||||
#define KEY_PROJ_TYPE "clip.projector_type"
|
||||
|
||||
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
||||
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
|
||||
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
|
||||
|
||||
|
||||
//
|
||||
// tensor name constants
|
||||
//
|
||||
|
||||
#define TN_TOKEN_EMBD "%s.token_embd.weight"
|
||||
#define TN_POS_EMBD "%s.position_embd.weight"
|
||||
#define TN_CLASS_EMBD "v.class_embd"
|
||||
#define TN_PATCH_EMBD "v.patch_embd.weight"
|
||||
#define TN_ATTN_K "%s.blk.%d.attn_k.%s"
|
||||
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
|
||||
#define TN_ATTN_V "%s.blk.%d.attn_v.%s"
|
||||
#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s"
|
||||
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
|
||||
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
|
||||
#define TN_LN_1 "%s.blk.%d.ln1.%s"
|
||||
#define TN_LN_2 "%s.blk.%d.ln2.%s"
|
||||
#define TN_LN_PRE "%s.pre_ln.%s"
|
||||
#define TN_LN_POST "%s.post_ln.%s"
|
||||
#define TN_TEXT_PROJ "text_projection.weight"
|
||||
#define TN_VIS_PROJ "visual_projection.weight"
|
||||
#define TN_LLAVA_PROJ "mm.%d.%s"
|
||||
#define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s"
|
||||
#define TN_TOKEN_EMBD "%s.token_embd.weight"
|
||||
#define TN_POS_EMBD "%s.position_embd.weight"
|
||||
#define TN_CLASS_EMBD "v.class_embd"
|
||||
#define TN_PATCH_EMBD "v.patch_embd.weight"
|
||||
#define TN_ATTN_K "%s.blk.%d.attn_k.%s"
|
||||
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
|
||||
#define TN_ATTN_V "%s.blk.%d.attn_v.%s"
|
||||
#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s"
|
||||
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
|
||||
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
|
||||
#define TN_LN_1 "%s.blk.%d.ln1.%s"
|
||||
#define TN_LN_2 "%s.blk.%d.ln2.%s"
|
||||
#define TN_LN_PRE "%s.pre_ln.%s"
|
||||
#define TN_LN_POST "%s.post_ln.%s"
|
||||
#define TN_TEXT_PROJ "text_projection.weight"
|
||||
#define TN_VIS_PROJ "visual_projection.weight"
|
||||
#define TN_LLAVA_PROJ "mm.%d.%s"
|
||||
#define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s"
|
||||
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
|
||||
#define TN_IMAGE_NEWLINE "model.image_newline"
|
||||
|
||||
|
||||
enum projector_type {
|
||||
|
@ -104,8 +130,8 @@ enum projector_type {
|
|||
};
|
||||
|
||||
static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
{ PROJECTOR_TYPE_MLP, "mlp" },
|
||||
{ PROJECTOR_TYPE_LDP, "ldp" },
|
||||
{ PROJECTOR_TYPE_MLP, "mlp" },
|
||||
{ PROJECTOR_TYPE_LDP, "ldp" },
|
||||
};
|
||||
|
||||
|
||||
|
@ -165,7 +191,6 @@ static std::string gguf_data_to_str(enum gguf_type type, const void * data, int
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
||||
std::string result;
|
||||
for (size_t pos = 0; ; pos += search.length()) {
|
||||
|
@ -217,7 +242,7 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
|
|||
}
|
||||
}
|
||||
|
||||
static void print_tensor_info(const ggml_tensor* tensor, const char* prefix = "") {
|
||||
static void print_tensor_info(const ggml_tensor * tensor, const char * prefix = "") {
|
||||
size_t tensor_size = ggml_nbytes(tensor);
|
||||
printf("%s: n_dims = %d, name = %s, tensor_size=%zu, shape:[%" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "], type = %s\n",
|
||||
prefix, ggml_n_dims(tensor), tensor->name, tensor_size,
|
||||
|
@ -233,31 +258,136 @@ static projector_type clip_projector_type_from_string(const std::string & name)
|
|||
return PROJECTOR_TYPE_UNKNOWN;
|
||||
}
|
||||
|
||||
//
|
||||
// image data
|
||||
//
|
||||
#ifdef CLIP_DEBUG_FUNCTIONS
|
||||
static void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::string& filename) {
|
||||
std::ofstream file(filename, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
std::cerr << "Failed to open file for writing: " << filename << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
// RGB uint8 image
|
||||
struct clip_image_u8 {
|
||||
int nx;
|
||||
int ny;
|
||||
// PPM header: P6 format, width, height, and max color value
|
||||
file << "P6\n" << img.nx << " " << img.ny << "\n255\n";
|
||||
|
||||
std::vector<uint8_t> buf;
|
||||
};
|
||||
// Write pixel data
|
||||
for (size_t i = 0; i < img.buf.size(); i += 3) {
|
||||
// PPM expects binary data in RGB format, which matches our image buffer
|
||||
file.write(reinterpret_cast<const char*>(&img.buf[i]), 3);
|
||||
}
|
||||
|
||||
// RGB float32 image (NHWC)
|
||||
// Memory layout: RGBRGBRGB...
|
||||
struct clip_image_f32 {
|
||||
int nx;
|
||||
int ny;
|
||||
file.close();
|
||||
}
|
||||
|
||||
static void clip_image_save_to_bmp(const clip_image_u8& img, const std::string& filename) {
|
||||
std::ofstream file(filename, std::ios::binary);
|
||||
if (!file.is_open()) {
|
||||
std::cerr << "Failed to open file for writing: " << filename << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
int fileSize = 54 + 3 * img.nx * img.ny; // File header + info header + pixel data
|
||||
int bytesPerPixel = 3;
|
||||
int widthInBytes = img.nx * bytesPerPixel;
|
||||
int paddingAmount = (4 - (widthInBytes % 4)) % 4;
|
||||
int stride = widthInBytes + paddingAmount;
|
||||
|
||||
// Bitmap file header
|
||||
unsigned char fileHeader[14] = {
|
||||
'B','M', // Signature
|
||||
0,0,0,0, // Image file size in bytes
|
||||
0,0,0,0, // Reserved
|
||||
54,0,0,0 // Start of pixel array
|
||||
};
|
||||
|
||||
// Total file size
|
||||
fileSize = 54 + (stride * img.ny);
|
||||
fileHeader[2] = (unsigned char)(fileSize);
|
||||
fileHeader[3] = (unsigned char)(fileSize >> 8);
|
||||
fileHeader[4] = (unsigned char)(fileSize >> 16);
|
||||
fileHeader[5] = (unsigned char)(fileSize >> 24);
|
||||
|
||||
// Bitmap information header (BITMAPINFOHEADER)
|
||||
unsigned char infoHeader[40] = {
|
||||
40,0,0,0, // Size of this header (40 bytes)
|
||||
0,0,0,0, // Image width
|
||||
0,0,0,0, // Image height
|
||||
1,0, // Number of color planes
|
||||
24,0, // Bits per pixel
|
||||
0,0,0,0, // No compression
|
||||
0,0,0,0, // Image size (can be 0 for no compression)
|
||||
0,0,0,0, // X pixels per meter (not specified)
|
||||
0,0,0,0, // Y pixels per meter (not specified)
|
||||
0,0,0,0, // Total colors (color table not used)
|
||||
0,0,0,0 // Important colors (all are important)
|
||||
};
|
||||
|
||||
// Width and height in the information header
|
||||
infoHeader[4] = (unsigned char)(img.nx);
|
||||
infoHeader[5] = (unsigned char)(img.nx >> 8);
|
||||
infoHeader[6] = (unsigned char)(img.nx >> 16);
|
||||
infoHeader[7] = (unsigned char)(img.nx >> 24);
|
||||
infoHeader[8] = (unsigned char)(img.ny);
|
||||
infoHeader[9] = (unsigned char)(img.ny >> 8);
|
||||
infoHeader[10] = (unsigned char)(img.ny >> 16);
|
||||
infoHeader[11] = (unsigned char)(img.ny >> 24);
|
||||
|
||||
// Write file headers
|
||||
file.write(reinterpret_cast<char*>(fileHeader), sizeof(fileHeader));
|
||||
file.write(reinterpret_cast<char*>(infoHeader), sizeof(infoHeader));
|
||||
|
||||
// Pixel data
|
||||
std::vector<unsigned char> padding(3, 0); // Max padding size to be added to each row
|
||||
for (int y = img.ny - 1; y >= 0; --y) { // BMP files are stored bottom-to-top
|
||||
for (int x = 0; x < img.nx; ++x) {
|
||||
// Each pixel
|
||||
size_t pixelIndex = (y * img.nx + x) * 3;
|
||||
unsigned char pixel[3] = {
|
||||
img.buf[pixelIndex + 2], // BMP stores pixels in BGR format
|
||||
img.buf[pixelIndex + 1],
|
||||
img.buf[pixelIndex]
|
||||
};
|
||||
file.write(reinterpret_cast<char*>(pixel), 3);
|
||||
}
|
||||
// Write padding for the row
|
||||
file.write(reinterpret_cast<char*>(padding.data()), paddingAmount);
|
||||
}
|
||||
|
||||
file.close();
|
||||
}
|
||||
|
||||
// debug function to convert f32 to u8
|
||||
static void clip_image_convert_f32_to_u8(const clip_image_f32& src, clip_image_u8& dst) {
|
||||
dst.nx = src.nx;
|
||||
dst.ny = src.ny;
|
||||
dst.buf.resize(3 * src.nx * src.ny);
|
||||
for (size_t i = 0; i < src.buf.size(); ++i) {
|
||||
dst.buf[i] = static_cast<uint8_t>(std::min(std::max(int(src.buf[i] * 255.0f), 0), 255));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
std::vector<float> buf;
|
||||
};
|
||||
|
||||
//
|
||||
// clip layers
|
||||
//
|
||||
|
||||
struct clip_hparams {
|
||||
int32_t image_size;
|
||||
int32_t patch_size;
|
||||
int32_t hidden_size;
|
||||
int32_t n_intermediate;
|
||||
int32_t projection_dim;
|
||||
int32_t n_head;
|
||||
int32_t n_layer;
|
||||
|
||||
float eps;
|
||||
|
||||
char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default)
|
||||
|
||||
int32_t image_grid_pinpoints[32];
|
||||
int32_t image_crop_resolution;
|
||||
};
|
||||
|
||||
struct clip_layer {
|
||||
// attention
|
||||
struct ggml_tensor * k_w;
|
||||
|
@ -287,7 +417,7 @@ struct clip_layer {
|
|||
};
|
||||
|
||||
struct clip_vision_model {
|
||||
struct clip_vision_hparams hparams;
|
||||
struct clip_hparams hparams;
|
||||
|
||||
// embeddings
|
||||
struct ggml_tensor * class_embedding;
|
||||
|
@ -310,6 +440,8 @@ struct clip_vision_model {
|
|||
struct ggml_tensor * mm_2_w = NULL;
|
||||
struct ggml_tensor * mm_2_b = NULL;
|
||||
|
||||
struct ggml_tensor * image_newline = NULL;
|
||||
|
||||
// Yi type models with mlp+normalization projection
|
||||
struct ggml_tensor * mm_1_w = NULL; // Yi type models have 0, 1, 3, 4
|
||||
struct ggml_tensor * mm_1_b = NULL;
|
||||
|
@ -364,9 +496,10 @@ struct clip_ctx {
|
|||
std::vector<uint8_t> buf_compute_meta;
|
||||
|
||||
// memory buffers to evaluate the model
|
||||
ggml_backend_buffer_t params_buffer = NULL;
|
||||
ggml_backend_buffer_t params_buffer = NULL;
|
||||
ggml_backend_buffer_t compute_buffer = NULL;
|
||||
ggml_backend_t backend = NULL;
|
||||
|
||||
ggml_backend_t backend = NULL;
|
||||
ggml_gallocr_t compute_alloc = NULL;
|
||||
};
|
||||
|
||||
|
@ -379,18 +512,19 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
const auto & model = ctx->vision_model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int image_size = hparams.image_size;
|
||||
const int patch_size = hparams.patch_size;
|
||||
const int num_patches = ((image_size / patch_size) * (image_size / patch_size));
|
||||
const int num_positions = num_patches + 1;
|
||||
const int hidden_size = hparams.hidden_size;
|
||||
const int n_head = hparams.n_head;
|
||||
const int d_head = hidden_size / n_head;
|
||||
const int n_layer = hparams.n_layer;
|
||||
//const int n_intermediate = hparams.n_intermediate;
|
||||
//const int projection_dim = hparams.projection_dim;
|
||||
const float eps = hparams.eps;
|
||||
int batch_size = imgs->size;
|
||||
const int image_size = hparams.image_size;
|
||||
const int patch_size = hparams.patch_size;
|
||||
const int num_patches = ((image_size / patch_size) * (image_size / patch_size));
|
||||
const int num_patches_per_side = image_size / patch_size; GGML_UNUSED(num_patches_per_side);
|
||||
const int num_positions = num_patches + 1;
|
||||
const int hidden_size = hparams.hidden_size;
|
||||
const int n_head = hparams.n_head;
|
||||
const int d_head = hidden_size / n_head;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const float eps = hparams.eps;
|
||||
|
||||
const int batch_size = imgs->size;
|
||||
|
||||
if (ctx->has_llava_projector) {
|
||||
GGML_ASSERT(batch_size == 1);
|
||||
}
|
||||
|
@ -540,7 +674,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
||||
|
||||
embeddings = ggml_gelu(ctx0, embeddings);
|
||||
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
|
||||
|
||||
|
@ -791,10 +924,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
if (idx != -1) {
|
||||
const std::string proj_type = gguf_get_val_str(ctx, idx);
|
||||
new_clip->proj_type = clip_projector_type_from_string(proj_type);
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
new_clip->proj_type = PROJECTOR_TYPE_MLP;
|
||||
}
|
||||
|
||||
if (new_clip->proj_type == PROJECTOR_TYPE_MLP) {
|
||||
if (gguf_find_tensor(ctx, format(TN_LLAVA_PROJ, 3, "weight").c_str()) != -1) {
|
||||
new_clip->proj_type = PROJECTOR_TYPE_MLP_NORM;
|
||||
|
@ -920,11 +1053,41 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
hparams.projection_dim = get_u32(ctx, format(KEY_PROJ_DIM, "vision"));
|
||||
hparams.eps = get_f32(ctx, format(KEY_LAYER_NORM_EPS, "vision"));
|
||||
|
||||
try {
|
||||
int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS);
|
||||
int n = gguf_get_arr_n(ctx, idx);
|
||||
const int32_t * pinpoints = (const int32_t *)gguf_get_arr_data(ctx, idx);
|
||||
for (int i = 0; i < 32 && i < n && pinpoints[i] != 0; ++i) {
|
||||
hparams.image_grid_pinpoints[i] = pinpoints[i];
|
||||
}
|
||||
if (n < 32)
|
||||
hparams.image_grid_pinpoints[n] = 0;
|
||||
} catch (std::runtime_error & e) {
|
||||
hparams.image_grid_pinpoints[0]=0;
|
||||
}
|
||||
|
||||
try {
|
||||
int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE);
|
||||
strcpy(hparams.mm_patch_merge_type, gguf_get_val_str(ctx, idx));
|
||||
} catch (std::runtime_error & e) {
|
||||
strcpy(hparams.mm_patch_merge_type, "flat");
|
||||
}
|
||||
|
||||
try {
|
||||
hparams.image_crop_resolution = get_u32(ctx, KEY_IMAGE_CROP_RESOLUTION); // llava-1.6
|
||||
} catch(const std::exception& e) {
|
||||
hparams.image_crop_resolution = hparams.image_size;
|
||||
}
|
||||
|
||||
int idx_mean = get_key_idx(ctx, KEY_IMAGE_MEAN);
|
||||
int idx_std = get_key_idx(ctx, KEY_IMAGE_STD);
|
||||
|
||||
const float * mean_data = (const float *)gguf_get_arr_data(ctx, idx_mean);
|
||||
const float * std_data = (const float *)gguf_get_arr_data(ctx, idx_std);
|
||||
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
new_clip->image_mean[i] = *((const float *)gguf_get_arr_data(ctx, idx_mean));
|
||||
new_clip->image_std[i] = *((const float *)gguf_get_arr_data(ctx, idx_std));
|
||||
new_clip->image_mean[i] = mean_data[i];
|
||||
new_clip->image_std[i] = std_data[i];
|
||||
}
|
||||
|
||||
if (verbosity >= 2) {
|
||||
|
@ -936,13 +1099,27 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
printf("v_projection_dim %d\n", hparams.projection_dim);
|
||||
printf("v_n_head %d\n", hparams.n_head);
|
||||
printf("v_n_layer %d\n", hparams.n_layer);
|
||||
printf("v_eps %f\n", hparams.eps);
|
||||
printf("v_image_mean %f %f %f\n", new_clip->image_mean[0], new_clip->image_mean[1], new_clip->image_mean[2]);
|
||||
printf("v_image_std %f %f %f\n", new_clip->image_std[0], new_clip->image_std[1], new_clip->image_std[2]);
|
||||
printf("v_image_grid_pinpoints: ");
|
||||
for (int i = 0; i < 32 && (hparams.image_grid_pinpoints[i] != 0); ++i) {
|
||||
printf("%d ", hparams.image_grid_pinpoints[i]);
|
||||
}
|
||||
printf("\n");
|
||||
printf("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type);
|
||||
|
||||
}
|
||||
|
||||
vision_model.patch_embeddings = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
|
||||
vision_model.class_embedding = get_tensor(new_clip->ctx_data, TN_CLASS_EMBD);
|
||||
vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v"));
|
||||
vision_model.pre_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "weight"));
|
||||
vision_model.pre_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "bias"));
|
||||
try {
|
||||
vision_model.patch_embeddings = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
|
||||
vision_model.class_embedding = get_tensor(new_clip->ctx_data, TN_CLASS_EMBD);
|
||||
vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v"));
|
||||
vision_model.pre_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "weight"));
|
||||
vision_model.pre_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "bias"));
|
||||
} catch(const std::exception& e) {
|
||||
fprintf(stderr, "%s: failed to load vision model tensors\n", __func__);
|
||||
}
|
||||
|
||||
// LLaVA projection
|
||||
if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
||||
|
@ -968,40 +1145,43 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
vision_model.mm_4_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 4, "weight"));
|
||||
vision_model.mm_4_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 4, "bias"));
|
||||
} catch (std::runtime_error & e) { }
|
||||
}
|
||||
else if (new_clip->proj_type == PROJECTOR_TYPE_LDP) {
|
||||
try {
|
||||
vision_model.image_newline = get_tensor(new_clip->ctx_data, TN_IMAGE_NEWLINE);
|
||||
// fprintf(stderr, "%s: image_newline tensor (llava-1.6) found\n", __func__);
|
||||
} catch (std::runtime_error & e) { }
|
||||
} else if (new_clip->proj_type == PROJECTOR_TYPE_LDP) {
|
||||
// MobileVLM projection
|
||||
vision_model.mm_model_mlp_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 1, "weight"));
|
||||
vision_model.mm_model_mlp_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 1, "bias"));
|
||||
vision_model.mm_model_mlp_3_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 3, "weight"));
|
||||
vision_model.mm_model_mlp_3_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 3, "bias"));
|
||||
vision_model.mm_model_block_1_block_0_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 0, "0.weight"));
|
||||
vision_model.mm_model_block_1_block_0_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.weight"));
|
||||
vision_model.mm_model_block_1_block_0_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.bias"));
|
||||
vision_model.mm_model_mlp_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 1, "weight"));
|
||||
vision_model.mm_model_mlp_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 1, "bias"));
|
||||
vision_model.mm_model_mlp_3_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 3, "weight"));
|
||||
vision_model.mm_model_mlp_3_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 3, "bias"));
|
||||
vision_model.mm_model_block_1_block_0_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 0, "0.weight"));
|
||||
vision_model.mm_model_block_1_block_0_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.weight"));
|
||||
vision_model.mm_model_block_1_block_0_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.bias"));
|
||||
vision_model.mm_model_block_1_block_1_fc1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.weight"));
|
||||
vision_model.mm_model_block_1_block_1_fc1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.bias"));
|
||||
vision_model.mm_model_block_1_block_1_fc2_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.weight"));
|
||||
vision_model.mm_model_block_1_block_1_fc2_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.bias"));
|
||||
vision_model.mm_model_block_1_block_2_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 2, "0.weight"));
|
||||
vision_model.mm_model_block_1_block_2_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.weight"));
|
||||
vision_model.mm_model_block_1_block_2_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.bias"));
|
||||
vision_model.mm_model_block_2_block_0_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 0, "0.weight"));
|
||||
vision_model.mm_model_block_2_block_0_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.weight"));
|
||||
vision_model.mm_model_block_2_block_0_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.bias"));
|
||||
vision_model.mm_model_block_1_block_2_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 2, "0.weight"));
|
||||
vision_model.mm_model_block_1_block_2_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.weight"));
|
||||
vision_model.mm_model_block_1_block_2_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.bias"));
|
||||
vision_model.mm_model_block_2_block_0_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 0, "0.weight"));
|
||||
vision_model.mm_model_block_2_block_0_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.weight"));
|
||||
vision_model.mm_model_block_2_block_0_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.bias"));
|
||||
vision_model.mm_model_block_2_block_1_fc1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.weight"));
|
||||
vision_model.mm_model_block_2_block_1_fc1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.bias"));
|
||||
vision_model.mm_model_block_2_block_1_fc2_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.weight"));
|
||||
vision_model.mm_model_block_2_block_1_fc2_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.bias"));
|
||||
vision_model.mm_model_block_2_block_2_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight"));
|
||||
vision_model.mm_model_block_2_block_2_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight"));
|
||||
vision_model.mm_model_block_2_block_2_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias"));
|
||||
}
|
||||
else {
|
||||
vision_model.mm_model_block_2_block_2_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight"));
|
||||
vision_model.mm_model_block_2_block_2_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight"));
|
||||
vision_model.mm_model_block_2_block_2_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias"));
|
||||
} else {
|
||||
std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type];
|
||||
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
|
||||
}
|
||||
|
||||
vision_model.layers.resize(hparams.n_layer);
|
||||
|
||||
for (int il = 0; il < hparams.n_layer; ++il) {
|
||||
auto & layer = vision_model.layers[il];
|
||||
layer.k_w = get_tensor(new_clip->ctx_data, format(TN_ATTN_K, "v", il, "weight"));
|
||||
|
@ -1050,8 +1230,20 @@ struct clip_image_f32 * clip_image_f32_init() {
|
|||
return new clip_image_f32();
|
||||
}
|
||||
|
||||
void clip_image_u8_free (struct clip_image_u8 * img) { delete img; }
|
||||
void clip_image_u8_free(struct clip_image_u8 * img) { delete img; }
|
||||
void clip_image_f32_free(struct clip_image_f32 * img) { delete img; }
|
||||
void clip_image_u8_batch_free(struct clip_image_u8_batch & batch) {
|
||||
if (batch.size > 0) {
|
||||
delete[] batch.data;
|
||||
batch.size = 0;
|
||||
}
|
||||
}
|
||||
void clip_image_f32_batch_free(struct clip_image_f32_batch & batch) {
|
||||
if (batch.size > 0) {
|
||||
delete[] batch.data;
|
||||
batch.size = 0;
|
||||
}
|
||||
}
|
||||
|
||||
static void build_clip_img_from_data(const stbi_uc * data, int nx, int ny, clip_image_u8 * img) {
|
||||
img->nx = nx;
|
||||
|
@ -1084,24 +1276,252 @@ bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length
|
|||
return true;
|
||||
}
|
||||
|
||||
// normalize: x = (x - mean) / std
|
||||
// TODO: implement bicubic interpolation instead of linear.
|
||||
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32 * res, const bool pad2square) {
|
||||
// Linear interpolation between two points
|
||||
inline float lerp(float s, float e, float t) {
|
||||
return s + (e - s) * t;
|
||||
}
|
||||
// Bilinear resize function
|
||||
static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) {
|
||||
dst.nx = target_width;
|
||||
dst.ny = target_height;
|
||||
dst.buf.resize(3 * target_width * target_height);
|
||||
|
||||
float x_ratio = static_cast<float>(src.nx - 1) / target_width;
|
||||
float y_ratio = static_cast<float>(src.ny - 1) / target_height;
|
||||
|
||||
for (int y = 0; y < target_height; y++) {
|
||||
for (int x = 0; x < target_width; x++) {
|
||||
float px = x_ratio * x;
|
||||
float py = y_ratio * y;
|
||||
int x_floor = static_cast<int>(px);
|
||||
int y_floor = static_cast<int>(py);
|
||||
float x_lerp = px - x_floor;
|
||||
float y_lerp = py - y_floor;
|
||||
|
||||
for (int c = 0; c < 3; c++) {
|
||||
float top = lerp(
|
||||
static_cast<float>(src.buf[3 * (y_floor * src.nx + x_floor) + c]),
|
||||
static_cast<float>(src.buf[3 * (y_floor * src.nx + (x_floor + 1)) + c]),
|
||||
x_lerp
|
||||
);
|
||||
float bottom = lerp(
|
||||
static_cast<float>(src.buf[3 * ((y_floor + 1) * src.nx + x_floor) + c]),
|
||||
static_cast<float>(src.buf[3 * ((y_floor + 1) * src.nx + (x_floor + 1)) + c]),
|
||||
x_lerp
|
||||
);
|
||||
dst.buf[3 * (y * target_width + x) + c] = static_cast<uint8_t>(lerp(top, bottom, y_lerp));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not
|
||||
static void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* dst, const float mean[3], const float std[3]) {
|
||||
dst->nx = src->nx;
|
||||
dst->ny = src->ny;
|
||||
dst->buf.resize(src->buf.size());
|
||||
|
||||
for (size_t i = 0; i < src->buf.size(); ++i) {
|
||||
int c = i % 3; // rgb
|
||||
dst->buf[i] = (static_cast<float>(src->buf[i]) / 255.0f - mean[c]) / std[c];
|
||||
}
|
||||
}
|
||||
|
||||
inline float clip(float x, float lower, float upper) {
|
||||
return std::max(lower, std::min(x, upper));
|
||||
}
|
||||
|
||||
static bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_width, int target_height) {
|
||||
const int nx = img.nx;
|
||||
const int ny = img.ny;
|
||||
|
||||
dst.nx = target_width;
|
||||
dst.ny = target_height;
|
||||
dst.buf.resize(3 * target_width * target_height);
|
||||
|
||||
float Cc;
|
||||
float C[5];
|
||||
float d0, d2, d3, a0, a1, a2, a3;
|
||||
int i, j, k, jj;
|
||||
int x, y;
|
||||
float dx, dy;
|
||||
float tx, ty;
|
||||
|
||||
tx = (float)nx / (float)target_width;
|
||||
ty = (float)ny / (float)target_height;
|
||||
|
||||
// Bicubic interpolation; adapted from ViT.cpp, inspired from :
|
||||
// -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36
|
||||
// -> https://en.wikipedia.org/wiki/Bicubic_interpolation
|
||||
|
||||
for (i = 0; i < target_height; i++) {
|
||||
for (j = 0; j < target_width; j++) {
|
||||
x = (int)(tx * j);
|
||||
y = (int)(ty * i);
|
||||
|
||||
dx = tx * j - x;
|
||||
dy = ty * i - y;
|
||||
|
||||
for (k = 0; k < 3; k++) {
|
||||
for (jj = 0; jj <= 3; jj++) {
|
||||
d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
|
||||
d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
|
||||
d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
|
||||
a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
|
||||
|
||||
a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3;
|
||||
a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2;
|
||||
a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3;
|
||||
|
||||
C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx;
|
||||
|
||||
d0 = C[0] - C[1];
|
||||
d2 = C[2] - C[1];
|
||||
d3 = C[3] - C[1];
|
||||
a0 = C[1];
|
||||
a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3;
|
||||
a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2;
|
||||
a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3;
|
||||
Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy;
|
||||
|
||||
const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f);
|
||||
dst.buf[(i * target_width + j) * 3 + k] = float(Cc2);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// llava-1.6 type of resize_and_pad (black)
|
||||
static void resize_and_pad_image(const clip_image_u8& image, clip_image_u8 &image_output, const std::pair<int, int>& target_resolution) {
|
||||
int target_width = target_resolution.first;
|
||||
int target_height = target_resolution.second;
|
||||
|
||||
float scale_w = static_cast<float>(target_width) / image.nx;
|
||||
float scale_h = static_cast<float>(target_height) / image.ny;
|
||||
|
||||
int new_width, new_height;
|
||||
|
||||
if (scale_w < scale_h) {
|
||||
new_width = target_width;
|
||||
new_height = std::min(static_cast<int>(std::ceil(image.ny * scale_w)), target_height);
|
||||
} else {
|
||||
new_height = target_height;
|
||||
new_width = std::min(static_cast<int>(std::ceil(image.nx * scale_h)), target_width);
|
||||
}
|
||||
|
||||
clip_image_u8 resized_image;
|
||||
// bilinear_resize(image, resized_image, new_width, new_height);
|
||||
bicubic_resize(image, resized_image, new_width, new_height);
|
||||
|
||||
clip_image_u8 padded_image;
|
||||
padded_image.nx = target_width;
|
||||
padded_image.ny = target_height;
|
||||
padded_image.buf.resize(3 * target_width * target_height, 0); // Initialize with black
|
||||
|
||||
// Calculate padding offsets
|
||||
int pad_x = (target_width - new_width) / 2;
|
||||
int pad_y = (target_height - new_height) / 2;
|
||||
|
||||
// Copy the resized image into the center of the padded buffer
|
||||
for (int y = 0; y < new_height; ++y) {
|
||||
for (int x = 0; x < new_width; ++x) {
|
||||
for (int c = 0; c < 3; ++c) {
|
||||
padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c];
|
||||
}
|
||||
}
|
||||
}
|
||||
image_output = std::move(padded_image);
|
||||
}
|
||||
|
||||
/**
|
||||
* Selects the best resolution from a list of possible resolutions based on the original size.
|
||||
*
|
||||
* @param original_size The original size of the image in the format (width, height).
|
||||
* @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
||||
* @return The best fit resolution in the format (width, height).
|
||||
*/
|
||||
static std::pair<int, int> select_best_resolution(const std::pair<int, int> & original_size, const std::vector<std::pair<int, int>> & possible_resolutions) {
|
||||
int original_width = original_size.first;
|
||||
int original_height = original_size.second;
|
||||
std::pair<int, int> best_fit;
|
||||
int max_effective_resolution = 0;
|
||||
int min_wasted_resolution = std::numeric_limits<int>::max();
|
||||
|
||||
for (const auto& resolution : possible_resolutions) {
|
||||
int width = resolution.first;
|
||||
int height = resolution.second;
|
||||
float scale = std::min(static_cast<float>(width) / original_width, static_cast<float>(height) / original_height);
|
||||
int downscaled_width = static_cast<int>(original_width * scale);
|
||||
int downscaled_height = static_cast<int>(original_height * scale);
|
||||
int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height);
|
||||
int wasted_resolution = (width * height) - effective_resolution;
|
||||
// fprintf(stderr, "resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution);
|
||||
if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) {
|
||||
max_effective_resolution = effective_resolution;
|
||||
min_wasted_resolution = wasted_resolution;
|
||||
best_fit = resolution;
|
||||
}
|
||||
}
|
||||
|
||||
return best_fit;
|
||||
}
|
||||
|
||||
static std::vector<clip_image_u8*> divide_to_patches_u8(const clip_image_u8 & image, int patch_size) {
|
||||
std::vector<clip_image_u8*> patches;
|
||||
int width = image.nx;
|
||||
int height = image.ny;
|
||||
for (int i = 0; i < height; i += patch_size) {
|
||||
for (int j = 0; j < width; j += patch_size) {
|
||||
clip_image_u8 *patch = clip_image_u8_init();
|
||||
patch->nx = std::min(patch_size, width - j);
|
||||
patch->ny = std::min(patch_size, height - i);
|
||||
patch->buf.resize(3 * patch->nx * patch->ny);
|
||||
for (int y = 0; y < patch->ny; ++y) {
|
||||
for (int x = 0; x < patch->nx; ++x) {
|
||||
for (int c = 0; c < 3; ++c) {
|
||||
patch->buf[3 * (y * patch->nx + x) + c] = image.buf[3 * ((i + y) * width + (j + x)) + c];
|
||||
}
|
||||
}
|
||||
}
|
||||
patches.push_back(patch);
|
||||
}
|
||||
}
|
||||
return patches;
|
||||
}
|
||||
|
||||
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
|
||||
// res_imgs memory is being allocated here, previous allocations will be freed if found
|
||||
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch & res_imgs) {
|
||||
bool pad_to_square = true;
|
||||
if (!ctx->has_vision_encoder) {
|
||||
printf("This gguf file seems to have no vision encoder\n");
|
||||
return false;
|
||||
}
|
||||
auto & params = ctx->vision_model.hparams;
|
||||
// The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing
|
||||
if (strcmp(params.mm_patch_merge_type, "spatial_unpad") == 0) {
|
||||
pad_to_square = false;
|
||||
}
|
||||
// free the previous res_imgs if any set
|
||||
if (res_imgs.size > 0) {
|
||||
clip_image_f32_batch_free(res_imgs);
|
||||
}
|
||||
res_imgs.data = nullptr;
|
||||
res_imgs.size = 0;
|
||||
|
||||
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
|
||||
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
|
||||
|
||||
clip_image_u8 * temp = clip_image_u8_init(); // we will keep the input image data here temporarily
|
||||
if (pad2square && img->nx != img->ny) {
|
||||
if (pad_to_square && img->nx != img->ny) {
|
||||
int longer_side = std::max(img->nx, img->ny);
|
||||
temp->nx = longer_side;
|
||||
temp->ny = longer_side;
|
||||
temp->buf.resize(3 * longer_side * longer_side);
|
||||
const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA
|
||||
const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA (this is the mean rgb color * 255)
|
||||
|
||||
// fill with background color
|
||||
for (size_t i = 0; i < temp->buf.size(); i++) {
|
||||
|
@ -1119,18 +1539,63 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
|
|||
}
|
||||
}
|
||||
} else {
|
||||
temp->nx = img->nx;
|
||||
temp->ny = img->ny;
|
||||
temp->buf.resize(img->buf.size());
|
||||
memcpy(temp->buf.data(), img->buf.data(), temp->buf.size());
|
||||
if (params.image_grid_pinpoints[0] != 0) {
|
||||
// "spatial_unpad" with "anyres" processing for llava-1.6
|
||||
std::vector<std::pair<int, int>> possible_resolutions;
|
||||
for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) {
|
||||
possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]});
|
||||
}
|
||||
std::pair<int, int> best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions);
|
||||
// clip_image_save_to_bmp(*img, "input.bmp");
|
||||
resize_and_pad_image(*img, *temp, best_resolution); // we do not pad with mean-bg color anymore in llava-1.6
|
||||
// clip_image_save_to_bmp(*temp, "resized.bmp");
|
||||
// visually verify normalized image:
|
||||
// normalize_image_u8_to_f32(*temp, *res, ctx->image_mean, ctx->image_std);
|
||||
// {
|
||||
// clip_image_u8 * temp2 = clip_image_u8_init();
|
||||
// clip_image_convert_f32_to_u8(*res, *temp2);
|
||||
// clip_image_save_to_bmp(*temp2, "resized_normalized_f32.bmp");
|
||||
// clip_image_u8_free(temp2);
|
||||
// }
|
||||
|
||||
std::vector<clip_image_u8 *> patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6)
|
||||
|
||||
clip_image_u8 *image_original_resize = clip_image_u8_init();
|
||||
// bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
|
||||
bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
|
||||
patches.insert(patches.begin(), image_original_resize);
|
||||
// clip_image_f32_batch_init(patches.size());
|
||||
res_imgs.size = patches.size();
|
||||
res_imgs.data = new clip_image_f32[res_imgs.size];
|
||||
int num=0;
|
||||
for (auto& patch : patches) {
|
||||
normalize_image_u8_to_f32(patch, &res_imgs.data[num], ctx->image_mean, ctx->image_std);
|
||||
num++;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < patches.size(); i++) {
|
||||
// printf("patch %d: %d %d\n", i, patches[i]->nx, patches[i]->ny);
|
||||
clip_image_u8_free(patches[i]);
|
||||
}
|
||||
|
||||
clip_image_u8_free(temp);
|
||||
|
||||
return true;
|
||||
} else {
|
||||
temp->nx = img->nx;
|
||||
temp->ny = img->ny;
|
||||
temp->buf.resize(img->buf.size());
|
||||
memcpy(temp->buf.data(), img->buf.data(), temp->buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
const int nx = temp->nx;
|
||||
const int ny = temp->ny;
|
||||
// clip_image_save_to_bmp(*temp, "resized_vanilla.bmp");
|
||||
|
||||
const int nx2 = ctx->vision_model.hparams.image_size;
|
||||
const int ny2 = ctx->vision_model.hparams.image_size;
|
||||
|
||||
clip_image_f32 * res = clip_image_f32_init();
|
||||
res->nx = nx2;
|
||||
res->ny = ny2;
|
||||
res->buf.resize(3 * nx2 * ny2);
|
||||
|
@ -1184,9 +1649,26 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
|
|||
}
|
||||
clip_image_u8_free(temp);
|
||||
|
||||
// {
|
||||
// clip_image_u8 * temp2 = clip_image_u8_init();
|
||||
// clip_image_convert_f32_to_u8(*res, *temp2);
|
||||
// clip_image_save_to_bmp(*temp2, "resized_normalized_f32_vanilla.bmp");
|
||||
// clip_image_u8_free(temp2);
|
||||
// }
|
||||
// res_imgs.push_back(res);
|
||||
|
||||
res_imgs.size = 1;
|
||||
res_imgs.data = new clip_image_f32[res_imgs.size];
|
||||
res_imgs.data[0] = *res;
|
||||
clip_image_f32_free(res);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.image_newline;
|
||||
}
|
||||
|
||||
void clip_free(clip_ctx * ctx) {
|
||||
ggml_free(ctx->ctx_data);
|
||||
gguf_free(ctx->ctx_gguf);
|
||||
|
@ -1194,6 +1676,42 @@ void clip_free(clip_ctx * ctx) {
|
|||
delete ctx;
|
||||
}
|
||||
|
||||
size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
|
||||
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
|
||||
}
|
||||
|
||||
int32_t clip_image_size(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.image_size;
|
||||
}
|
||||
|
||||
int32_t clip_patch_size(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.patch_size;
|
||||
}
|
||||
|
||||
int32_t clip_hidden_size(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.hidden_size;
|
||||
}
|
||||
|
||||
const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.mm_patch_merge_type;
|
||||
}
|
||||
|
||||
const int32_t * clip_image_grid(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.image_grid_pinpoints;
|
||||
}
|
||||
|
||||
int clip_n_patches(const struct clip_ctx * ctx) {
|
||||
const auto & params = ctx->vision_model.hparams;
|
||||
|
||||
int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
|
||||
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_LDP) {
|
||||
n_patches /= 4;
|
||||
}
|
||||
|
||||
return n_patches;
|
||||
}
|
||||
|
||||
bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) {
|
||||
if (!ctx->has_vision_encoder) {
|
||||
printf("This gguf file seems to have no vision encoder\n");
|
||||
|
@ -1213,7 +1731,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
}
|
||||
|
||||
int batch_size = imgs->size;
|
||||
if(ctx->has_llava_projector) {
|
||||
if (ctx->has_llava_projector) {
|
||||
GGML_ASSERT(batch_size == 1); // TODO: support multiple images
|
||||
}
|
||||
|
||||
|
@ -1224,9 +1742,10 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
// set inputs
|
||||
const auto & model = ctx->vision_model;
|
||||
const auto & hparams = model.hparams;
|
||||
const int image_size = hparams.image_size;
|
||||
const int patch_size = hparams.patch_size;
|
||||
const int num_patches = ((image_size / patch_size) * (image_size / patch_size));
|
||||
|
||||
const int image_size = hparams.image_size;
|
||||
const int patch_size = hparams.patch_size;
|
||||
const int num_patches = ((image_size / patch_size) * (image_size / patch_size));
|
||||
const int num_positions = num_patches + 1;
|
||||
|
||||
{
|
||||
|
@ -1301,11 +1820,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
|
||||
// copy the embeddings to the location passed by the user
|
||||
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype) {
|
||||
|
||||
ggml_type type = GGML_TYPE_Q4_1;
|
||||
|
||||
assert(itype < GGML_TYPE_COUNT);
|
||||
|
@ -1494,26 +2013,13 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
|||
if (ctx->proj_type == PROJECTOR_TYPE_LDP) {
|
||||
return ctx->vision_model.mm_model_block_1_block_2_1_b->ne[0];
|
||||
}
|
||||
else if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
|
||||
return ctx->vision_model.mm_2_b->ne[0];
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
||||
}
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
||||
return ctx->vision_model.mm_3_b->ne[0];
|
||||
}
|
||||
else {
|
||||
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
|
||||
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
|
||||
}
|
||||
}
|
||||
|
||||
int clip_n_patches(const struct clip_ctx * ctx) {
|
||||
auto & params = ctx->vision_model.hparams;
|
||||
int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_LDP) {
|
||||
n_patches /= 4;
|
||||
}
|
||||
return n_patches;
|
||||
}
|
||||
|
||||
size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
|
||||
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
|
||||
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
|
||||
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
|
||||
}
|
||||
|
|
|
@ -24,25 +24,7 @@ struct clip_ctx;
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct clip_vision_hparams {
|
||||
int32_t image_size;
|
||||
int32_t patch_size;
|
||||
int32_t hidden_size;
|
||||
int32_t n_intermediate;
|
||||
int32_t projection_dim;
|
||||
int32_t n_head;
|
||||
int32_t n_layer;
|
||||
float eps;
|
||||
};
|
||||
|
||||
CLIP_API struct clip_ctx * clip_model_load(const char * fname, int verbosity);
|
||||
|
||||
CLIP_API void clip_free(struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
|
||||
CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx);
|
||||
struct clip_ctx;
|
||||
|
||||
struct clip_image_u8_batch {
|
||||
struct clip_image_u8 * data;
|
||||
|
@ -54,18 +36,43 @@ struct clip_image_f32_batch {
|
|||
size_t size;
|
||||
};
|
||||
|
||||
CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
|
||||
CLIP_API struct clip_ctx * clip_model_load_cpu(const char * fname, int verbosity);
|
||||
|
||||
CLIP_API void clip_free(struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API int32_t clip_image_size (const struct clip_ctx * ctx);
|
||||
CLIP_API int32_t clip_patch_size (const struct clip_ctx * ctx);
|
||||
CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx);
|
||||
|
||||
// TODO: should be enum, not string
|
||||
CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
|
||||
CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API struct clip_image_u8 * clip_image_u8_init ();
|
||||
CLIP_API struct clip_image_f32 * clip_image_f32_init();
|
||||
|
||||
CLIP_API void clip_image_u8_free (struct clip_image_u8 * img);
|
||||
CLIP_API void clip_image_u8_free (struct clip_image_u8 * img);
|
||||
CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
|
||||
CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch & batch);
|
||||
CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch & batch);
|
||||
|
||||
CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
|
||||
|
||||
/** interpret bytes as an image file with length bytes_length, and use the result to populate img */
|
||||
CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img);
|
||||
|
||||
CLIP_API bool clip_image_preprocess (struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res, bool pad2square);
|
||||
/** preprocess img and store the result in res_imgs, pad_to_square may be overriden to false depending on model configuration */
|
||||
CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch & res_imgs );
|
||||
|
||||
CLIP_API struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API bool clip_image_encode (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec);
|
||||
CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec);
|
||||
|
||||
|
|
|
@ -78,18 +78,19 @@ ap.add_argument("--text-only", action="store_true", required=False,
|
|||
help="Save a text-only model. It can't be used to encode images")
|
||||
ap.add_argument("--vision-only", action="store_true", required=False,
|
||||
help="Save a vision-only model. It can't be used to encode texts")
|
||||
ap.add_argument("--clip_model_is_vision", action="store_true", required=False,
|
||||
ap.add_argument("--clip-model-is-vision", action="store_true", required=False,
|
||||
help="The clip model is a pure vision model (ShareGPT4V vision extract for example)")
|
||||
ap.add_argument("--clip-model-is-openclip", action="store_true", required=False,
|
||||
help="The clip model is from openclip (for ViT-SO400M type))")
|
||||
ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
|
||||
ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp", choices=["mlp", "ldp"], default="mlp")
|
||||
ap.add_argument("--image-mean", nargs=3, type=float, required=False, help="Override image mean values")
|
||||
ap.add_argument("--image-std", nargs=3, type=float, required=False, help="Override image std values")
|
||||
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
|
||||
# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711
|
||||
# Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5
|
||||
default_image_mean = [0.48145466, 0.4578275, 0.40821073]
|
||||
default_image_std = [0.26862954, 0.26130258, 0.27577711]
|
||||
ap.add_argument('--image_mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
|
||||
ap.add_argument('--image_std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
|
||||
ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
|
||||
ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
|
||||
|
||||
# with proper
|
||||
args = ap.parse_args()
|
||||
|
@ -105,7 +106,7 @@ if args.use_f32:
|
|||
# output in the same directory as the model if output_dir is None
|
||||
dir_model = args.model_dir
|
||||
|
||||
if args.clip_model_is_vision:
|
||||
if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip:
|
||||
vocab = None
|
||||
tokens = None
|
||||
else:
|
||||
|
@ -133,7 +134,7 @@ ftype = 1
|
|||
if args.use_f32:
|
||||
ftype = 0
|
||||
|
||||
if args.clip_model_is_vision:
|
||||
if args.clip_model_is_vision or args.clip_model_is_openclip:
|
||||
model = CLIPVisionModel.from_pretrained(dir_model)
|
||||
processor = None
|
||||
else:
|
||||
|
@ -202,6 +203,57 @@ if has_vision_encoder:
|
|||
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"])
|
||||
block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"]
|
||||
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
|
||||
# /**
|
||||
# "image_grid_pinpoints": [
|
||||
# [
|
||||
# 336,
|
||||
# 672
|
||||
# ],
|
||||
# [
|
||||
# 672,
|
||||
# 336
|
||||
# ],
|
||||
# [
|
||||
# 672,
|
||||
# 672
|
||||
# ],
|
||||
# [
|
||||
# 1008,
|
||||
# 336
|
||||
# ],
|
||||
# [
|
||||
# 336,
|
||||
# 1008
|
||||
# ]
|
||||
# ],
|
||||
# Flattened:
|
||||
# [
|
||||
# 336, 672,
|
||||
# 672, 336,
|
||||
# 672, 672,
|
||||
# 1008, 336,
|
||||
# 336, 1008
|
||||
# ]
|
||||
# *
|
||||
# */
|
||||
if "image_grid_pinpoints" in v_hparams:
|
||||
# flatten it
|
||||
image_grid_pinpoints = []
|
||||
for pinpoint in v_hparams["image_grid_pinpoints"]:
|
||||
for p in pinpoint:
|
||||
image_grid_pinpoints.append(p)
|
||||
fout.add_array("clip.vision.image_grid_pinpoints", image_grid_pinpoints)
|
||||
if "image_crop_resolution" in v_hparams:
|
||||
fout.add_uint32("clip.vision.image_crop_resolution", v_hparams["image_crop_resolution"])
|
||||
if "image_aspect_ratio" in v_hparams:
|
||||
fout.add_string("clip.vision.image_aspect_ratio", v_hparams["image_aspect_ratio"])
|
||||
if "image_split_resolution" in v_hparams:
|
||||
fout.add_uint32("clip.vision.image_split_resolution", v_hparams["image_split_resolution"])
|
||||
if "mm_patch_merge_type" in v_hparams:
|
||||
fout.add_string("clip.vision.mm_patch_merge_type", v_hparams["mm_patch_merge_type"])
|
||||
if "mm_projector_type" in v_hparams:
|
||||
fout.add_string("clip.vision.mm_projector_type", v_hparams["mm_projector_type"])
|
||||
|
||||
|
||||
if processor is not None:
|
||||
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean
|
||||
|
|
|
@ -155,11 +155,29 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
|
|||
system_prompt = prompt.substr(0, image_pos);
|
||||
user_prompt = prompt.substr(image_pos + std::string("<image>").length());
|
||||
printf("system_prompt: %s\n", system_prompt.c_str());
|
||||
if (params->verbose_prompt) {
|
||||
auto tmp = ::llama_tokenize(ctx_llava->ctx_llama, system_prompt, true, true);
|
||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||
printf("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
|
||||
}
|
||||
}
|
||||
printf("user_prompt: %s\n", user_prompt.c_str());
|
||||
if (params->verbose_prompt) {
|
||||
auto tmp = ::llama_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
|
||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||
printf("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// llava-1.5 native mode
|
||||
system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:";
|
||||
user_prompt = prompt + "\nASSISTANT:";
|
||||
if (params->verbose_prompt) {
|
||||
auto tmp = ::llama_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
|
||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||
printf("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, add_bos);
|
||||
|
@ -171,13 +189,17 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
|
|||
fprintf(stderr, "\n");
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
|
||||
|
||||
std::string response = "";
|
||||
for (int i = 0; i < max_tgt_len; i++) {
|
||||
const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past);
|
||||
response += tmp;
|
||||
if (strcmp(tmp, "</s>") == 0) break;
|
||||
if (strstr(tmp, "###")) break; // Yi-VL behavior
|
||||
|
||||
printf("%s", tmp);
|
||||
if (strstr(response.c_str(), "<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works)
|
||||
if (strstr(response.c_str(), "<|im_start|>")) break; // Yi-34B llava-1.6
|
||||
if (strstr(response.c_str(), "USER:")) break; // mistral llava-1.6
|
||||
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
|
@ -196,7 +218,8 @@ static struct llava_context * llava_init(gpt_params * params) {
|
|||
|
||||
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
|
||||
|
||||
llama_backend_init(params->numa);
|
||||
llama_backend_init();
|
||||
llama_numa_init(params->numa);
|
||||
|
||||
llama_model_params model_params = llama_model_params_from_gpt_params(*params);
|
||||
|
||||
|
|
167
examples/llava/llava-surgery-v2.py
Normal file
167
examples/llava/llava-surgery-v2.py
Normal file
|
@ -0,0 +1,167 @@
|
|||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load as safe_load, save as safe_save, safe_open, save_file
|
||||
|
||||
# Function to determine if file is a SafeTensor file
|
||||
def is_safetensor_file(file_path):
|
||||
return file_path.endswith('.safetensors')
|
||||
|
||||
|
||||
# Unified loading function
|
||||
def load_model(file_path):
|
||||
if is_safetensor_file(file_path):
|
||||
tensors = {}
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
tensors[key] = f.get_tensor(key).clone()
|
||||
# output shape
|
||||
print(f"{key} : {tensors[key].shape}")
|
||||
return tensors, 'safetensor'
|
||||
else:
|
||||
return torch.load(file_path, map_location=torch.device('cpu')), 'pytorch'
|
||||
|
||||
|
||||
# Unified saving function
|
||||
def save_model(model, file_path, file_type):
|
||||
if file_type == 'safetensor':
|
||||
# safe_save(model, file_path)
|
||||
save_file(model, file_path)
|
||||
else:
|
||||
torch.save(model, file_path)
|
||||
|
||||
|
||||
# Adapted function to clean vision tower from checkpoint
|
||||
def clean_vision_tower_from_checkpoint(checkpoint_path):
|
||||
checkpoint, file_type = load_model(checkpoint_path)
|
||||
# file_type = 'pytorch'
|
||||
model_path = os.path.dirname(checkpoint_path)
|
||||
print(f"Searching for vision tower tensors in {checkpoint_path}")
|
||||
clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))]
|
||||
|
||||
if len(clip_tensors) > 0:
|
||||
print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
|
||||
# Adapted for file type
|
||||
clip_path = os.path.join(model_path, "llava.clip")
|
||||
|
||||
if os.path.exists(clip_path):
|
||||
print(f"Loading existing llava.clip from {clip_path}")
|
||||
existing_clip, _ = load_model(clip_path)
|
||||
else:
|
||||
print(f"Creating new llava.clip at {clip_path}")
|
||||
existing_clip = {}
|
||||
# Update existing_clip with new tensors, avoid duplicates
|
||||
for name in clip_tensors:
|
||||
simple_name = name[name.index('vision_model.'):] if 'vision_model.' in name else name
|
||||
print(f"Adding {simple_name} to llava.clip")
|
||||
if simple_name not in existing_clip:
|
||||
existing_clip[simple_name] = checkpoint[name]
|
||||
|
||||
# Save the updated clip tensors back to llava.clip
|
||||
save_model(existing_clip, clip_path, 'pytorch')
|
||||
|
||||
# Remove the tensors from the original checkpoint
|
||||
for name in clip_tensors:
|
||||
del checkpoint[name]
|
||||
|
||||
# Save the updated checkpoint
|
||||
checkpoint_path = checkpoint_path
|
||||
save_model(checkpoint, checkpoint_path, file_type)
|
||||
return True
|
||||
return False
|
||||
|
||||
def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
|
||||
newline_checkpoint_path = None
|
||||
projector_checkpoint_path = None
|
||||
|
||||
for path in checkpoint_paths:
|
||||
checkpoint, _ = load_model(path)
|
||||
if newline_criteria(checkpoint) and newline_checkpoint_path is None:
|
||||
newline_checkpoint_path = path
|
||||
if projector(checkpoint):
|
||||
projector_checkpoint_path = path
|
||||
|
||||
return newline_checkpoint_path, projector_checkpoint_path
|
||||
|
||||
def newline_criteria(checkpoint):
|
||||
return any(k.startswith("model.image_newline") for k in checkpoint.keys())
|
||||
|
||||
def proj_criteria(checkpoint):
|
||||
return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys())
|
||||
|
||||
|
||||
# Command-line interface setup
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("-m", "--model", required=True, help="Path to LLaVA v1.5+ model")
|
||||
ap.add_argument("-C", "--clean-vision-tower", action="store_true", help="Remove any vision tower from the model files")
|
||||
args = ap.parse_args()
|
||||
|
||||
if args.clean_vision_tower:
|
||||
# Generalized to handle both PyTorch and SafeTensors models
|
||||
model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
|
||||
# checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and path.startswith('pytorch')) or (path.endswith('.safetensors') and path.startswith('model'))]
|
||||
checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
|
||||
for projector_checkpoint_path in checkpoint_paths:
|
||||
print(f"Cleaning {projector_checkpoint_path}")
|
||||
if not clean_vision_tower_from_checkpoint(projector_checkpoint_path):
|
||||
print(f"No vision tower found in {projector_checkpoint_path}")
|
||||
# we break once none is found, so far all models append them at the end
|
||||
# break
|
||||
print("Done! All vision tower tensors are removed from the model files and stored in llava.clip file.")
|
||||
|
||||
# Now we look for the projector in the last checkpoint
|
||||
model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
|
||||
checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
|
||||
# last_checkpoint_path = checkpoint_paths[0]
|
||||
# first_checkpoint_path = checkpoint_paths[-1]
|
||||
newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria)
|
||||
|
||||
print(f"Taking projector from {projector_checkpoint_path}")
|
||||
first_mm_tensors = []
|
||||
first_checkpoint = None
|
||||
if newline_checkpoint_path is not None:
|
||||
print(f"Taking newline from {newline_checkpoint_path}")
|
||||
first_checkpoint, file_type = load_model(newline_checkpoint_path)
|
||||
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
|
||||
|
||||
# Load the checkpoint
|
||||
mm_tensors = []
|
||||
last_checkpoint = None
|
||||
if projector_checkpoint_path is not None:
|
||||
last_checkpoint, file_type = load_model(projector_checkpoint_path)
|
||||
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
|
||||
|
||||
if len(mm_tensors) == 0:
|
||||
if last_checkpoint is not None:
|
||||
for k, v in last_checkpoint.items():
|
||||
print(k)
|
||||
print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint)} tensors.")
|
||||
print("No tensors found. Is this a LLaVA model?")
|
||||
exit()
|
||||
|
||||
print(f"Found {len(mm_tensors)} tensors to extract.")
|
||||
print(f"Found additional {len(first_mm_tensors)} tensors to extract.")
|
||||
# projector = {name: checkpoint.[name].float() for name in mm_tensors}
|
||||
projector = {}
|
||||
for name in mm_tensors:
|
||||
projector[name] = last_checkpoint[name].float()
|
||||
for name in first_mm_tensors:
|
||||
projector[name] = first_checkpoint[name].float()
|
||||
|
||||
if len(projector) > 0:
|
||||
save_model(projector, f"{args.model}/llava.projector", 'pytorch')
|
||||
|
||||
for name in mm_tensors:
|
||||
del last_checkpoint[name]
|
||||
for name in first_mm_tensors:
|
||||
del first_checkpoint[name]
|
||||
|
||||
if len(mm_tensors) > 0:
|
||||
save_model(last_checkpoint, projector_checkpoint_path, file_type)
|
||||
if len(first_mm_tensors) > 0:
|
||||
save_model(first_checkpoint, newline_checkpoint_path, file_type)
|
||||
|
||||
print("Done!")
|
||||
print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.")
|
||||
print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")
|
|
@ -19,19 +19,12 @@ mm_tensors = [k for k, v in checkpoint.items() if k.startswith("model.mm_project
|
|||
projector = {name: checkpoint[name].float() for name in mm_tensors}
|
||||
torch.save(projector, f"{args.model}/llava.projector")
|
||||
|
||||
# remove these tensors from the checkpoint and save it again
|
||||
for name in mm_tensors:
|
||||
del checkpoint[name]
|
||||
|
||||
# BakLLaVA models contain CLIP tensors in it
|
||||
clip_tensors = [k for k, v in checkpoint.items() if k.startswith("model.vision_tower")]
|
||||
if len(clip_tensors) > 0:
|
||||
clip = {name.replace("vision_tower.vision_tower.", ""): checkpoint[name].float() for name in clip_tensors}
|
||||
torch.save(clip, f"{args.model}/llava.clip")
|
||||
|
||||
# remove these tensors
|
||||
for name in clip_tensors:
|
||||
del checkpoint[name]
|
||||
|
||||
# added tokens should be removed to be able to convert Mistral models
|
||||
if os.path.exists(f"{args.model}/added_tokens.json"):
|
||||
|
@ -39,7 +32,6 @@ if len(clip_tensors) > 0:
|
|||
f.write("{}\n")
|
||||
|
||||
|
||||
torch.save(checkpoint, path)
|
||||
|
||||
print("Done!")
|
||||
print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")
|
||||
|
|
|
@ -2,32 +2,296 @@
|
|||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include "llava.h"
|
||||
#include "base64.hpp"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
|
||||
// RGB uint8 image
|
||||
struct clip_image_u8 {
|
||||
int nx;
|
||||
int ny;
|
||||
|
||||
std::vector<uint8_t> buf;
|
||||
};
|
||||
|
||||
// RGB float32 image (NHWC)
|
||||
// Memory layout: RGBRGBRGB...
|
||||
struct clip_image_f32 {
|
||||
int nx;
|
||||
int ny;
|
||||
|
||||
std::vector<float> buf;
|
||||
};
|
||||
|
||||
struct clip_image_grid_shape {
|
||||
int first;
|
||||
int second;
|
||||
};
|
||||
|
||||
/**
|
||||
* Selects the best resolution from a list of possible resolutions based on the original size.
|
||||
*
|
||||
* @param original_size The original size of the image in the format (width, height).
|
||||
* @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
||||
* @return The best fit resolution in the format (width, height).
|
||||
*/
|
||||
static std::pair<int, int> select_best_resolution(const std::pair<int, int>& original_size, const std::vector<std::pair<int, int>>& possible_resolutions) {
|
||||
int original_width = original_size.first;
|
||||
int original_height = original_size.second;
|
||||
|
||||
std::pair<int, int> best_fit;
|
||||
int max_effective_resolution = 0;
|
||||
int min_wasted_resolution = std::numeric_limits<int>::max();
|
||||
|
||||
for (const auto& resolution : possible_resolutions) {
|
||||
int width = resolution.first;
|
||||
int height = resolution.second;
|
||||
float scale = std::min(static_cast<float>(width) / original_width, static_cast<float>(height) / original_height);
|
||||
int downscaled_width = static_cast<int>(original_width * scale);
|
||||
int downscaled_height = static_cast<int>(original_height * scale);
|
||||
int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height);
|
||||
int wasted_resolution = (width * height) - effective_resolution;
|
||||
// fprintf(stderr, "resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution);
|
||||
if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) {
|
||||
max_effective_resolution = effective_resolution;
|
||||
min_wasted_resolution = wasted_resolution;
|
||||
best_fit = resolution;
|
||||
}
|
||||
}
|
||||
|
||||
return best_fit;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the anyres image grid shape object
|
||||
*
|
||||
* @param image_size
|
||||
* @param grid_pinpoints
|
||||
* @param image_patch_size
|
||||
* @return <int, int>
|
||||
*/
|
||||
static struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair<int, int> & image_size, const std::vector<std::pair<int, int>> & grid_pinpoints, int image_patch_size) {
|
||||
/**
|
||||
Conversion from gguf flat array to vector:
|
||||
std::vector<std::pair<int, int>> possible_resolutions;
|
||||
for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) {
|
||||
possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]});
|
||||
}
|
||||
*/
|
||||
auto best_resolution = select_best_resolution(image_size, grid_pinpoints);
|
||||
return {best_resolution.first / image_patch_size, best_resolution.second / image_patch_size};
|
||||
}
|
||||
|
||||
// Take the image segments in a grid configuration and return the embeddings and the number of embeddings into preallocated memory (image_embd_out)
|
||||
static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *> & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out) {
|
||||
struct {
|
||||
struct ggml_tensor * newline;
|
||||
struct ggml_context * ctx;
|
||||
} model;
|
||||
|
||||
const int32_t image_size = clip_image_size(ctx_clip);
|
||||
const int32_t patch_size = clip_patch_size(ctx_clip);
|
||||
|
||||
int32_t num_patches_per_side = image_size / patch_size; // 336 / 14 = 24 - used for embedding-patching boxes (24*24 = 576 patches)
|
||||
|
||||
int num_patches_width = grid_shape.first; // grid 1-4
|
||||
int num_patches_height = grid_shape.second; // grid 1-4
|
||||
|
||||
const size_t num_images = num_patches_width * num_patches_height + 1;
|
||||
|
||||
// TODO: size calculation is not calculated - it's only tens of MB
|
||||
size_t ctx_size = 0;
|
||||
|
||||
{
|
||||
ctx_size += clip_embd_nbytes(ctx_clip) * num_images * 8; // image_features
|
||||
ctx_size += 1024*1024 * ggml_type_size(GGML_TYPE_F32);
|
||||
}
|
||||
|
||||
struct ggml_init_params params {
|
||||
/*.mem_size =*/ ctx_size,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ false, // NOTE: this should be false when using the legacy API
|
||||
};
|
||||
|
||||
// Python reference code for full unpad:
|
||||
/*
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
image_feature = torch.cat((
|
||||
image_feature,
|
||||
self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1)
|
||||
), dim=-1)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
||||
*/
|
||||
// We now have two options: unpad or no unpad. Unpad removes tokens for faster llm eval.
|
||||
// In terms of result quality it appears to make no difference, so we'll start with the easier approach given 5D tensors are not supported in ggml yet.
|
||||
// Without unpad we have to split the sub-image embeddings into patches of 24 features each and permute them.
|
||||
// Once all images are processed to prepended the base_image_features without any changes.
|
||||
|
||||
// Pytorch reference simplified, modified for ggml compatibility - confirmed identical output in python (for a 2x2 grid image (676x676 scaling))
|
||||
/*
|
||||
image_feature = image_feature.view(2, 2, 24, 24, 4096)
|
||||
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
|
||||
image_feature = image_feature.view(2, 24, 2, 24, 4096)
|
||||
image_feature = image_feature.flatten(0, 3)
|
||||
|
||||
// Reshape to 4D tensor by merging the last two dimensions
|
||||
image_feature = image_feature.view(2, 2, 24, 24*4096)
|
||||
image_feature = image_feature.permute(0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.view(-1, 4096)
|
||||
*/
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
|
||||
ggml_tensor * newline_tmp = clip_get_newline_tensor(ctx_clip);
|
||||
model.newline = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, newline_tmp->ne[0]);
|
||||
if (newline_tmp->backend != GGML_BACKEND_CPU) {
|
||||
if (newline_tmp->buffer == NULL) {
|
||||
printf("newline_tmp tensor buffer is NULL\n");
|
||||
}
|
||||
ggml_backend_tensor_get(newline_tmp, model.newline->data, 0, ggml_nbytes(newline_tmp));
|
||||
} else {
|
||||
model.newline->data = newline_tmp->data;
|
||||
if (model.newline->data == NULL) {
|
||||
printf("newline_tmp tensor data is NULL\n");
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_patches(ctx_clip), num_images - 1); // example: 4096 x 576 x 4
|
||||
// ggml_tensor_printf(image_features,"image_features",__LINE__,false,false);
|
||||
// fill it with the image embeddings, ignoring the base
|
||||
for (size_t i = 1; i < num_images; i++) {
|
||||
size_t offset = (i-1) * clip_embd_nbytes(ctx_clip);
|
||||
memcpy((uint8_t *)(image_features->data) + offset, image_embd_v[i], clip_embd_nbytes(ctx_clip));
|
||||
}
|
||||
|
||||
struct ggml_cgraph * gf = ggml_new_graph(model.ctx);
|
||||
size_t size_ele = ggml_type_size(GGML_TYPE_F32);
|
||||
|
||||
struct ggml_tensor *image_features_patchview = ggml_view_4d(model.ctx, image_features,
|
||||
num_patches_per_side * clip_n_mmproj_embd(ctx_clip),
|
||||
num_patches_per_side,
|
||||
num_patches_width,
|
||||
num_patches_height,
|
||||
size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip),
|
||||
size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip) * num_patches_per_side,
|
||||
size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip) * num_patches_per_side * num_patches_width, 0);
|
||||
// ggml_tensor_printf(image_features_patchview,"image_features_patchview",__LINE__,false,false);
|
||||
struct ggml_tensor *permuted_cont = ggml_cont(model.ctx, ggml_permute(model.ctx, image_features_patchview, 0, 2, 1, 3));
|
||||
/**
|
||||
At the end of each row we have to add the row_end embeddings, which are the same as the newline embeddings
|
||||
image_feature = torch.cat((
|
||||
image_feature,
|
||||
self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
|
||||
), dim=-1)
|
||||
*
|
||||
*/
|
||||
|
||||
// ggml_tensor_printf(permuted_cont,"permuted_cont",__LINE__,false,false);
|
||||
struct ggml_tensor *flatten = ggml_view_2d(model.ctx, permuted_cont, clip_n_mmproj_embd(ctx_clip), num_patches_height * num_patches_width * num_patches_per_side * num_patches_per_side, size_ele * clip_n_mmproj_embd(ctx_clip), 0);
|
||||
// ggml_tensor_printf(flatten,"flatten",__LINE__,false,false);
|
||||
ggml_build_forward_expand(gf, flatten);
|
||||
ggml_graph_compute_with_ctx(model.ctx, gf, 1);
|
||||
struct ggml_tensor* result = gf->nodes[gf->n_nodes - 1];
|
||||
|
||||
memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as global context
|
||||
// append without newline tokens (default behavior in llava_arch when not using unpad ):
|
||||
memcpy(image_embd_out + clip_n_patches(ctx_clip) * clip_n_mmproj_embd(ctx_clip), (float*)result->data, clip_embd_nbytes(ctx_clip) * (num_images-1)); // grid patches
|
||||
*n_img_pos_out = static_cast<int>(result->ne[1]+clip_n_patches(ctx_clip));
|
||||
|
||||
// Debug: Test single segments
|
||||
// Current findings: sending base image, sending a segment embedding all works similar to python
|
||||
// However, permuted embeddings do not work yet (stride issue?)
|
||||
// memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as context
|
||||
// memcpy(image_embd_out, (float*)prepared_cont->data, clip_embd_nbytes(ctx_clip)); // main image as context
|
||||
// *n_img_pos_out=576;
|
||||
|
||||
ggml_free(model.ctx);
|
||||
return true;
|
||||
}
|
||||
|
||||
#include "base64.hpp"
|
||||
|
||||
static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) {
|
||||
clip_image_f32 * img_res = clip_image_f32_init();
|
||||
if (!clip_image_preprocess(ctx_clip, img, img_res, /*pad2square =*/ true)) {
|
||||
// std::vector<clip_image_f32*> img_res_v; // format VectN x H x W x RGB (N x 336 x 336 x 3), so interleaved RGB - different to the python implementation which is N x 3 x 336 x 336
|
||||
clip_image_f32_batch img_res_v;
|
||||
img_res_v.size = 0;
|
||||
img_res_v.data = nullptr;
|
||||
if (!clip_image_preprocess(ctx_clip, img, img_res_v)) {
|
||||
fprintf(stderr, "%s: unable to preprocess image\n", __func__);
|
||||
clip_image_f32_free(img_res);
|
||||
delete[] img_res_v.data;
|
||||
return false;
|
||||
}
|
||||
|
||||
*n_img_pos = clip_n_patches(ctx_clip);
|
||||
|
||||
const int64_t t_img_enc_start_us = ggml_time_us();
|
||||
bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd);
|
||||
clip_image_f32_free(img_res);
|
||||
if (!encoded) {
|
||||
fprintf(stderr, "Unable to encode image\n");
|
||||
|
||||
return false;
|
||||
const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip);
|
||||
|
||||
if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) {
|
||||
// flat / default llava-1.5 type embedding
|
||||
*n_img_pos = clip_n_patches(ctx_clip);
|
||||
bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); // image_embd shape is 576 x 4096
|
||||
delete[] img_res_v.data;
|
||||
if (!encoded) {
|
||||
fprintf(stderr, "Unable to encode image\n");
|
||||
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// spatial_unpad llava-1.6 type embedding
|
||||
// TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a solution to quickly get batching working
|
||||
std::vector<float *> image_embd_v;
|
||||
image_embd_v.resize(img_res_v.size);
|
||||
for (size_t i = 0; i < img_res_v.size; i++) {
|
||||
image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184
|
||||
const bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside
|
||||
if (!encoded) {
|
||||
fprintf(stderr, "Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
const int64_t t_img_enc_batch_us = ggml_time_us();
|
||||
printf("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
|
||||
|
||||
const int32_t * image_grid = clip_image_grid(ctx_clip);
|
||||
|
||||
std::vector<std::pair<int, int>> grid_pinpoints;
|
||||
for (int i = 0; i < 32 && image_grid[i] != 0; i += 2) {
|
||||
grid_pinpoints.push_back({image_grid[i], image_grid[i+1]});
|
||||
}
|
||||
|
||||
// free all img_res_v - not needed anymore
|
||||
delete[] img_res_v.data;
|
||||
img_res_v.size = 0;
|
||||
img_res_v.data = nullptr;
|
||||
|
||||
const int32_t image_size = clip_image_size(ctx_clip);
|
||||
|
||||
struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, image_size);
|
||||
|
||||
int n_img_pos_out;
|
||||
clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out);
|
||||
*n_img_pos = n_img_pos_out;
|
||||
|
||||
for (size_t i = 0; i < image_embd_v.size(); i++) {
|
||||
free(image_embd_v[i]);
|
||||
}
|
||||
image_embd_v.clear();
|
||||
|
||||
// debug image/segment/normalization content:
|
||||
// clip_image_u8 * tmp = clip_image_u8_init();
|
||||
// clip_image_convert_f32_to_u8(*image_feature, *tmp);
|
||||
// clip_image_save_to_bmp(*tmp, "image_feature.bmp");
|
||||
}
|
||||
|
||||
printf("%s: image embedding created: %d tokens\n", __func__, *n_img_pos);
|
||||
|
||||
const int64_t t_img_enc_end_us = ggml_time_us();
|
||||
float t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0;
|
||||
|
||||
|
@ -48,10 +312,9 @@ bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx *
|
|||
}
|
||||
|
||||
static bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) {
|
||||
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
|
||||
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*6); // TODO: base on gridsize/llava model
|
||||
if (!image_embd) {
|
||||
fprintf(stderr, "Unable to allocate memory for image embeddings\n");
|
||||
free(image_embd);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -85,7 +348,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
|
|||
return true;
|
||||
}
|
||||
|
||||
LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length) {
|
||||
struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length) {
|
||||
clip_image_u8 * img = clip_image_u8_init();
|
||||
if (!clip_image_load_from_bytes(image_bytes, image_bytes_length, img)) {
|
||||
clip_image_u8_free(img);
|
||||
|
@ -142,7 +405,7 @@ static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long
|
|||
return true;
|
||||
}
|
||||
|
||||
LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path) {
|
||||
struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path) {
|
||||
unsigned char* image_bytes;
|
||||
long image_bytes_length;
|
||||
auto loaded = load_file_to_bytes(image_path, &image_bytes, &image_bytes_length);
|
||||
|
@ -151,13 +414,13 @@ LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct
|
|||
return NULL;
|
||||
}
|
||||
|
||||
auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, image_bytes, image_bytes_length);
|
||||
llava_image_embed *embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, image_bytes, image_bytes_length);
|
||||
free(image_bytes);
|
||||
|
||||
return embed;
|
||||
}
|
||||
|
||||
LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed) {
|
||||
void llava_image_embed_free(struct llava_image_embed * embed) {
|
||||
free(embed->embed);
|
||||
free(embed);
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
#include "ggml.h"
|
||||
|
||||
|
||||
#ifdef LLAMA_SHARED
|
||||
# if defined(_WIN32) && !defined(__MINGW32__)
|
||||
# ifdef LLAMA_BUILD
|
||||
|
@ -42,7 +41,6 @@ LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
|
|||
/** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
|
||||
LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -54,7 +54,8 @@ int main(int argc, char ** argv) {
|
|||
#endif // LOG_DISABLE_LOGS
|
||||
|
||||
// init llama.cpp
|
||||
llama_backend_init(params.numa);
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
llama_model * model = NULL;
|
||||
llama_context * ctx = NULL;
|
||||
|
|
|
@ -31,7 +31,8 @@ int main(int argc, char ** argv){
|
|||
#endif // LOG_DISABLE_LOGS
|
||||
|
||||
// init llama.cpp
|
||||
llama_backend_init(params.numa);
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
llama_model * model = NULL;
|
||||
llama_context * ctx = NULL;
|
||||
|
|
|
@ -283,7 +283,11 @@ These options help improve the performance and memory usage of the LLaMA models.
|
|||
|
||||
### NUMA support
|
||||
|
||||
- `--numa`: Attempt optimizations that help on some systems with non-uniform memory access. This currently consists of pinning an equal proportion of the threads to the cores on each NUMA node, and disabling prefetch and readahead for mmap. The latter causes mapped pages to be faulted in on first access instead of all at once, and in combination with pinning threads to NUMA nodes, more of the pages end up on the NUMA node where they are used. Note that if the model is already in the system page cache, for example because of a previous run without this option, this will have little effect unless you drop the page cache first. This can be done by rebooting the system or on Linux by writing '3' to '/proc/sys/vm/drop_caches' as root.
|
||||
- `--numa distribute`: Pin an equal proportion of the threads to the cores on each NUMA node. This will spread the load amongst all cores on the system, utilitizing all memory channels at the expense of potentially requiring memory to travel over the slow links between nodes.
|
||||
- `--numa isolate`: Pin all threads to the NUMA node that the program starts on. This limits the number of cores and amount of memory that can be used, but guarantees all memory access remains local to the NUMA node.
|
||||
- `--numa numactl`: Pin threads to the CPUMAP that is passed to the program by starting it with the numactl utility. This is the most flexible mode, and allow arbitraty core usage patterns, for example a map that uses all the cores on one NUMA nodes, and just enough cores on a second node to saturate the inter-node memory bus.
|
||||
|
||||
These flags attempt optimizations that help on some systems with non-uniform memory access. This currently consists of one of the above strategies, and disabling prefetch and readahead for mmap. The latter causes mapped pages to be faulted in on first access instead of all at once, and in combination with pinning threads to NUMA nodes, more of the pages end up on the NUMA node where they are used. Note that if the model is already in the system page cache, for example because of a previous run without this option, this will have little effect unless you drop the page cache first. This can be done by rebooting the system or on Linux by writing '3' to '/proc/sys/vm/drop_caches' as root.
|
||||
|
||||
### Memory Float 32
|
||||
|
||||
|
|
|
@ -185,7 +185,8 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
LOG("%s: llama backend init\n", __func__);
|
||||
llama_backend_init(params.numa);
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
llama_model * model;
|
||||
llama_context * ctx;
|
||||
|
|
|
@ -122,7 +122,8 @@ int main(int argc, char ** argv) {
|
|||
#endif // LOG_DISABLE_LOGS
|
||||
|
||||
// init llama.cpp
|
||||
llama_backend_init(params.numa);
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
llama_model * model = NULL;
|
||||
llama_context * ctx = NULL;
|
||||
|
|
|
@ -71,7 +71,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// init LLM
|
||||
|
||||
llama_backend_init(params.numa);
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
// initialize the model
|
||||
|
||||
|
|
|
@ -309,7 +309,7 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens
|
|||
}
|
||||
|
||||
static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
||||
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
||||
// Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
|
||||
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
||||
// Output: `perplexity: 13.5106 [114/114]`
|
||||
// BOS tokens will be added for each chunk before eval
|
||||
|
@ -447,7 +447,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
return perplexity_v2(ctx, params);
|
||||
}
|
||||
|
||||
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
||||
// Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
|
||||
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
||||
// Output: `perplexity: 13.5106 [114/114]`
|
||||
// BOS tokens will be added for each chunk before eval
|
||||
|
@ -1623,7 +1623,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
|||
uint32_t n_ctx;
|
||||
in.read((char *)&n_ctx, sizeof(n_ctx));
|
||||
if (n_ctx > llama_n_ctx(ctx)) {
|
||||
fprintf(stderr, "%s: %s has been computed with %d, while the current context is %d. Increase it with -c and retry\n",
|
||||
fprintf(stderr, "%s: %s has been computed with %u, while the current context is %d. Increase it with -c and retry\n",
|
||||
__func__, params.logits_file.c_str(), n_ctx, params.n_ctx);
|
||||
}
|
||||
|
||||
|
@ -1809,7 +1809,8 @@ int main(int argc, char ** argv) {
|
|||
params.prompt = gpt_random_prompt(rng);
|
||||
}
|
||||
|
||||
llama_backend_init(params.numa);
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
llama_model * model;
|
||||
llama_context * ctx;
|
||||
|
|
|
@ -23,6 +23,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
|
|||
{ "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 4.70G, +0.0349 ppl @ LLaMA-v1-7B", },
|
||||
{ "IQ2_XXS",LLAMA_FTYPE_MOSTLY_IQ2_XXS," 2.06 bpw quantization", },
|
||||
{ "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization", },
|
||||
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
|
||||
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
|
||||
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
|
||||
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },
|
||||
|
@ -237,7 +238,7 @@ int main(int argc, char ** argv) {
|
|||
params.imatrix = &imatrix_data;
|
||||
}
|
||||
|
||||
llama_backend_init(false);
|
||||
llama_backend_init();
|
||||
|
||||
// parse command line arguments
|
||||
const std::string fname_inp = argv[arg_idx];
|
||||
|
@ -287,9 +288,10 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
if ((params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) && imatrix_data.empty()) {
|
||||
if ((params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS ||
|
||||
params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) && imatrix_data.empty()) {
|
||||
fprintf(stderr, "\n===============================================================================================\n");
|
||||
fprintf(stderr, "Please do not use IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n");
|
||||
fprintf(stderr, "Please do not use IQ1_S, IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n");
|
||||
fprintf(stderr, "===============================================================================================\n\n\n");
|
||||
return 1;
|
||||
}
|
||||
|
|
|
@ -16,6 +16,13 @@ Command line options:
|
|||
- `--memory-f32`: Use 32-bit floats instead of 16-bit floats for memory key+value. Not recommended.
|
||||
- `--mlock`: Lock the model in memory, preventing it from being swapped out when memory-mapped.
|
||||
- `--no-mmap`: Do not memory-map the model. By default, models are mapped into memory, which allows the system to load only the necessary parts of the model as needed.
|
||||
- `--numa STRATEGY`: Attempt one of the below optimization strategies that help on some NUMA systems
|
||||
- `--numa distribute`: Spread execution evenly over all nodes
|
||||
- `--numa isolate`: Only spawn threads on CPUs on the node that execution started on
|
||||
- `--numa numactl`: Use the CPU map provided by numactl
|
||||
if run without this previously, it is recommended to drop the system page cache before using this
|
||||
see https://github.com/ggerganov/llama.cpp/issues/1437
|
||||
|
||||
- `--numa`: Attempt optimizations that help on some NUMA systems.
|
||||
- `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains.
|
||||
- `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation.
|
||||
|
@ -32,6 +39,8 @@ Command line options:
|
|||
- `--mmproj MMPROJ_FILE`: Path to a multimodal projector file for LLaVA.
|
||||
- `--grp-attn-n`: Set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`
|
||||
- `--grp-attn-w`: Set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`
|
||||
- `-n, --n-predict`: Set the maximum tokens to predict (default: -1)
|
||||
- `--slots-endpoint-disable`: To disable slots state monitoring endpoint. Slots state may contain user data, prompts included.
|
||||
|
||||
## Build
|
||||
|
||||
|
@ -128,6 +137,7 @@ node index.js
|
|||
- `{"status": "loading model"}` if the model is still being loaded.
|
||||
- `{"status": "error"}` if the model failed to load.
|
||||
- `{"status": "ok"}` if the model is successfully loaded and the server is ready for further requests mentioned below.
|
||||
- `{"status": "no slot available", "slots_idle": 0, "slots_processing": 32}` if no slot are currently available
|
||||
|
||||
- **POST** `/completion`: Given a `prompt`, it returns the predicted completion.
|
||||
|
||||
|
@ -189,6 +199,8 @@ node index.js
|
|||
|
||||
`n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token (default: 0)
|
||||
|
||||
`min_keep`: If greater than 0, force samplers to return N possible tokens at minimum (default: 0)
|
||||
|
||||
`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `prompt`. You can determine the place of the image in the prompt as in the following: `USER:[img-12]Describe the image in detail.\nASSISTANT:`. In this case, `[img-12]` will be replaced by the embeddings of the image with id `12` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 12}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
|
||||
|
||||
`slot_id`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot (default: -1)
|
||||
|
@ -197,6 +209,8 @@ node index.js
|
|||
|
||||
`system_prompt`: Change the system prompt (initial prompt of all slots), this is useful for chat applications. [See more](#change-system-prompt-on-runtime)
|
||||
|
||||
`samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. (default: `["top_k", "tfs_z", "typical_p", "top_p", "min_p", "temperature"]` - these are all the available values)
|
||||
|
||||
### Result JSON
|
||||
|
||||
- Note: When using streaming mode (`stream`) only `content` and `stop` will be returned until end of completion.
|
||||
|
@ -370,6 +384,69 @@ Notice that each `probs` is an array of length `n_probs`.
|
|||
}'
|
||||
```
|
||||
|
||||
- **GET** `/slots`: Returns the current slots processing state. Can be disabled with `--slots-endpoint-disable`.
|
||||
|
||||
### Result JSON
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"dynatemp_exponent": 1.0,
|
||||
"dynatemp_range": 0.0,
|
||||
"frequency_penalty": 0.0,
|
||||
"grammar": "",
|
||||
"id": 0,
|
||||
"ignore_eos": false,
|
||||
"logit_bias": [],
|
||||
"min_p": 0.05000000074505806,
|
||||
"mirostat": 0,
|
||||
"mirostat_eta": 0.10000000149011612,
|
||||
"mirostat_tau": 5.0,
|
||||
"model": "llama-2-7b-32k-instruct.Q2_K.gguf",
|
||||
"n_ctx": 2048,
|
||||
"n_keep": 0,
|
||||
"n_predict": 100000,
|
||||
"n_probs": 0,
|
||||
"next_token": {
|
||||
"has_next_token": true,
|
||||
"n_remain": -1,
|
||||
"num_tokens_predicted": 0,
|
||||
"stopped_eos": false,
|
||||
"stopped_limit": false,
|
||||
"stopped_word": false,
|
||||
"stopping_word": ""
|
||||
},
|
||||
"penalize_nl": true,
|
||||
"penalty_prompt_tokens": [],
|
||||
"presence_penalty": 0.0,
|
||||
"prompt": "Say hello to llama.cpp",
|
||||
"repeat_last_n": 64,
|
||||
"repeat_penalty": 1.100000023841858,
|
||||
"samplers": [
|
||||
"top_k",
|
||||
"tfs_z",
|
||||
"typical_p",
|
||||
"top_p",
|
||||
"min_p",
|
||||
"temperature"
|
||||
],
|
||||
"seed": 42,
|
||||
"state": 1,
|
||||
"stop": [
|
||||
"\n"
|
||||
],
|
||||
"stream": false,
|
||||
"task_id": 0,
|
||||
"temperature": 0.0,
|
||||
"tfs_z": 1.0,
|
||||
"top_k": 40,
|
||||
"top_p": 0.949999988079071,
|
||||
"typical_p": 1.0,
|
||||
"use_penalty_prompt_tokens": false
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
## More examples
|
||||
|
||||
### Change system prompt on runtime
|
||||
|
|
|
@ -234,6 +234,7 @@
|
|||
mirostat_eta: 0.1, // learning rate
|
||||
grammar: '',
|
||||
n_probs: 0, // no completion_probabilities,
|
||||
min_keep: 0, // min probs from each sampler,
|
||||
image_data: [],
|
||||
cache_prompt: true,
|
||||
api_key: ''
|
||||
|
@ -791,6 +792,9 @@
|
|||
<fieldset>
|
||||
${IntField({ label: "Show Probabilities", max: 10, min: 0, name: "n_probs", value: params.value.n_probs })}
|
||||
</fieldset>
|
||||
<fieldset>
|
||||
${IntField({ label: "Min Probabilities from each Sampler", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
|
||||
</fieldset>
|
||||
<fieldset>
|
||||
<label for="api_key">API Key</label>
|
||||
<input type="text" name="api_key" value="${params.value.api_key}" placeholder="Enter API key" oninput=${updateParams} />
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include <chrono>
|
||||
#include <condition_variable>
|
||||
#include <atomic>
|
||||
#include <signal.h>
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
|
@ -40,6 +41,7 @@ struct server_params
|
|||
int32_t port = 8080;
|
||||
int32_t read_timeout = 600;
|
||||
int32_t write_timeout = 600;
|
||||
bool slots_endpoint = true;
|
||||
};
|
||||
|
||||
bool server_verbose = false;
|
||||
|
@ -158,6 +160,7 @@ struct llama_client_slot
|
|||
int32_t n_decoded = 0;
|
||||
int32_t n_remaining = -1;
|
||||
int32_t i_batch = -1;
|
||||
int32_t n_predict = -1;
|
||||
|
||||
int32_t num_prompt_tokens = 0;
|
||||
int32_t num_prompt_tokens_processed = 0;
|
||||
|
@ -409,6 +412,7 @@ struct llama_server_context
|
|||
|
||||
slot.id = i;
|
||||
slot.n_ctx = n_ctx_slot;
|
||||
slot.n_predict = params.n_predict;
|
||||
|
||||
LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot);
|
||||
|
||||
|
@ -436,10 +440,6 @@ struct llama_server_context
|
|||
default_generation_settings_for_props["seed"] = -1;
|
||||
|
||||
batch = llama_batch_init(n_ctx, 0, params.n_parallel);
|
||||
|
||||
// empty system prompt
|
||||
system_prompt = "";
|
||||
system_tokens.clear();
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokenize(const json & json_prompt, bool add_bos) const
|
||||
|
@ -548,6 +548,16 @@ struct llama_server_context
|
|||
slot->params.seed = json_value(data, "seed", default_params.seed);
|
||||
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
||||
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||
|
||||
if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) {
|
||||
// Might be better to reject the request with a 400 ?
|
||||
LOG_WARNING("Max tokens to predict exceeds server configuration", {
|
||||
{"params.n_predict", slot->params.n_predict},
|
||||
{"slot.n_predict", slot->n_predict},
|
||||
});
|
||||
slot->params.n_predict = slot->n_predict;
|
||||
}
|
||||
|
||||
// infill
|
||||
if (data.count("input_prefix") != 0)
|
||||
|
@ -676,6 +686,24 @@ struct llama_server_context
|
|||
}
|
||||
}
|
||||
|
||||
const auto &samplers_sequence = data.find("samplers");
|
||||
if (samplers_sequence != data.end() && samplers_sequence->is_array())
|
||||
{
|
||||
std::vector<std::string> sampler_names;
|
||||
for (const auto &sampler_name : *samplers_sequence)
|
||||
{
|
||||
if (sampler_name.is_string())
|
||||
{
|
||||
sampler_names.emplace_back(sampler_name);
|
||||
}
|
||||
}
|
||||
slot->sparams.samplers_sequence = sampler_types_from_names(sampler_names, false);
|
||||
}
|
||||
else
|
||||
{
|
||||
slot->sparams.samplers_sequence = default_sparams.samplers_sequence;
|
||||
}
|
||||
|
||||
if (multimodal)
|
||||
{
|
||||
const auto &images_data = data.find("image_data");
|
||||
|
@ -765,27 +793,30 @@ struct llama_server_context
|
|||
}
|
||||
|
||||
void update_system_prompt() {
|
||||
system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
|
||||
|
||||
llama_batch_clear(batch);
|
||||
|
||||
kv_cache_clear();
|
||||
system_tokens.clear();
|
||||
|
||||
for (int i = 0; i < (int) system_tokens.size(); ++i)
|
||||
{
|
||||
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
|
||||
}
|
||||
if (!system_prompt.empty()) {
|
||||
system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
|
||||
|
||||
if (llama_decode(ctx, batch) != 0)
|
||||
{
|
||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||
return;
|
||||
}
|
||||
llama_batch_clear(batch);
|
||||
|
||||
// assign the system KV cache to all parallel sequences
|
||||
for (int32_t i = 1; i < params.n_parallel; ++i)
|
||||
{
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size());
|
||||
for (int i = 0; i < (int)system_tokens.size(); ++i)
|
||||
{
|
||||
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch) != 0)
|
||||
{
|
||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
// assign the system KV cache to all parallel sequences
|
||||
for (int32_t i = 1; i < params.n_parallel; ++i)
|
||||
{
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size());
|
||||
}
|
||||
}
|
||||
|
||||
LOG_TEE("system prompt updated\n");
|
||||
|
@ -807,10 +838,8 @@ struct llama_server_context
|
|||
name_user = sys_props.value("anti_prompt", "");
|
||||
name_assistant = sys_props.value("assistant_name", "");
|
||||
|
||||
if (slots.size() > 0)
|
||||
{
|
||||
notify_system_prompt_changed();
|
||||
}
|
||||
|
||||
notify_system_prompt_changed();
|
||||
}
|
||||
|
||||
static size_t find_stopping_strings(const std::string &text, const size_t last_token_size,
|
||||
|
@ -968,18 +997,31 @@ struct llama_server_context
|
|||
{
|
||||
continue;
|
||||
}
|
||||
clip_image_f32 * img_res = clip_image_f32_init();
|
||||
if (!clip_image_preprocess(clp_ctx, img.img_data, img_res, /*pad2square =*/ true))
|
||||
clip_image_f32_batch img_res_v;
|
||||
img_res_v.size = 0;
|
||||
img_res_v.data = nullptr;
|
||||
if (!clip_image_preprocess(clp_ctx, img.img_data, img_res_v))
|
||||
{
|
||||
LOG_TEE("Error processing the given image");
|
||||
clip_free(clp_ctx);
|
||||
clip_image_f32_batch_free(img_res_v);
|
||||
return false;
|
||||
}
|
||||
if (img_res_v.size == 0)
|
||||
{
|
||||
LOG_TEE("Error processing the given image");
|
||||
return false;
|
||||
}
|
||||
|
||||
// note: assumes only one image was returned by clip_image_preprocess
|
||||
clip_image_f32 * img_res = img_res_v.data;
|
||||
|
||||
img.image_tokens = clip_n_patches(clp_ctx);
|
||||
img.image_embedding = (float *)malloc(clip_embd_nbytes(clp_ctx));
|
||||
if (!img.image_embedding)
|
||||
{
|
||||
LOG_TEE("Unable to allocate memory for image embeddings\n");
|
||||
clip_image_f32_batch_free(img_res_v);
|
||||
clip_free(clp_ctx);
|
||||
return false;
|
||||
}
|
||||
|
@ -987,9 +1029,12 @@ struct llama_server_context
|
|||
if (!clip_image_encode(clp_ctx, params.n_threads, img_res, img.image_embedding))
|
||||
{
|
||||
LOG_TEE("Unable to encode image\n");
|
||||
clip_image_f32_batch_free(img_res_v);
|
||||
return false;
|
||||
}
|
||||
clip_image_f32_free(img_res);
|
||||
|
||||
clip_image_f32_batch_free(img_res_v);
|
||||
|
||||
img.request_encode_image = false;
|
||||
}
|
||||
|
||||
|
@ -1013,8 +1058,15 @@ struct llama_server_context
|
|||
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
|
||||
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() &&
|
||||
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
||||
std::vector<std::string> samplers_sequence;
|
||||
for (const auto &sampler_type : slot.sparams.samplers_sequence)
|
||||
{
|
||||
samplers_sequence.emplace_back(sampler_type_to_name_string(sampler_type));
|
||||
}
|
||||
|
||||
return json {
|
||||
{"n_ctx", slot.n_ctx},
|
||||
{"n_predict", slot.n_predict},
|
||||
{"model", params.model_alias},
|
||||
{"seed", slot.params.seed},
|
||||
{"temperature", slot.sparams.temp},
|
||||
|
@ -1042,7 +1094,9 @@ struct llama_server_context
|
|||
{"stream", slot.params.stream},
|
||||
{"logit_bias", slot.sparams.logit_bias},
|
||||
{"n_probs", slot.sparams.n_probs},
|
||||
{"min_keep", slot.sparams.min_keep},
|
||||
{"grammar", slot.sparams.grammar},
|
||||
{"samplers", samplers_sequence}
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -1839,7 +1893,10 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
|||
{
|
||||
printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
|
||||
}
|
||||
printf(" --numa attempt optimizations that help on some NUMA systems\n");
|
||||
printf(" --numa TYPE attempt optimizations that help on some NUMA systems\n");
|
||||
printf(" - distribute: spread execution evenly over all nodes\n");
|
||||
printf(" - isolate: only spawn threads on CPUs on the node that execution started on\n");
|
||||
printf(" - numactl: use the CPU map provided my numactl\n");
|
||||
if (llama_supports_gpu_offload()) {
|
||||
printf(" -ngl N, --n-gpu-layers N\n");
|
||||
printf(" number of layers to store in VRAM\n");
|
||||
|
@ -1872,14 +1929,16 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
|||
printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
|
||||
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n");
|
||||
printf(" --log-disable disables logging to a file.\n");
|
||||
printf(" --slots-endpoint-disable disables slots monitoring endpoint.\n");
|
||||
printf("\n");
|
||||
printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict);
|
||||
printf(" --override-kv KEY=TYPE:VALUE\n");
|
||||
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
|
||||
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
|
||||
printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
|
||||
printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
|
||||
printf(" --chat-template FORMAT_NAME");
|
||||
printf(" set chat template, possible valus is: llama2, chatml (default %s)", sparams.chat_template.c_str());
|
||||
printf(" set chat template, possible value is: llama2, chatml (default %s)", sparams.chat_template.c_str());
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
|
@ -2248,9 +2307,17 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||
{
|
||||
params.use_mmap = false;
|
||||
}
|
||||
else if (arg == "--numa")
|
||||
{
|
||||
params.numa = true;
|
||||
else if (arg == "--numa") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
} else {
|
||||
std::string value(argv[i]);
|
||||
/**/ if (value == "distribute" || value == "" ) { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; }
|
||||
else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; }
|
||||
else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; }
|
||||
else { invalid_param = true; break; }
|
||||
}
|
||||
}
|
||||
else if (arg == "--embedding")
|
||||
{
|
||||
|
@ -2311,6 +2378,10 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||
log_set_target(stdout);
|
||||
LOG_INFO("logging to file is disabled.", {});
|
||||
}
|
||||
else if (arg == "--slots-endpoint-disable")
|
||||
{
|
||||
sparams.slots_endpoint = false;
|
||||
}
|
||||
else if (arg == "--chat-template")
|
||||
{
|
||||
if (++i >= argc)
|
||||
|
@ -2462,6 +2533,9 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con
|
|||
}
|
||||
}
|
||||
|
||||
std::function<void(int)> shutdown_handler;
|
||||
inline void signal_handler(int signal) { shutdown_handler(signal); }
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
#if SERVER_VERBOSE != 1
|
||||
|
@ -2481,7 +2555,8 @@ int main(int argc, char **argv)
|
|||
params.model_alias = params.model;
|
||||
}
|
||||
|
||||
llama_backend_init(params.numa);
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
LOG_INFO("build info", {{"build", LLAMA_BUILD_NUMBER},
|
||||
{"commit", LLAMA_COMMIT}});
|
||||
|
@ -2511,8 +2586,35 @@ int main(int argc, char **argv)
|
|||
server_state current_state = state.load();
|
||||
switch(current_state) {
|
||||
case SERVER_STATE_READY:
|
||||
res.set_content(R"({"status": "ok"})", "application/json");
|
||||
res.status = 200; // HTTP OK
|
||||
if (llama.all_slots_are_idle) {
|
||||
res.set_content(R"({"status": "ok"})", "application/json");
|
||||
res.status = 200; // HTTP OK
|
||||
} else {
|
||||
int available_slots = 0;
|
||||
int processing_slots = 0;
|
||||
for (llama_client_slot & slot : llama.slots) {
|
||||
if (slot.available()) {
|
||||
available_slots++;
|
||||
} else {
|
||||
processing_slots++;
|
||||
}
|
||||
}
|
||||
if (available_slots > 0) {
|
||||
json health = {
|
||||
{"status", "ok"},
|
||||
{"slots_idle", available_slots},
|
||||
{"slots_processing", processing_slots}};
|
||||
res.set_content(health.dump(), "application/json");
|
||||
res.status = 200; // HTTP OK
|
||||
} else {
|
||||
json health = {
|
||||
{"status", "no slot available"},
|
||||
{"slots_idle", available_slots},
|
||||
{"slots_processing", processing_slots}};
|
||||
res.set_content(health.dump(), "application/json");
|
||||
res.status = 503; // HTTP Service Unavailable
|
||||
}
|
||||
}
|
||||
break;
|
||||
case SERVER_STATE_LOADING_MODEL:
|
||||
res.set_content(R"({"status": "loading model"})", "application/json");
|
||||
|
@ -2525,6 +2627,32 @@ int main(int argc, char **argv)
|
|||
}
|
||||
});
|
||||
|
||||
if (sparams.slots_endpoint) {
|
||||
svr.Get("/slots", [&](const httplib::Request&, httplib::Response& res) {
|
||||
json slots;
|
||||
for (llama_client_slot & slot : llama.slots) {
|
||||
json slot_data = llama.get_formated_generation(slot);
|
||||
slot_data["id"] = slot.id;
|
||||
slot_data["task_id"] = slot.task_id;
|
||||
slot_data["state"] = slot.state;
|
||||
slot_data["prompt"] = slot.prompt;
|
||||
slot_data["next_token"] = {
|
||||
{"has_next_token", slot.has_next_token},
|
||||
{"n_remain", slot.n_remaining},
|
||||
{"num_tokens_predicted", slot.n_decoded},
|
||||
{"stopped_eos", slot.stopped_eos},
|
||||
{"stopped_word", slot.stopped_word},
|
||||
{"stopped_limit", slot.stopped_limit},
|
||||
{"stopping_word", slot.stopping_word},
|
||||
};
|
||||
|
||||
slots.push_back(slot_data);
|
||||
}
|
||||
res.set_content(slots.dump(), "application/json");
|
||||
res.status = 200; // HTTP OK
|
||||
});
|
||||
}
|
||||
|
||||
svr.set_logger(log_server_request);
|
||||
|
||||
svr.set_exception_handler([](const httplib::Request &, httplib::Response &res, std::exception_ptr ep)
|
||||
|
@ -3078,8 +3206,25 @@ int main(int argc, char **argv)
|
|||
std::placeholders::_2,
|
||||
std::placeholders::_3
|
||||
));
|
||||
llama.queue_tasks.start_loop();
|
||||
|
||||
shutdown_handler = [&](int) {
|
||||
llama.queue_tasks.terminate();
|
||||
};
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = signal_handler;
|
||||
sigemptyset (&sigint_action.sa_mask);
|
||||
sigint_action.sa_flags = 0;
|
||||
sigaction(SIGINT, &sigint_action, NULL);
|
||||
#elif defined (_WIN32)
|
||||
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
|
||||
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
|
||||
};
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
#endif
|
||||
llama.queue_tasks.start_loop();
|
||||
svr.stop();
|
||||
t.join();
|
||||
|
||||
llama_backend_free();
|
||||
|
|
|
@ -220,6 +220,7 @@ inline std::string format_chatml(std::vector<json> messages)
|
|||
struct llama_server_queue {
|
||||
int id = 0;
|
||||
std::mutex mutex_tasks;
|
||||
bool running;
|
||||
// queues
|
||||
std::vector<task_server> queue_tasks;
|
||||
std::vector<task_server> queue_tasks_deferred;
|
||||
|
@ -278,9 +279,18 @@ struct llama_server_queue {
|
|||
queue_tasks_deferred.clear();
|
||||
}
|
||||
|
||||
// Start the main loop. This call is blocking
|
||||
[[noreturn]]
|
||||
// end the start_loop routine
|
||||
void terminate() {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
running = false;
|
||||
}
|
||||
condition_tasks.notify_all();
|
||||
}
|
||||
|
||||
// Start the main loop.
|
||||
void start_loop() {
|
||||
running = true;
|
||||
while (true) {
|
||||
// new task arrived
|
||||
LOG_VERBOSE("have new task", {});
|
||||
|
@ -324,8 +334,12 @@ struct llama_server_queue {
|
|||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (queue_tasks.empty()) {
|
||||
if (!running) {
|
||||
LOG_VERBOSE("ending start_loop", {});
|
||||
return;
|
||||
}
|
||||
condition_tasks.wait(lock, [&]{
|
||||
return !queue_tasks.empty();
|
||||
return (!queue_tasks.empty() || !running);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// init LLM
|
||||
|
||||
llama_backend_init(params.numa);
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
// initialize the model
|
||||
|
||||
|
|
|
@ -50,7 +50,8 @@ int main(int argc, char ** argv) {
|
|||
#endif // LOG_DISABLE_LOGS
|
||||
|
||||
// init llama.cpp
|
||||
llama_backend_init(params.numa);
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
llama_model * model_tgt = NULL;
|
||||
llama_model * model_dft = NULL;
|
||||
|
|
|
@ -17,7 +17,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const bool printing_ids = argc > 3 && std::string(argv[3]) == "--ids";
|
||||
|
||||
llama_backend_init(false);
|
||||
llama_backend_init();
|
||||
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
model_params.vocab_only = true;
|
||||
|
|
|
@ -50,9 +50,9 @@ struct my_llama_layer {
|
|||
struct ggml_tensor * ffn_norm;
|
||||
|
||||
// ff
|
||||
struct ggml_tensor * w1;
|
||||
struct ggml_tensor * w2;
|
||||
struct ggml_tensor * w3;
|
||||
struct ggml_tensor * ffn_gate; // w1
|
||||
struct ggml_tensor * ffn_down; // w2
|
||||
struct ggml_tensor * ffn_up; // w3
|
||||
};
|
||||
|
||||
struct my_llama_model {
|
||||
|
@ -111,13 +111,13 @@ static const char * LLM_TENSOR_FFN_DOWN = "blk.%d.ffn_down";
|
|||
static const char * LLM_TENSOR_FFN_UP = "blk.%d.ffn_up";
|
||||
|
||||
static void print_params(struct my_llama_hparams * params) {
|
||||
printf("%s: n_vocab: %d\n", __func__, params->n_vocab);
|
||||
printf("%s: n_ctx: %d\n", __func__, params->n_ctx);
|
||||
printf("%s: n_embd: %d\n", __func__, params->n_embd);
|
||||
printf("%s: n_head: %d\n", __func__, params->n_head);
|
||||
printf("%s: n_ff: %d\n", __func__, params->n_ff);
|
||||
printf("%s: n_layer: %d\n", __func__, params->n_layer);
|
||||
printf("%s: n_rot: %d\n", __func__, params->n_rot);
|
||||
printf("%s: n_vocab: %u\n", __func__, params->n_vocab);
|
||||
printf("%s: n_ctx: %u\n", __func__, params->n_ctx);
|
||||
printf("%s: n_embd: %u\n", __func__, params->n_embd);
|
||||
printf("%s: n_head: %u\n", __func__, params->n_head);
|
||||
printf("%s: n_ff: %u\n", __func__, params->n_ff);
|
||||
printf("%s: n_layer: %u\n", __func__, params->n_layer);
|
||||
printf("%s: n_rot: %u\n", __func__, params->n_rot);
|
||||
}
|
||||
|
||||
static void set_param_model(struct my_llama_model * model) {
|
||||
|
@ -140,9 +140,9 @@ static void set_param_model(struct my_llama_model * model) {
|
|||
ggml_set_param(ctx, layer.wv);
|
||||
ggml_set_param(ctx, layer.wo);
|
||||
ggml_set_param(ctx, layer.ffn_norm);
|
||||
ggml_set_param(ctx, layer.w1);
|
||||
ggml_set_param(ctx, layer.w2);
|
||||
ggml_set_param(ctx, layer.w3);
|
||||
ggml_set_param(ctx, layer.ffn_gate);
|
||||
ggml_set_param(ctx, layer.ffn_down);
|
||||
ggml_set_param(ctx, layer.ffn_up);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -198,9 +198,9 @@ static void init_model(struct my_llama_model * model) {
|
|||
|
||||
layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.w1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
|
||||
layer.w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
|
||||
layer.w3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
|
||||
layer.ffn_gate = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
|
||||
layer.ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
|
||||
layer.ffn_up = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
|
||||
|
||||
ggml_set_name(layer.attention_norm, tni(LLM_TENSOR_ATTN_NORM, i));
|
||||
|
||||
|
@ -211,9 +211,9 @@ static void init_model(struct my_llama_model * model) {
|
|||
|
||||
ggml_set_name(layer.ffn_norm, tni(LLM_TENSOR_FFN_NORM, i));
|
||||
|
||||
ggml_set_name(layer.w1, tni(LLM_TENSOR_FFN_GATE, i));
|
||||
ggml_set_name(layer.w2, tni(LLM_TENSOR_FFN_DOWN, i));
|
||||
ggml_set_name(layer.w3, tni(LLM_TENSOR_FFN_UP, i));
|
||||
ggml_set_name(layer.ffn_gate, tni(LLM_TENSOR_FFN_GATE, i));
|
||||
ggml_set_name(layer.ffn_down, tni(LLM_TENSOR_FFN_DOWN, i));
|
||||
ggml_set_name(layer.ffn_up, tni(LLM_TENSOR_FFN_UP, i));
|
||||
}
|
||||
|
||||
set_param_model(model);
|
||||
|
@ -244,9 +244,9 @@ static void randomize_model(struct my_llama_model * model, int seed, float mean,
|
|||
|
||||
randomize_tensor_normal(layer.ffn_norm, rnd);
|
||||
|
||||
randomize_tensor_normal(layer.w1, rnd);
|
||||
randomize_tensor_normal(layer.w2, rnd);
|
||||
randomize_tensor_normal(layer.w3, rnd);
|
||||
randomize_tensor_normal(layer.ffn_gate, rnd);
|
||||
randomize_tensor_normal(layer.ffn_down, rnd);
|
||||
randomize_tensor_normal(layer.ffn_up, rnd);
|
||||
}
|
||||
|
||||
free_random_normal_distribution(rnd);
|
||||
|
@ -356,11 +356,11 @@ static struct ggml_tensor * llama_build_train_graphs(
|
|||
struct ggml_tensor * t22 = ggml_rms_norm (ctx, t21, f_norm_rms_eps); set_name(t22, "t22"); assert_shape_2d(t22, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t23 = ggml_repeat (ctx, layer.ffn_norm, t22); set_name(t23, "t23"); assert_shape_2d(t23, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t24 = ggml_mul (ctx, t23, t22); set_name(t24, "t24"); assert_shape_2d(t24, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t25 = ggml_mul_mat (ctx, layer.w3, t24); set_name(t25, "t25"); assert_shape_2d(t25, n_ff, N*n_batch);
|
||||
struct ggml_tensor * t26 = ggml_mul_mat (ctx, layer.w1, t24); set_name(t26, "t26"); assert_shape_2d(t26, n_ff, N*n_batch);
|
||||
struct ggml_tensor * t25 = ggml_mul_mat (ctx, layer.ffn_up, t24); set_name(t25, "t25"); assert_shape_2d(t25, n_ff, N*n_batch);
|
||||
struct ggml_tensor * t26 = ggml_mul_mat (ctx, layer.ffn_gate, t24); set_name(t26, "t26"); assert_shape_2d(t26, n_ff, N*n_batch);
|
||||
struct ggml_tensor * t27 = ggml_silu (ctx, t26); set_name(t27, "t27"); assert_shape_2d(t27, n_ff, N*n_batch);
|
||||
struct ggml_tensor * t28 = ggml_mul (ctx, t27, t25); set_name(t28, "t28"); assert_shape_2d(t28, n_ff, N*n_batch);
|
||||
struct ggml_tensor * t29 = ggml_mul_mat (ctx, layer.w2, t28); set_name(t29, "t29"); assert_shape_2d(t29, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t29 = ggml_mul_mat (ctx, layer.ffn_down, t28); set_name(t29, "t29"); assert_shape_2d(t29, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t30 = ggml_add (ctx, t29, t21); set_name(t30, "t30"); assert_shape_2d(t30, n_embd, N*n_batch);
|
||||
cur = t30;
|
||||
checkpoints.push_back(cur);
|
||||
|
@ -521,9 +521,9 @@ static void load_llama_model_gguf(struct gguf_context * fctx, struct ggml_contex
|
|||
copy_tensor_by_name(layer.wv, f_ggml_ctx, tni(LLM_TENSOR_ATTN_V, i));
|
||||
copy_tensor_by_name(layer.wo, f_ggml_ctx, tni(LLM_TENSOR_ATTN_OUT, i));
|
||||
copy_tensor_by_name(layer.ffn_norm, f_ggml_ctx, tni(LLM_TENSOR_FFN_NORM, i));
|
||||
copy_tensor_by_name(layer.w1, f_ggml_ctx, tni(LLM_TENSOR_FFN_GATE, i));
|
||||
copy_tensor_by_name(layer.w2, f_ggml_ctx, tni(LLM_TENSOR_FFN_DOWN, i));
|
||||
copy_tensor_by_name(layer.w3, f_ggml_ctx, tni(LLM_TENSOR_FFN_UP, i));
|
||||
copy_tensor_by_name(layer.ffn_gate, f_ggml_ctx, tni(LLM_TENSOR_FFN_GATE, i));
|
||||
copy_tensor_by_name(layer.ffn_down, f_ggml_ctx, tni(LLM_TENSOR_FFN_DOWN, i));
|
||||
copy_tensor_by_name(layer.ffn_up, f_ggml_ctx, tni(LLM_TENSOR_FFN_UP, i));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -664,9 +664,9 @@ static void save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vo
|
|||
gguf_add_tensor(fctx, layer.wv);
|
||||
gguf_add_tensor(fctx, layer.wo);
|
||||
gguf_add_tensor(fctx, layer.ffn_norm);
|
||||
gguf_add_tensor(fctx, layer.w1);
|
||||
gguf_add_tensor(fctx, layer.w2);
|
||||
gguf_add_tensor(fctx, layer.w3);
|
||||
gguf_add_tensor(fctx, layer.ffn_gate);
|
||||
gguf_add_tensor(fctx, layer.ffn_down);
|
||||
gguf_add_tensor(fctx, layer.ffn_up);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -915,9 +915,9 @@ static int64_t get_parameter_count(struct my_llama_model* model) {
|
|||
nx += ggml_nelements(layer.wv);
|
||||
nx += ggml_nelements(layer.wo);
|
||||
nx += ggml_nelements(layer.ffn_norm);
|
||||
nx += ggml_nelements(layer.w1);
|
||||
nx += ggml_nelements(layer.w2);
|
||||
nx += ggml_nelements(layer.w3);
|
||||
nx += ggml_nelements(layer.ffn_gate);
|
||||
nx += ggml_nelements(layer.ffn_down);
|
||||
nx += ggml_nelements(layer.ffn_up);
|
||||
}
|
||||
return nx;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue