diff --git a/examples/baby-llama/baby-llama-text.cpp b/examples/baby-llama/baby-llama-text.cpp index b56441f9a..e65d2d186 100644 --- a/examples/baby-llama/baby-llama-text.cpp +++ b/examples/baby-llama/baby-llama-text.cpp @@ -1025,78 +1025,93 @@ void print_tokens_batch(struct llama_context* ctx, struct ggml_tensor * tokens) } } -void get_example_targets(const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * targets) { +void get_example_targets(const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) { int n_tokens = tokens_input->ne[0]; - int n_vocab = targets->ne[0]; + int n_vocab = target_logits->ne[0]; + + const float eps = 1e-6f; + const float target_prob = 1.0f; int sample = train_samples[example_id % n_train_samples]; GGML_ASSERT(sample+n_tokens-1 < n_train_data); - ggml_set_f32(targets, -1.0f/n_vocab); + ggml_set_f32(target_logits, -1.0f/n_vocab); + ggml_set_f32(target_probs, 0.0f); ggml_set_i32_1d(tokens_input, 0, llama_token_bos()); for (int i=1; in_dims == 2); - GGML_ASSERT( targets->n_dims == 3); +void get_example_targets_batch(struct ggml_context * ctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) { + GGML_ASSERT(tokens_input->n_dims == 2); + GGML_ASSERT(target_logits->n_dims == 3); + GGML_ASSERT(target_probs->n_dims == 3); int n_tokens = tokens_input->ne[0]; int n_batch = tokens_input->ne[1]; - GGML_ASSERT(n_tokens == targets->ne[1]); - GGML_ASSERT(n_batch == targets->ne[2]); + GGML_ASSERT(n_tokens == target_logits->ne[1]); + GGML_ASSERT(n_batch == target_logits->ne[2]); + GGML_ASSERT(n_tokens == target_probs->ne[1]); + GGML_ASSERT(n_batch == target_probs->ne[2]); for (int k=0; kne[0], k*tokens_input->nb[1]); - struct ggml_tensor * targets_k = ggml_view_2d(ctx, - targets, - targets->ne[0], - targets->ne[1], - targets->nb[1], - k*targets->nb[2]); + struct ggml_tensor * target_logits_k = ggml_view_2d(ctx, + target_logits, + target_logits->ne[0], + target_logits->ne[1], + target_logits->nb[1], + k*target_logits->nb[2]); + + struct ggml_tensor * target_probs_k = ggml_view_2d(ctx, + target_probs, + target_probs->ne[0], + target_probs->ne[1], + target_probs->nb[1], + k*target_probs->nb[2]); get_example_targets(train_samples, n_train_samples, train_data, n_train_data, - example_id*n_batch + k, tokens_input_k, targets_k); + example_id*n_batch + k, tokens_input_k, target_logits_k, target_probs_k); } } -void lshift_examples(struct ggml_tensor * tokens_input, struct ggml_tensor * targets, int n_shift) { +void lshift_examples(struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs, int n_shift) { int n_tokens = tokens_input->ne[0]; - int n_vocab = targets->ne[0]; + int n_vocab = target_logits->ne[0]; for (int i=0; idata + (sample_ctx-1)*logits->nb[1]), (llama_token *) tokens_input->data, @@ -1739,7 +1756,7 @@ int main(int argc, char ** argv) { // print_row(probs, sample_at); print_token(lctx, token); - lshift_examples(tokens_input, targets, 1); + lshift_examples(tokens_input, target_logits, target_probs, 1); ggml_set_i32_1d(tokens_input, 0, 0); ggml_set_i32_1d(tokens_input, sample_ctx-1, token);