diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index db7a52842..de71dc996 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -2581,7 +2581,7 @@ void get_example_targets(const int * train_samples, size_t n_train_samples, cons } } -void get_example_targets_batch(struct llama_context * /*lctx*/, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) { +void get_example_targets_batch(const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) { GGML_ASSERT(tokens_input->n_dims == 2); GGML_ASSERT(target_logits->n_dims == 3); GGML_ASSERT(target_probs->n_dims == 3); @@ -2596,27 +2596,23 @@ void get_example_targets_batch(struct llama_context * /*lctx*/, const int * trai ggml_set_f32(target_logits, -1.0f/n_vocab); ggml_set_f32(target_probs, 0.0f); + // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples); for (int k=0; kdata; struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data; - - get_example_targets_batch(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs); + get_example_targets_batch(train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs); GGML_ASSERT(n_past == 0);