fix cross entropy loss

- add target probabilities for each sample which is then used in cross entropy loss
This commit is contained in:
xaedes 2023-05-19 18:39:38 +02:00
parent 09b304d015
commit da86a1d736
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1025,78 +1025,93 @@ void print_tokens_batch(struct llama_context* ctx, struct ggml_tensor * tokens)
}
}
void get_example_targets(const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * targets) {
void get_example_targets(const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
int n_tokens = tokens_input->ne[0];
int n_vocab = targets->ne[0];
int n_vocab = target_logits->ne[0];
const float eps = 1e-6f;
const float target_prob = 1.0f;
int sample = train_samples[example_id % n_train_samples];
GGML_ASSERT(sample+n_tokens-1 < n_train_data);
ggml_set_f32(targets, -1.0f/n_vocab);
ggml_set_f32(target_logits, -1.0f/n_vocab);
ggml_set_f32(target_probs, 0.0f);
ggml_set_i32_1d(tokens_input, 0, llama_token_bos());
for (int i=1; i<n_tokens+1; ++i) {
int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
ggml_set_f32_1d(targets, (i-1)*n_vocab + token, +1.0f);
ggml_set_f32_1d(target_logits, (i-1)*n_vocab + token, +1.0f);
ggml_set_f32_1d(target_probs, (i-1)*n_vocab + token, -1.0f);
if (i<n_tokens) {
ggml_set_i32_1d(tokens_input, i, token);
}
}
}
void get_example_targets_batch(struct ggml_context * ctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * targets) {
GGML_ASSERT(tokens_input->n_dims == 2);
GGML_ASSERT( targets->n_dims == 3);
void get_example_targets_batch(struct ggml_context * ctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
GGML_ASSERT(tokens_input->n_dims == 2);
GGML_ASSERT(target_logits->n_dims == 3);
GGML_ASSERT(target_probs->n_dims == 3);
int n_tokens = tokens_input->ne[0];
int n_batch = tokens_input->ne[1];
GGML_ASSERT(n_tokens == targets->ne[1]);
GGML_ASSERT(n_batch == targets->ne[2]);
GGML_ASSERT(n_tokens == target_logits->ne[1]);
GGML_ASSERT(n_batch == target_logits->ne[2]);
GGML_ASSERT(n_tokens == target_probs->ne[1]);
GGML_ASSERT(n_batch == target_probs->ne[2]);
for (int k=0; k<n_batch; ++k) {
struct ggml_tensor * tokens_input_k = ggml_view_1d(ctx,
struct ggml_tensor * tokens_input_k = ggml_view_1d(ctx,
tokens_input,
tokens_input->ne[0],
k*tokens_input->nb[1]);
struct ggml_tensor * targets_k = ggml_view_2d(ctx,
targets,
targets->ne[0],
targets->ne[1],
targets->nb[1],
k*targets->nb[2]);
struct ggml_tensor * target_logits_k = ggml_view_2d(ctx,
target_logits,
target_logits->ne[0],
target_logits->ne[1],
target_logits->nb[1],
k*target_logits->nb[2]);
struct ggml_tensor * target_probs_k = ggml_view_2d(ctx,
target_probs,
target_probs->ne[0],
target_probs->ne[1],
target_probs->nb[1],
k*target_probs->nb[2]);
get_example_targets(train_samples, n_train_samples, train_data, n_train_data,
example_id*n_batch + k, tokens_input_k, targets_k);
example_id*n_batch + k, tokens_input_k, target_logits_k, target_probs_k);
}
}
void lshift_examples(struct ggml_tensor * tokens_input, struct ggml_tensor * targets, int n_shift) {
void lshift_examples(struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs, int n_shift) {
int n_tokens = tokens_input->ne[0];
int n_vocab = targets->ne[0];
int n_vocab = target_logits->ne[0];
for (int i=0; i<n_tokens-n_shift; ++i) {
ggml_set_i32_1d(tokens_input, i, ggml_get_i32_1d(tokens_input, i + n_shift));
for (int k=0; k<n_vocab; ++k) {
ggml_set_f32_1d(targets, i*n_vocab + k, ggml_get_f32_1d(targets, (i + n_shift)*n_vocab + k));
ggml_set_f32_1d(target_logits, i*n_vocab + k, ggml_get_f32_1d(target_logits, (i + n_shift)*n_vocab + k));
ggml_set_f32_1d(target_probs, i*n_vocab + k, ggml_get_f32_1d(target_probs, (i + n_shift)*n_vocab + k));
}
}
}
struct ggml_tensor * square_error_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
// todo: instead of a-b: a[1:]-b[:-1]
return ggml_sum(ctx, ggml_sqr(ctx, ggml_sub(ctx, a, b)));
struct ggml_tensor * square_error_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * target) {
return ggml_sum(ctx, ggml_sqr(ctx, ggml_sub(ctx, target, a)));
}
struct ggml_tensor * cross_entropy_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
const float eps = 1e-3;
struct ggml_tensor * cross_entropy_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * probs) {
const float eps = 1e-9f;
return
ggml_sum(ctx,
ggml_neg(ctx,
ggml_sum_rows(ctx,
ggml_mul(ctx,
ggml_soft_max(ctx, a),
ggml_log(ctx,
ggml_add1(ctx,
ggml_soft_max(ctx, b),
ggml_new_f32(ctx, eps)))))));
ggml_mul(ctx,
probs,
ggml_log(ctx,
ggml_add1(ctx,
ggml_scale(ctx,
ggml_soft_max(ctx, a),
ggml_new_f32(ctx, 1.0f-eps)),
ggml_new_f32(ctx, eps)))));
}
#ifdef __GNUC__
@ -1602,21 +1617,22 @@ int main(int argc, char ** argv) {
struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
struct ggml_tensor * targets = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
struct ggml_tensor * target_logits = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
int n_past = 0;
ggml_cgraph gf = {};
gf.n_threads = 6;
get_example_targets_batch(ctx0, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, targets);
get_example_targets_batch(ctx0, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs);
struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch);
struct ggml_tensor * se = square_error_loss(ctx0, targets, logits);
// struct ggml_tensor * ce = cross_entropy_loss(ctx0, targets, logits);
// struct ggml_tensor * se = square_error_loss(ctx0, logits, target_logits);
struct ggml_tensor * ce = cross_entropy_loss(ctx0, logits, target_probs);
// struct ggml_tensor * e = ggml_add(ctx0, se, ce);
// struct ggml_tensor * e = ce;
struct ggml_tensor * e = se;
struct ggml_tensor * e = ce;
// struct ggml_tensor * e = se;
ggml_build_forward_expand(&gf, e);
ggml_graph_compute(ctx0, &gf);
@ -1652,12 +1668,12 @@ int main(int argc, char ** argv) {
if (ex % 1 == 0) {
printf("Example %d\n", ex);
printf("error_before_opt: %.2f\n", error_before_opt);
printf("error_after_opt: %.2f\n", error_after_opt);
printf("error_before_opt: %.6f\n", error_before_opt);
printf("error_after_opt: %.6f\n", error_after_opt);
}
if (ex % 2 == 0) {
set_logits_masked(logits, token_notavail, -1e9);
// set_logits_masked(logits, token_notavail, -1e9);
for (int i=0; i<n_batch; ++i) {
init_sampler(&sampler, lctx);
for (int k=0; k<n_tokens; ++k) {
@ -1695,10 +1711,11 @@ int main(int argc, char ** argv) {
printf("Generating %d tokens.\n", n_gen);
struct ggml_tensor * tokens_input = ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, n_tokens);
struct ggml_tensor * targets = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens);
struct ggml_tensor * tokens_input = ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, n_tokens);
struct ggml_tensor * target_logits = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens);
struct ggml_tensor * target_probs = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens);
get_example_targets(train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), 137, tokens_input, targets);
get_example_targets(train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), rand()%train_samples.size(), tokens_input, target_logits, target_probs);
for (int i=sample_ctx; i<n_tokens; ++i) {
ggml_set_i32_1d(tokens_input, i, n_vocab/2);
}
@ -1728,7 +1745,7 @@ int main(int argc, char ** argv) {
struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
set_logits_masked(logits, token_notavail, -1e9);
// set_logits_masked(logits, token_notavail, -1e9);
int token = sample(&sampler,
(float *) ((char *) logits->data + (sample_ctx-1)*logits->nb[1]),
(llama_token *) tokens_input->data,
@ -1739,7 +1756,7 @@ int main(int argc, char ** argv) {
// print_row(probs, sample_at);
print_token(lctx, token);
lshift_examples(tokens_input, targets, 1);
lshift_examples(tokens_input, target_logits, target_probs, 1);
ggml_set_i32_1d(tokens_input, 0, 0);
ggml_set_i32_1d(tokens_input, sample_ctx-1, token);