remove unused function argument from get_example_targets_batch

This commit is contained in:
xaedes 2023-07-02 21:38:03 +02:00
parent ce937bc431
commit ff759d957c
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -2581,7 +2581,7 @@ void get_example_targets(const int * train_samples, size_t n_train_samples, cons
}
}
void get_example_targets_batch(struct llama_context * /*lctx*/, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
void get_example_targets_batch(const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
GGML_ASSERT(tokens_input->n_dims == 2);
GGML_ASSERT(target_logits->n_dims == 3);
GGML_ASSERT(target_probs->n_dims == 3);
@ -2596,27 +2596,23 @@ void get_example_targets_batch(struct llama_context * /*lctx*/, const int * trai
ggml_set_f32(target_logits, -1.0f/n_vocab);
ggml_set_f32(target_probs, 0.0f);
// printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
for (int k=0; k<n_batch; ++k) {
// printf("%s: batch %d\n", __func__, k);
size_t sample = train_samples[(example_id*n_batch + k) % n_train_samples];
size_t sample_idx = (example_id*n_batch + k) % n_train_samples;
size_t sample = train_samples[sample_idx];
// printf("%s: sample_idx=%zu sample=%zu\n", __func__, sample_idx, sample);
GGML_ASSERT(sample+n_tokens-1 < n_train_data);
set_i32_2d(tokens_input, 0, k, llama_token_bos());
for (int i=1; i<n_tokens+1; ++i) {
int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
// print_token(lctx, token);
set_f32_3d(target_logits, token, i-1, k, +1.0f);
set_f32_3d(target_probs, token, i-1, k, +1.0f);
if (i<n_tokens) {
set_i32_2d(tokens_input, i, k, token);
}
}
// printf("\n=\n");
// for (int i=0; i<n_tokens; ++i) {
// int token = get_i32_2d(tokens_input, i, k);
// print_token(lctx, token);
// }
// printf("\n-\n");
}
}
@ -4011,8 +4007,7 @@ int main(int argc, char ** argv) {
struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
get_example_targets_batch(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs);
get_example_targets_batch(train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs);
GGML_ASSERT(n_past == 0);