fix bug in get_samples which corrupted training targets

This commit is contained in:
xaedes 2023-05-22 16:55:52 +02:00
parent b763d6f1f2
commit cc440bd438
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1184,16 +1184,40 @@ void print_tokens_batch(struct llama_context* ctx, struct ggml_tensor * tokens)
if (isnl) { if (isnl) {
++num_newline; ++num_newline;
} }
if (!isnl || (num_newline < 2)) { if (isnl) {
print_token(ctx, token); if (num_newline < 2) {
print_token(ctx, token);
} else {
printf("\\n");
}
} else { } else {
printf("\\n"); print_token(ctx, token);
} }
} }
printf("\n--\n"); printf("\n--\n");
} }
} }
void set_f32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, float value) {
float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
*ptr = value;
}
void set_i32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, int32_t value) {
int32_t * ptr = (int32_t *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
*ptr = value;
}
float get_f32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) {
float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
return *ptr;
}
int32_t get_i32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) {
int32_t * ptr = (int32_t *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
return *ptr;
}
void get_example_targets(const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) { void get_example_targets(const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
int n_tokens = tokens_input->ne[0]; int n_tokens = tokens_input->ne[0];
int n_vocab = target_logits->ne[0]; int n_vocab = target_logits->ne[0];
@ -1209,8 +1233,8 @@ void get_example_targets(const int * train_samples, size_t n_train_samples, cons
ggml_set_i32_1d(tokens_input, 0, llama_token_bos()); ggml_set_i32_1d(tokens_input, 0, llama_token_bos());
for (int i=1; i<n_tokens+1; ++i) { for (int i=1; i<n_tokens+1; ++i) {
int token = clamp(train_data[sample+i-1], 0, n_vocab-1); int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
ggml_set_f32_1d(target_logits, (i-1)*n_vocab + token, +1.0f); set_f32_2d(target_logits, token, i-1, +1.0f);
ggml_set_f32_1d(target_probs, (i-1)*n_vocab + token, -1.0f); set_f32_2d(target_probs, token, i-1, -1.0f);
if (i<n_tokens) { if (i<n_tokens) {
ggml_set_i32_1d(tokens_input, i, token); ggml_set_i32_1d(tokens_input, i, token);
} }