fix cross entropy loss
- add target probabilities for each sample which is then used in cross entropy loss
This commit is contained in:
parent
09b304d015
commit
da86a1d736
1 changed files with 64 additions and 47 deletions
|
@ -1025,78 +1025,93 @@ void print_tokens_batch(struct llama_context* ctx, struct ggml_tensor * tokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_example_targets(const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * targets) {
|
void get_example_targets(const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
|
||||||
int n_tokens = tokens_input->ne[0];
|
int n_tokens = tokens_input->ne[0];
|
||||||
int n_vocab = targets->ne[0];
|
int n_vocab = target_logits->ne[0];
|
||||||
|
|
||||||
|
const float eps = 1e-6f;
|
||||||
|
const float target_prob = 1.0f;
|
||||||
|
|
||||||
int sample = train_samples[example_id % n_train_samples];
|
int sample = train_samples[example_id % n_train_samples];
|
||||||
GGML_ASSERT(sample+n_tokens-1 < n_train_data);
|
GGML_ASSERT(sample+n_tokens-1 < n_train_data);
|
||||||
|
|
||||||
ggml_set_f32(targets, -1.0f/n_vocab);
|
ggml_set_f32(target_logits, -1.0f/n_vocab);
|
||||||
|
ggml_set_f32(target_probs, 0.0f);
|
||||||
ggml_set_i32_1d(tokens_input, 0, llama_token_bos());
|
ggml_set_i32_1d(tokens_input, 0, llama_token_bos());
|
||||||
for (int i=1; i<n_tokens+1; ++i) {
|
for (int i=1; i<n_tokens+1; ++i) {
|
||||||
int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
|
int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
|
||||||
ggml_set_f32_1d(targets, (i-1)*n_vocab + token, +1.0f);
|
ggml_set_f32_1d(target_logits, (i-1)*n_vocab + token, +1.0f);
|
||||||
|
ggml_set_f32_1d(target_probs, (i-1)*n_vocab + token, -1.0f);
|
||||||
if (i<n_tokens) {
|
if (i<n_tokens) {
|
||||||
ggml_set_i32_1d(tokens_input, i, token);
|
ggml_set_i32_1d(tokens_input, i, token);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_example_targets_batch(struct ggml_context * ctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * targets) {
|
void get_example_targets_batch(struct ggml_context * ctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
|
||||||
GGML_ASSERT(tokens_input->n_dims == 2);
|
GGML_ASSERT(tokens_input->n_dims == 2);
|
||||||
GGML_ASSERT( targets->n_dims == 3);
|
GGML_ASSERT(target_logits->n_dims == 3);
|
||||||
|
GGML_ASSERT(target_probs->n_dims == 3);
|
||||||
int n_tokens = tokens_input->ne[0];
|
int n_tokens = tokens_input->ne[0];
|
||||||
int n_batch = tokens_input->ne[1];
|
int n_batch = tokens_input->ne[1];
|
||||||
GGML_ASSERT(n_tokens == targets->ne[1]);
|
GGML_ASSERT(n_tokens == target_logits->ne[1]);
|
||||||
GGML_ASSERT(n_batch == targets->ne[2]);
|
GGML_ASSERT(n_batch == target_logits->ne[2]);
|
||||||
|
GGML_ASSERT(n_tokens == target_probs->ne[1]);
|
||||||
|
GGML_ASSERT(n_batch == target_probs->ne[2]);
|
||||||
|
|
||||||
for (int k=0; k<n_batch; ++k) {
|
for (int k=0; k<n_batch; ++k) {
|
||||||
struct ggml_tensor * tokens_input_k = ggml_view_1d(ctx,
|
struct ggml_tensor * tokens_input_k = ggml_view_1d(ctx,
|
||||||
tokens_input,
|
tokens_input,
|
||||||
tokens_input->ne[0],
|
tokens_input->ne[0],
|
||||||
k*tokens_input->nb[1]);
|
k*tokens_input->nb[1]);
|
||||||
struct ggml_tensor * targets_k = ggml_view_2d(ctx,
|
struct ggml_tensor * target_logits_k = ggml_view_2d(ctx,
|
||||||
targets,
|
target_logits,
|
||||||
targets->ne[0],
|
target_logits->ne[0],
|
||||||
targets->ne[1],
|
target_logits->ne[1],
|
||||||
targets->nb[1],
|
target_logits->nb[1],
|
||||||
k*targets->nb[2]);
|
k*target_logits->nb[2]);
|
||||||
|
|
||||||
|
struct ggml_tensor * target_probs_k = ggml_view_2d(ctx,
|
||||||
|
target_probs,
|
||||||
|
target_probs->ne[0],
|
||||||
|
target_probs->ne[1],
|
||||||
|
target_probs->nb[1],
|
||||||
|
k*target_probs->nb[2]);
|
||||||
|
|
||||||
get_example_targets(train_samples, n_train_samples, train_data, n_train_data,
|
get_example_targets(train_samples, n_train_samples, train_data, n_train_data,
|
||||||
example_id*n_batch + k, tokens_input_k, targets_k);
|
example_id*n_batch + k, tokens_input_k, target_logits_k, target_probs_k);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void lshift_examples(struct ggml_tensor * tokens_input, struct ggml_tensor * targets, int n_shift) {
|
void lshift_examples(struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs, int n_shift) {
|
||||||
int n_tokens = tokens_input->ne[0];
|
int n_tokens = tokens_input->ne[0];
|
||||||
int n_vocab = targets->ne[0];
|
int n_vocab = target_logits->ne[0];
|
||||||
for (int i=0; i<n_tokens-n_shift; ++i) {
|
for (int i=0; i<n_tokens-n_shift; ++i) {
|
||||||
ggml_set_i32_1d(tokens_input, i, ggml_get_i32_1d(tokens_input, i + n_shift));
|
ggml_set_i32_1d(tokens_input, i, ggml_get_i32_1d(tokens_input, i + n_shift));
|
||||||
for (int k=0; k<n_vocab; ++k) {
|
for (int k=0; k<n_vocab; ++k) {
|
||||||
ggml_set_f32_1d(targets, i*n_vocab + k, ggml_get_f32_1d(targets, (i + n_shift)*n_vocab + k));
|
ggml_set_f32_1d(target_logits, i*n_vocab + k, ggml_get_f32_1d(target_logits, (i + n_shift)*n_vocab + k));
|
||||||
|
ggml_set_f32_1d(target_probs, i*n_vocab + k, ggml_get_f32_1d(target_probs, (i + n_shift)*n_vocab + k));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * square_error_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
|
struct ggml_tensor * square_error_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * target) {
|
||||||
// todo: instead of a-b: a[1:]-b[:-1]
|
return ggml_sum(ctx, ggml_sqr(ctx, ggml_sub(ctx, target, a)));
|
||||||
return ggml_sum(ctx, ggml_sqr(ctx, ggml_sub(ctx, a, b)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * cross_entropy_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
|
struct ggml_tensor * cross_entropy_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * probs) {
|
||||||
const float eps = 1e-3;
|
const float eps = 1e-9f;
|
||||||
return
|
return
|
||||||
ggml_sum(ctx,
|
ggml_sum(ctx,
|
||||||
ggml_neg(ctx,
|
ggml_mul(ctx,
|
||||||
ggml_sum_rows(ctx,
|
probs,
|
||||||
ggml_mul(ctx,
|
ggml_log(ctx,
|
||||||
ggml_soft_max(ctx, a),
|
ggml_add1(ctx,
|
||||||
ggml_log(ctx,
|
ggml_scale(ctx,
|
||||||
ggml_add1(ctx,
|
ggml_soft_max(ctx, a),
|
||||||
ggml_soft_max(ctx, b),
|
ggml_new_f32(ctx, 1.0f-eps)),
|
||||||
ggml_new_f32(ctx, eps)))))));
|
ggml_new_f32(ctx, eps)))));
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef __GNUC__
|
#ifdef __GNUC__
|
||||||
|
@ -1602,21 +1617,22 @@ int main(int argc, char ** argv) {
|
||||||
struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
|
struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
|
||||||
struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
||||||
struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
|
struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
|
||||||
struct ggml_tensor * targets = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
struct ggml_tensor * target_logits = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
||||||
|
struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
||||||
|
|
||||||
int n_past = 0;
|
int n_past = 0;
|
||||||
|
|
||||||
ggml_cgraph gf = {};
|
ggml_cgraph gf = {};
|
||||||
gf.n_threads = 6;
|
gf.n_threads = 6;
|
||||||
|
|
||||||
get_example_targets_batch(ctx0, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, targets);
|
get_example_targets_batch(ctx0, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs);
|
||||||
|
|
||||||
struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch);
|
struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch);
|
||||||
struct ggml_tensor * se = square_error_loss(ctx0, targets, logits);
|
// struct ggml_tensor * se = square_error_loss(ctx0, logits, target_logits);
|
||||||
// struct ggml_tensor * ce = cross_entropy_loss(ctx0, targets, logits);
|
struct ggml_tensor * ce = cross_entropy_loss(ctx0, logits, target_probs);
|
||||||
// struct ggml_tensor * e = ggml_add(ctx0, se, ce);
|
// struct ggml_tensor * e = ggml_add(ctx0, se, ce);
|
||||||
// struct ggml_tensor * e = ce;
|
struct ggml_tensor * e = ce;
|
||||||
struct ggml_tensor * e = se;
|
// struct ggml_tensor * e = se;
|
||||||
|
|
||||||
ggml_build_forward_expand(&gf, e);
|
ggml_build_forward_expand(&gf, e);
|
||||||
ggml_graph_compute(ctx0, &gf);
|
ggml_graph_compute(ctx0, &gf);
|
||||||
|
@ -1652,12 +1668,12 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
if (ex % 1 == 0) {
|
if (ex % 1 == 0) {
|
||||||
printf("Example %d\n", ex);
|
printf("Example %d\n", ex);
|
||||||
printf("error_before_opt: %.2f\n", error_before_opt);
|
printf("error_before_opt: %.6f\n", error_before_opt);
|
||||||
printf("error_after_opt: %.2f\n", error_after_opt);
|
printf("error_after_opt: %.6f\n", error_after_opt);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ex % 2 == 0) {
|
if (ex % 2 == 0) {
|
||||||
set_logits_masked(logits, token_notavail, -1e9);
|
// set_logits_masked(logits, token_notavail, -1e9);
|
||||||
for (int i=0; i<n_batch; ++i) {
|
for (int i=0; i<n_batch; ++i) {
|
||||||
init_sampler(&sampler, lctx);
|
init_sampler(&sampler, lctx);
|
||||||
for (int k=0; k<n_tokens; ++k) {
|
for (int k=0; k<n_tokens; ++k) {
|
||||||
|
@ -1695,10 +1711,11 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
printf("Generating %d tokens.\n", n_gen);
|
printf("Generating %d tokens.\n", n_gen);
|
||||||
|
|
||||||
struct ggml_tensor * tokens_input = ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, n_tokens);
|
struct ggml_tensor * tokens_input = ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, n_tokens);
|
||||||
struct ggml_tensor * targets = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens);
|
struct ggml_tensor * target_logits = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens);
|
||||||
|
struct ggml_tensor * target_probs = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens);
|
||||||
|
|
||||||
get_example_targets(train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), 137, tokens_input, targets);
|
get_example_targets(train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), rand()%train_samples.size(), tokens_input, target_logits, target_probs);
|
||||||
for (int i=sample_ctx; i<n_tokens; ++i) {
|
for (int i=sample_ctx; i<n_tokens; ++i) {
|
||||||
ggml_set_i32_1d(tokens_input, i, n_vocab/2);
|
ggml_set_i32_1d(tokens_input, i, n_vocab/2);
|
||||||
}
|
}
|
||||||
|
@ -1728,7 +1745,7 @@ int main(int argc, char ** argv) {
|
||||||
struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
|
struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
|
||||||
struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
|
struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
|
||||||
|
|
||||||
set_logits_masked(logits, token_notavail, -1e9);
|
// set_logits_masked(logits, token_notavail, -1e9);
|
||||||
int token = sample(&sampler,
|
int token = sample(&sampler,
|
||||||
(float *) ((char *) logits->data + (sample_ctx-1)*logits->nb[1]),
|
(float *) ((char *) logits->data + (sample_ctx-1)*logits->nb[1]),
|
||||||
(llama_token *) tokens_input->data,
|
(llama_token *) tokens_input->data,
|
||||||
|
@ -1739,7 +1756,7 @@ int main(int argc, char ** argv) {
|
||||||
// print_row(probs, sample_at);
|
// print_row(probs, sample_at);
|
||||||
print_token(lctx, token);
|
print_token(lctx, token);
|
||||||
|
|
||||||
lshift_examples(tokens_input, targets, 1);
|
lshift_examples(tokens_input, target_logits, target_probs, 1);
|
||||||
ggml_set_i32_1d(tokens_input, 0, 0);
|
ggml_set_i32_1d(tokens_input, 0, 0);
|
||||||
ggml_set_i32_1d(tokens_input, sample_ctx-1, token);
|
ggml_set_i32_1d(tokens_input, sample_ctx-1, token);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue