remove prediction related code to reduce duplicated code with main
use main instead
This commit is contained in:
parent
5ce92aed37
commit
271c0300de
1 changed files with 2 additions and 613 deletions
|
@ -61,17 +61,6 @@ float frand_uniform(struct random_uniform_distribution * rnd) {
|
|||
return rnd->rd(rnd->gen);
|
||||
}
|
||||
|
||||
void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
|
||||
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
|
||||
|
||||
if (plan.work_size > 0) {
|
||||
buf.resize(plan.work_size);
|
||||
plan.work_data = buf.data();
|
||||
}
|
||||
|
||||
ggml_graph_compute(graph, &plan);
|
||||
}
|
||||
|
||||
struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) {
|
||||
float scale = 1.0f; // xavier
|
||||
switch (tensor->n_dims) {
|
||||
|
@ -165,17 +154,6 @@ struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struc
|
|||
return tensor;
|
||||
}
|
||||
|
||||
struct my_llama_kv_cache {
|
||||
struct ggml_context * ctx = NULL;
|
||||
|
||||
struct ggml_tensor * k;
|
||||
struct ggml_tensor * v;
|
||||
|
||||
// llama_ctx_buffer buf;
|
||||
|
||||
int n; // number of tokens currently in the cache
|
||||
};
|
||||
|
||||
struct llama_vocab {
|
||||
using id = int32_t;
|
||||
using token = std::string;
|
||||
|
@ -540,293 +518,6 @@ void randomize_lora(struct my_llama_lora * lora, int seed, float mean, float std
|
|||
}
|
||||
}
|
||||
|
||||
bool init_kv_cache(struct my_llama_kv_cache* cache, struct my_llama_model * model, int n_batch) {
|
||||
const auto & hparams = model->hparams;
|
||||
|
||||
const uint32_t n_ctx = hparams.n_ctx;
|
||||
const uint32_t n_embd = hparams.n_embd;
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
|
||||
const int64_t n_mem = n_layer*n_ctx*n_batch;
|
||||
const int64_t n_elements = n_embd*n_mem;
|
||||
|
||||
// cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
|
||||
|
||||
// struct ggml_init_params params;
|
||||
// params.mem_size = cache.buf.size;
|
||||
// params.mem_buffer = cache.buf.addr;
|
||||
// params.no_alloc = false;
|
||||
if (!cache->ctx) {
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = 2u*n_elements*ggml_type_size(GGML_TYPE_F32) + 2u*1024*1024;
|
||||
params.mem_buffer = NULL;
|
||||
params.no_alloc = false;
|
||||
|
||||
cache->ctx = ggml_init(params);
|
||||
|
||||
if (!cache->ctx) {
|
||||
fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
cache->k = ggml_new_tensor_1d(cache->ctx, GGML_TYPE_F32, n_elements);
|
||||
cache->v = ggml_new_tensor_1d(cache->ctx, GGML_TYPE_F32, n_elements);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
struct ggml_tensor * forward(
|
||||
struct my_llama_model * model,
|
||||
struct my_llama_lora * lora,
|
||||
struct my_llama_kv_cache * cache,
|
||||
struct ggml_context * ctx0,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_tensor * tokens_input,
|
||||
const int n_tokens,
|
||||
const int n_past) {
|
||||
|
||||
const int N = n_tokens;
|
||||
|
||||
struct my_llama_kv_cache& kv_self = *cache;
|
||||
const auto & hparams = model->hparams;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_head = hparams.n_head;
|
||||
const int n_rot = hparams.n_rot;
|
||||
|
||||
const float rms_norm_eps = hparams.f_rms_norm_eps;
|
||||
|
||||
GGML_ASSERT(n_layer == lora->layers.size());
|
||||
|
||||
struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
memcpy(tokens->data, tokens_input->data, N*ggml_element_size(tokens));
|
||||
|
||||
struct ggml_tensor * kc = kv_self.k;
|
||||
struct ggml_tensor * vc = kv_self.v;
|
||||
|
||||
struct ggml_tensor * tok_embeddings = ggml_add(ctx0, model->tok_embeddings, ggml_mul_mat(ctx0, lora->tok_embeddings_a, lora->tok_embeddings_b));
|
||||
struct ggml_tensor * norm = ggml_add(ctx0, model->norm, ggml_mul_mat(ctx0, lora->norm_a, lora->norm_b));
|
||||
struct ggml_tensor * output = ggml_add(ctx0, model->output, ggml_mul_mat(ctx0, lora->output_a, lora->output_b));
|
||||
|
||||
|
||||
// inpL shape [n_embd,N,1,1]
|
||||
struct ggml_tensor * inpL = ggml_get_rows(ctx0, tok_embeddings, tokens);
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
|
||||
// lctx.use_buf(ctx0, 0);
|
||||
struct ggml_tensor * attention_norm = ggml_add(ctx0, model->layers[il].attention_norm, ggml_mul_mat(ctx0, lora->layers[il].attention_norm_a, lora->layers[il].attention_norm_b));
|
||||
struct ggml_tensor * ffn_norm = ggml_add(ctx0, model->layers[il].ffn_norm, ggml_mul_mat(ctx0, lora->layers[il].ffn_norm_a, lora->layers[il].ffn_norm_b));
|
||||
struct ggml_tensor * wq = ggml_add(ctx0, model->layers[il].wq, ggml_mul_mat(ctx0, lora->layers[il].wq_a, lora->layers[il].wq_b));
|
||||
struct ggml_tensor * wk = ggml_add(ctx0, model->layers[il].wk, ggml_mul_mat(ctx0, lora->layers[il].wk_a, lora->layers[il].wk_b));
|
||||
struct ggml_tensor * wv = ggml_add(ctx0, model->layers[il].wv, ggml_mul_mat(ctx0, lora->layers[il].wv_a, lora->layers[il].wv_b));
|
||||
struct ggml_tensor * wo = ggml_add(ctx0, model->layers[il].wo, ggml_mul_mat(ctx0, lora->layers[il].wo_a, lora->layers[il].wo_b));
|
||||
struct ggml_tensor * w1 = ggml_add(ctx0, model->layers[il].w1, ggml_mul_mat(ctx0, lora->layers[il].w1_a, lora->layers[il].w1_b));
|
||||
struct ggml_tensor * w2 = ggml_add(ctx0, model->layers[il].w2, ggml_mul_mat(ctx0, lora->layers[il].w2_a, lora->layers[il].w2_b));
|
||||
struct ggml_tensor * w3 = ggml_add(ctx0, model->layers[il].w3, ggml_mul_mat(ctx0, lora->layers[il].w3_a, lora->layers[il].w3_b));
|
||||
|
||||
// norm
|
||||
{
|
||||
// cur shape [n_embd,N,1,1]
|
||||
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||
|
||||
// cur = attention_norm*cur
|
||||
cur = ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0,
|
||||
attention_norm,
|
||||
cur),
|
||||
cur);
|
||||
}
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
// wq shape [n_embd, n_embd, 1, 1]
|
||||
// wk shape [n_embd, n_embd, 1, 1]
|
||||
// Qcur shape [n_embd/n_head, n_head, N, 1]
|
||||
// Kcur shape [n_embd/n_head, n_head, N, 1]
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, n_ctx);
|
||||
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, n_ctx);
|
||||
|
||||
// store key and value to memory
|
||||
{
|
||||
// compute the transposed [N, n_embd] V matrix
|
||||
// wv shape [n_embd, n_embd, 1, 1]
|
||||
// Vcur shape [n_embd, N, 1, 1]
|
||||
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, wv, cur), n_embd, N)));
|
||||
|
||||
// kv_self.k shape [n_embd * n_ctx * n_layer, 1]
|
||||
// kv_self.v shape [n_embd * n_ctx * n_layer, 1]
|
||||
// k shape [n_embd * N, 1] == kv_self.k[:,n_past:n_past+N,il,0]
|
||||
// v shape [N, n_embd, 1, 1] == kv_self.v[:,n_past:n_past+N,il,0]
|
||||
|
||||
/* {
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
|
||||
|
||||
// important: storing RoPE-ed version of K in the KV cache!
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
} //*/
|
||||
|
||||
kc = ggml_set_1d_inplace(ctx0, kc, ggml_reshape_1d(ctx0, Kcur, n_embd*N), (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
|
||||
vc = ggml_set_2d_inplace(ctx0, vc, Vcur, ( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
|
||||
}
|
||||
|
||||
// Qcur shape [n_embd/n_head, n_head, N, 1]
|
||||
// Q shape [n_embd/n_head, N, n_head, 1]
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctx0,
|
||||
Qcur,
|
||||
0, 2, 1, 3);
|
||||
|
||||
// kv_self.k shape [n_embd * n_ctx * n_layer, 1]
|
||||
// K shape [n_embd/n_head, n_past + N, n_head, 1]
|
||||
struct ggml_tensor * K =
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, kc, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kc)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// K * Q
|
||||
// KQ shape [n_past + N, N, n_head, 1]
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
|
||||
// KQ_scaled = KQ / sqrt(n_embd/n_head)
|
||||
// KQ_scaled shape [n_past + N, N, n_head, 1]
|
||||
struct ggml_tensor * KQ_scaled =
|
||||
ggml_scale(ctx0,
|
||||
KQ,
|
||||
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
|
||||
|
||||
// KQ_masked = mask_past(KQ_scaled)
|
||||
// KQ_masked shape [n_past + N, N, n_head, 1]
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
||||
|
||||
// KQ = soft_max(KQ_masked)
|
||||
// KQ_soft_max shape [n_past + N, N, n_head, 1]
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
||||
|
||||
// split cached V into n_head heads
|
||||
//// V shape [n_past + N, n_embd/n_head, n_head, 1]
|
||||
// V shape [n_past + N, n_embd/n_head, n_head, 1] == kv_self.v[:,:(n_past+N),il,1]
|
||||
struct ggml_tensor * V =
|
||||
ggml_view_3d(ctx0, vc,
|
||||
n_past + N, n_embd/n_head, n_head,
|
||||
n_ctx*ggml_element_size(vc),
|
||||
n_ctx*ggml_element_size(vc)*n_embd/n_head,
|
||||
il*n_ctx*ggml_element_size(vc)*n_embd);
|
||||
|
||||
// KQV shape [n_embd/n_head, N, n_head, 1]
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||
|
||||
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
||||
// KQV_merged shape [n_embd/n_head, n_head, N, 1]
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
// KQV_merged shape
|
||||
|
||||
// cur = KQV_merged.contiguous().view(n_embd, N)
|
||||
// cur shape [n_embd,N,1,1]
|
||||
cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N);
|
||||
// cur = ggml_cpy(ctx0,
|
||||
// KQV_merged,
|
||||
// ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
||||
|
||||
// projection (no bias)
|
||||
// cur shape [n_embd,N,1,1]
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
wo,
|
||||
cur);
|
||||
}
|
||||
|
||||
// lctx.use_buf(ctx0, 1);
|
||||
|
||||
// inpFF shape [n_embd,N,1,1]
|
||||
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
// norm
|
||||
{
|
||||
// cur shape [n_embd,N,1,1]
|
||||
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||
|
||||
// cur = ffn_norm*cur
|
||||
// cur shape [n_embd,N,1,1]
|
||||
cur = ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0,
|
||||
ffn_norm,
|
||||
cur),
|
||||
cur);
|
||||
}
|
||||
|
||||
// tmp shape [n_ff,N,1,1]
|
||||
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
|
||||
w3,
|
||||
cur);
|
||||
|
||||
// cur shape [n_ff,N,1,1]
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
w1,
|
||||
cur);
|
||||
|
||||
// SILU activation
|
||||
// cur shape [n_ff,N,1,1]
|
||||
cur = ggml_silu(ctx0, cur);
|
||||
|
||||
// cur shape [n_ff,N,1,1]
|
||||
cur = ggml_mul(ctx0, cur, tmp);
|
||||
|
||||
// cur shape [n_embd,N,1,1]
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
w2,
|
||||
cur);
|
||||
}
|
||||
|
||||
// cur shape [n_embd,N,1,1]
|
||||
cur = ggml_add(ctx0, cur, inpFF);
|
||||
|
||||
// input for next layer
|
||||
// inpL shape [n_embd,N,1,1]
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
// norm
|
||||
{
|
||||
|
||||
// inpL shape [n_embd,N,1,1]
|
||||
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||
|
||||
// inpL = norm*inpL
|
||||
// inpL shape [n_embd,N,1,1]
|
||||
inpL = ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0,
|
||||
norm,
|
||||
inpL),
|
||||
inpL);
|
||||
|
||||
//embeddings = inpL;
|
||||
}
|
||||
|
||||
// lm_head
|
||||
// inpL shape [n_vocab,N,1,1]
|
||||
inpL = ggml_mul_mat(ctx0, output, inpL);
|
||||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(gf, inpL);
|
||||
|
||||
return inpL;
|
||||
}
|
||||
|
||||
void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) {
|
||||
GGML_ASSERT(tensor->n_dims == 1);
|
||||
GGML_ASSERT(tensor->ne[0] == ne0);
|
||||
|
@ -1292,61 +983,6 @@ int32_t get_i32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) {
|
|||
return *ptr;
|
||||
}
|
||||
|
||||
void print_row(struct ggml_tensor * probs, int i) {
|
||||
for (int k = 0; k < probs->ne[0]; ++k) {
|
||||
float p = get_f32_2d(probs, k, i);
|
||||
printf(" %.2f", p);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
void print_matrix(struct ggml_tensor * probs) {
|
||||
assert(probs->n_dims == 2);
|
||||
for (int i = 0; i < probs->ne[1]; ++i) {
|
||||
for (int k = 0; k < probs->ne[0]; ++k) {
|
||||
float p = get_f32_2d(probs, k, i);
|
||||
printf(" %.2f", p);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void print_token(struct llama_context * ctx, llama_token token) {
|
||||
printf("%s", llama_token_get_text(ctx, token));
|
||||
}
|
||||
|
||||
void print_tokens(struct llama_context* ctx, struct ggml_tensor * tokens) {
|
||||
for (int i=0; i<tokens->ne[0]; ++i) {
|
||||
int token = ggml_get_i32_1d(tokens, i);
|
||||
print_token(ctx, token);
|
||||
}
|
||||
}
|
||||
|
||||
void print_tokens_batch(struct llama_context* ctx, struct ggml_tensor * tokens) {
|
||||
for (int i1=0; i1<tokens->ne[1]; ++i1) {
|
||||
//int num_newline = 0;
|
||||
for (int i0=0; i0<tokens->ne[0]; ++i0) {
|
||||
int token = get_i32_2d(tokens, i0, i1);
|
||||
print_token(ctx, token);
|
||||
// bool isnl = (token == llama_token_nl());
|
||||
// if (isnl) {
|
||||
// ++num_newline;
|
||||
// }
|
||||
// if (isnl) {
|
||||
// if (num_newline < 2) {
|
||||
// print_token(ctx, token);
|
||||
// } else {
|
||||
// printf("\\n");
|
||||
// }
|
||||
// } else {
|
||||
// print_token(ctx, token);
|
||||
// }
|
||||
}
|
||||
printf("\n--\n");
|
||||
}
|
||||
}
|
||||
|
||||
void get_example_targets(struct llama_context * lctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
|
||||
int n_tokens = tokens_input->ne[0];
|
||||
int n_vocab = target_logits->ne[0];
|
||||
|
@ -1402,19 +1038,6 @@ void get_example_targets_batch(struct llama_context* lctx, const int * train_sam
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
void lshift_examples(struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs, int n_shift) {
|
||||
int n_tokens = tokens_input->ne[0];
|
||||
int n_vocab = target_logits->ne[0];
|
||||
for (int i=0; i<n_tokens-n_shift; ++i) {
|
||||
ggml_set_i32_1d(tokens_input, i, ggml_get_i32_1d(tokens_input, i + n_shift));
|
||||
for (int k=0; k<n_vocab; ++k) {
|
||||
ggml_set_f32_1d(target_logits, i*n_vocab + k, ggml_get_f32_1d(target_logits, (i + n_shift)*n_vocab + k));
|
||||
ggml_set_f32_1d(target_probs, i*n_vocab + k, ggml_get_f32_1d(target_probs, (i + n_shift)*n_vocab + k));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __GNUC__
|
||||
#ifdef __MINGW32__
|
||||
__attribute__((format(gnu_printf, 1, 2)))
|
||||
|
@ -1576,112 +1199,6 @@ void shuffle_ints(int * begin, int * end) {
|
|||
});
|
||||
}
|
||||
|
||||
struct my_llama_sampler_params {
|
||||
float temp = 0.0f; // <= 0.0 disabled
|
||||
int top_k = 20; // <= 0 to use vocab size
|
||||
float top_p = 0.95f; // 1.0 = disabled
|
||||
float tfs_z = 1.00f; // 1.0 = disabled
|
||||
float typical_p = 1.00f; // 1.0 = disabled
|
||||
int repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
float repeat_penalty = 1.0f; // 1.0 = disabled
|
||||
float presence_penalty = 0.0f; // 0.0 = disabled
|
||||
float frequency_penalty = 0.0f; // 0.0 = disabled
|
||||
int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
bool penalize_nl = true; // consider newlines as a repeatable token
|
||||
};
|
||||
|
||||
struct my_llama_sampler {
|
||||
struct llama_context * ctx = NULL;
|
||||
my_llama_sampler_params params;
|
||||
|
||||
int n_vocab = 0;
|
||||
int n_ctx = 0;
|
||||
|
||||
float mirostat_mu;
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
llama_token_data_array candidates_p;
|
||||
|
||||
};
|
||||
|
||||
void init_sampler(struct my_llama_sampler * sampler, struct llama_context * ctx) {
|
||||
sampler->ctx = ctx;
|
||||
sampler->n_vocab = llama_n_vocab(sampler->ctx);
|
||||
sampler->n_ctx = llama_n_ctx(sampler->ctx);
|
||||
sampler->mirostat_mu = 2.0f * sampler->params.mirostat_tau;
|
||||
}
|
||||
|
||||
llama_token sample(struct llama_context * lctx, struct my_llama_sampler * sampler, float * logits, const llama_token * last_tokens, int n_last_tokens) {
|
||||
GGML_ASSERT(sampler->ctx != NULL);
|
||||
|
||||
struct llama_context * ctx = sampler->ctx;
|
||||
|
||||
sampler->candidates.resize(sampler->n_vocab);
|
||||
for (llama_token token_id = 0; token_id < sampler->n_vocab; ++token_id) {
|
||||
sampler->candidates[token_id].id = token_id;
|
||||
sampler->candidates[token_id].logit = logits[token_id];
|
||||
sampler->candidates[token_id].p = 0.0;
|
||||
}
|
||||
|
||||
llama_token_data_array * candidates_p = & sampler->candidates_p;
|
||||
|
||||
candidates_p->data = sampler->candidates.data();
|
||||
candidates_p->size = sampler->candidates.size();
|
||||
candidates_p->sorted = false;
|
||||
|
||||
const auto params = sampler->params;
|
||||
|
||||
// Apply penalties
|
||||
const float nl_logit = logits[llama_token_nl(lctx)];
|
||||
|
||||
const int n_last = std::min(std::min(n_last_tokens, params.repeat_last_n), sampler->n_ctx);
|
||||
|
||||
llama_sample_repetition_penalty(
|
||||
ctx,
|
||||
candidates_p,
|
||||
last_tokens + n_last_tokens - n_last,
|
||||
n_last,
|
||||
params.repeat_penalty);
|
||||
llama_sample_frequency_and_presence_penalties(
|
||||
ctx,
|
||||
candidates_p,
|
||||
last_tokens + n_last_tokens - n_last,
|
||||
n_last,
|
||||
params.frequency_penalty,
|
||||
params.presence_penalty);
|
||||
|
||||
if (!params.penalize_nl) {
|
||||
logits[llama_token_nl(lctx)] = nl_logit;
|
||||
}
|
||||
|
||||
llama_token token = 0;
|
||||
if (params.temp <= 0) {
|
||||
// Greedy sampling
|
||||
token = llama_sample_token_greedy(ctx, candidates_p);
|
||||
} else {
|
||||
if (params.mirostat == 1) {
|
||||
int mirostat_m = 100;
|
||||
llama_sample_temperature(ctx, candidates_p, params.temp);
|
||||
token = llama_sample_token_mirostat(ctx, candidates_p, params.mirostat_tau, params.mirostat_eta, mirostat_m, &sampler->mirostat_mu);
|
||||
} else if (params.mirostat == 2) {
|
||||
llama_sample_temperature(ctx, candidates_p, params.temp);
|
||||
token = llama_sample_token_mirostat_v2(ctx, candidates_p, params.mirostat_tau, params.mirostat_eta, &sampler->mirostat_mu);
|
||||
} else {
|
||||
// Temperature sampling
|
||||
llama_sample_top_k (ctx, candidates_p, params.top_k, 1);
|
||||
llama_sample_tail_free (ctx, candidates_p, params.tfs_z, 1);
|
||||
llama_sample_typical (ctx, candidates_p, params.typical_p, 1);
|
||||
|
||||
llama_sample_top_p (ctx, candidates_p, params.top_p, 1);
|
||||
llama_sample_temperature (ctx, candidates_p, params.temp);
|
||||
token = llama_sample_token(ctx, candidates_p);
|
||||
}
|
||||
}
|
||||
return token;
|
||||
}
|
||||
|
||||
void write_tensor(struct llama_file * file, struct ggml_tensor * tensor) {
|
||||
if (tensor == NULL) {
|
||||
file->write_u32(0);
|
||||
|
@ -2111,7 +1628,6 @@ struct train_params {
|
|||
int n_threads;
|
||||
int n_batch;
|
||||
int n_examples;
|
||||
int n_predict;
|
||||
|
||||
int32_t lora_r;
|
||||
int32_t lora_alpha;
|
||||
|
@ -2130,7 +1646,6 @@ struct train_params {
|
|||
int n_rank_output;
|
||||
|
||||
int print_info_interval;
|
||||
int print_details_interval;
|
||||
|
||||
bool samples_start_after_nl;
|
||||
bool use_adam;
|
||||
|
@ -2183,7 +1698,6 @@ struct train_params get_default_train_params() {
|
|||
params.n_threads = 6;
|
||||
params.n_batch = 8;
|
||||
params.n_examples = 1;
|
||||
params.n_predict = 1024;
|
||||
|
||||
params.lora_alpha = 4;
|
||||
params.lora_r = 4;
|
||||
|
@ -2202,7 +1716,6 @@ struct train_params get_default_train_params() {
|
|||
params.n_rank_output = 4;
|
||||
|
||||
params.print_info_interval = 1;
|
||||
params.print_details_interval = 2;
|
||||
|
||||
params.samples_start_after_nl = false;
|
||||
params.use_adam = true;
|
||||
|
@ -2256,7 +1769,6 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
|
|||
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
|
||||
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
|
||||
fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples);
|
||||
fprintf(stderr, " --predict N Number of tokens to generate after training (default %d)\n", params->n_predict);
|
||||
fprintf(stderr, " --lora-alpha N LORA alpha : resulting LORA scaling is alpha/r. (default %d)\n", params->lora_alpha);
|
||||
fprintf(stderr, " --lora-r N LORA r : resulting LORA scaling is alpha/r. (default %d)\n", params->lora_r);
|
||||
fprintf(stderr, " --rank-att-norm N LORA rank for attention norm tensor (default %d)\n", params->n_rank_attention_norm);
|
||||
|
@ -2272,7 +1784,6 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
|
|||
fprintf(stderr, " --rank-w2 N LORA rank for w2 tensor (default %d)\n", params->n_rank_w2);
|
||||
fprintf(stderr, " --rank-w3 N LORA rank for w3 tensor (default %d)\n", params->n_rank_w3);
|
||||
fprintf(stderr, " --print-info-interval N Print infos during training each N examples (default %d)\n", params->print_info_interval);
|
||||
fprintf(stderr, " --print-details-interval N Print details during training each N examples (default %d)\n", params->print_details_interval);
|
||||
fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off");
|
||||
fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n");
|
||||
fprintf(stderr, " --use-adam Use Adam optimizer (default)\n");
|
||||
|
@ -2301,7 +1812,7 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
|
|||
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
|
||||
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
|
||||
fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter);
|
||||
fprintf(stderr, " --mem-lora N Memory to allocate for LORA and cache in gigabytes. (default %d)\n", params->mem_lora_gb);
|
||||
fprintf(stderr, " --mem-lora N Memory to allocate for LORA in gigabytes. (default %d)\n", params->mem_lora_gb);
|
||||
fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
|
||||
fprintf(stderr, " --mem-compute0 N Memory to allocate for automatic memory allocator in gigabytes. (default %d)\n", params->mem_compute0_gb);
|
||||
fprintf(stderr, "\n");
|
||||
|
@ -2397,12 +1908,6 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
|||
break;
|
||||
}
|
||||
params->n_examples = std::stoi(argv[i]);
|
||||
} else if (arg == "--predict") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->n_predict = std::stoi(argv[i]);
|
||||
} else if (arg == "--lora-alpha") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -2493,12 +1998,6 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
|||
break;
|
||||
}
|
||||
params->print_info_interval = std::stoi(argv[i]);
|
||||
} else if (arg == "--print-details-interval") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->print_details_interval = std::stoi(argv[i]);
|
||||
} else if (arg == "--samples-after-nl") {
|
||||
params->samples_start_after_nl = true;
|
||||
} else if (arg == "--use-lbfgs") {
|
||||
|
@ -2824,17 +2323,12 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens);
|
||||
|
||||
struct my_llama_kv_cache kv_self;
|
||||
|
||||
struct ggml_init_params lcparams;
|
||||
lcparams.mem_size = 1024ll*1024ll*1024ll*((size_t) params.mem_lora_gb);
|
||||
lcparams.mem_buffer = NULL;
|
||||
lcparams.no_alloc = false;
|
||||
|
||||
lora.ctx = ggml_init(lcparams);
|
||||
kv_self.ctx = lora.ctx;
|
||||
|
||||
my_llama_sampler sampler;
|
||||
|
||||
int n_tokens = model.hparams.n_ctx;
|
||||
int n_vocab = model.hparams.n_vocab;
|
||||
|
@ -2886,11 +2380,7 @@ int main(int argc, char ** argv) {
|
|||
randomize_lora(&lora, params.seed, 0.0f, 1.0f, -1.0f, +1.0f);
|
||||
}
|
||||
|
||||
init_kv_cache(&kv_self, &model, 1);
|
||||
// init_kv_cache(&kv_self, &model, n_batch);
|
||||
init_sampler(&sampler, lctx);
|
||||
|
||||
printf("used_mem model+cache: %zu bytes\n", ggml_used_mem(lora.ctx));
|
||||
printf("used_mem model: %zu bytes\n", ggml_used_mem(lora.ctx));
|
||||
// ggml_print_tensor_objects(lora.ctx);
|
||||
|
||||
// TODO: use std::vector<uint8_t> intead of "new"
|
||||
|
@ -2919,8 +2409,6 @@ int main(int argc, char ** argv) {
|
|||
GGML_ASSERT(train_samples[i]+n_tokens-1 < (int) train_tokens.size());
|
||||
}
|
||||
|
||||
std::vector<uint8_t> work_buffer;
|
||||
|
||||
printf("%s: begin training\n", __func__);
|
||||
|
||||
struct opt_callback_data opt_cb_data;
|
||||
|
@ -2959,8 +2447,6 @@ int main(int argc, char ** argv) {
|
|||
ggml_set_no_alloc(ctx0, false);
|
||||
|
||||
// don't use alloc for input tensors, so we can safely fill them with data
|
||||
struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
|
||||
//struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
||||
struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
|
||||
struct ggml_tensor * target_logits = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
||||
struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
||||
|
@ -3026,31 +2512,6 @@ int main(int argc, char ** argv) {
|
|||
printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt);
|
||||
}
|
||||
|
||||
if (params.print_details_interval > 0 && ex % params.print_details_interval == 0) {
|
||||
// set_logits_masked(logits, token_notavail, -1e9);
|
||||
for (int i=0; i<n_batch; ++i) {
|
||||
init_sampler(&sampler, lctx);
|
||||
for (int k=0; k<n_tokens; ++k) {
|
||||
int32_t token = sample(lctx, &sampler,
|
||||
(float *) ((char *) logits->data + i*logits->nb[2] + k*logits->nb[1]),
|
||||
(llama_token *) ((char *) tokens_input->data + i*tokens_input->nb[1]),
|
||||
k);
|
||||
* ((int32_t *) ((char *) after_opt_best_samples->data + i*after_opt_best_samples->nb[1] + k*after_opt_best_samples->nb[0])) = token;
|
||||
}
|
||||
}
|
||||
|
||||
// printf("probabilities after optimization:\n");
|
||||
// print_matrix(after_opt_probs);
|
||||
printf("Example:\n---\n");
|
||||
print_tokens_batch(lctx, tokens_input);
|
||||
printf("\n---\n");
|
||||
|
||||
// printf("best samples after optimization:\n---\n");
|
||||
printf("samples after optimization:\n---\n");
|
||||
print_tokens_batch(lctx, after_opt_best_samples);
|
||||
printf("\n---\n");
|
||||
}
|
||||
|
||||
ggml_free(ctx0);
|
||||
}
|
||||
|
||||
|
@ -3076,78 +2537,6 @@ int main(int argc, char ** argv) {
|
|||
|
||||
opt_cb_data.last_save_iter = opt->iter;
|
||||
|
||||
{
|
||||
int n_gen = params.n_predict;
|
||||
int sample_ctx = n_tokens - n_tokens/8;
|
||||
|
||||
// use defaults from common.h
|
||||
sampler.params.top_k = 40;
|
||||
sampler.params.top_p = 0.95f;
|
||||
sampler.params.tfs_z = 1.00f;
|
||||
sampler.params.typical_p = 1.00f;
|
||||
sampler.params.temp = 0.8f;
|
||||
sampler.params.repeat_penalty = 1.1f;
|
||||
sampler.params.repeat_last_n = 64;
|
||||
sampler.params.frequency_penalty = 0.0f;
|
||||
sampler.params.presence_penalty = 0.0f;
|
||||
sampler.params.mirostat = 0;
|
||||
sampler.params.mirostat_tau = 5.00f;
|
||||
sampler.params.mirostat_eta = 0.10f;
|
||||
init_sampler(&sampler, lctx);
|
||||
|
||||
printf("[Prediction context]\n");
|
||||
|
||||
struct ggml_tensor * tokens_input = ggml_new_tensor_1d(lora.ctx, GGML_TYPE_I32, n_tokens);
|
||||
struct ggml_tensor * target_logits = ggml_new_tensor_2d(lora.ctx, GGML_TYPE_F32, n_vocab, n_tokens);
|
||||
struct ggml_tensor * target_probs = ggml_new_tensor_2d(lora.ctx, GGML_TYPE_F32, n_vocab, n_tokens);
|
||||
|
||||
get_example_targets(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), rand()%train_samples.size(), tokens_input, target_logits, target_probs);
|
||||
for (int i=sample_ctx; i<n_tokens; ++i) {
|
||||
ggml_set_i32_1d(tokens_input, i, n_vocab/2);
|
||||
}
|
||||
|
||||
for (int i=0; i<sample_ctx-1; ++i) {
|
||||
print_token(lctx, ggml_get_i32_1d(tokens_input, i));
|
||||
}
|
||||
|
||||
printf("\n[Generating %d tokens]\n", n_gen);
|
||||
for (int i=0; i<n_gen; ++i) {
|
||||
struct ggml_init_params cparams = {
|
||||
compute_size, // .mem_size
|
||||
compute_addr, // .mem_buffer
|
||||
false, // .no_alloc
|
||||
};
|
||||
struct ggml_context * ctx0 = ggml_init(cparams);
|
||||
|
||||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
int n_past = 0;
|
||||
struct ggml_tensor * logits = forward(&model, &lora, &kv_self, ctx0, gf, tokens_input, sample_ctx, n_past);
|
||||
|
||||
ggml_build_forward_expand(gf, logits);
|
||||
ggml_graph_compute_helper(work_buffer, gf, params.n_threads);
|
||||
|
||||
//struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
|
||||
//struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
|
||||
|
||||
// set_logits_masked(logits, token_notavail, -1e9);
|
||||
int token = sample(lctx, &sampler,
|
||||
(float *) ((char *) logits->data + (sample_ctx-1)*logits->nb[1]),
|
||||
(llama_token *) tokens_input->data,
|
||||
sample_ctx-1);
|
||||
//int token = ggml_get_i32_1d(best_samples, sample_ctx-1);
|
||||
|
||||
// print_row(probs, sample_at);
|
||||
print_token(lctx, token);
|
||||
|
||||
lshift_examples(tokens_input, target_logits, target_probs, 1);
|
||||
ggml_set_i32_1d(tokens_input, 0, 0);
|
||||
ggml_set_i32_1d(tokens_input, sample_ctx-1, token);
|
||||
|
||||
ggml_free(ctx0);
|
||||
}
|
||||
}
|
||||
|
||||
if (alloc) {
|
||||
ggml_allocr_free(alloc);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue