remove unused function argument from get_example_targets_batch
This commit is contained in:
parent
ce937bc431
commit
ff759d957c
1 changed files with 6 additions and 11 deletions
|
@ -2581,7 +2581,7 @@ void get_example_targets(const int * train_samples, size_t n_train_samples, cons
|
|||
}
|
||||
}
|
||||
|
||||
void get_example_targets_batch(struct llama_context * /*lctx*/, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
|
||||
void get_example_targets_batch(const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
|
||||
GGML_ASSERT(tokens_input->n_dims == 2);
|
||||
GGML_ASSERT(target_logits->n_dims == 3);
|
||||
GGML_ASSERT(target_probs->n_dims == 3);
|
||||
|
@ -2596,27 +2596,23 @@ void get_example_targets_batch(struct llama_context * /*lctx*/, const int * trai
|
|||
|
||||
ggml_set_f32(target_logits, -1.0f/n_vocab);
|
||||
ggml_set_f32(target_probs, 0.0f);
|
||||
// printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
|
||||
for (int k=0; k<n_batch; ++k) {
|
||||
// printf("%s: batch %d\n", __func__, k);
|
||||
size_t sample = train_samples[(example_id*n_batch + k) % n_train_samples];
|
||||
size_t sample_idx = (example_id*n_batch + k) % n_train_samples;
|
||||
size_t sample = train_samples[sample_idx];
|
||||
// printf("%s: sample_idx=%zu sample=%zu\n", __func__, sample_idx, sample);
|
||||
GGML_ASSERT(sample+n_tokens-1 < n_train_data);
|
||||
|
||||
set_i32_2d(tokens_input, 0, k, llama_token_bos());
|
||||
for (int i=1; i<n_tokens+1; ++i) {
|
||||
int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
|
||||
// print_token(lctx, token);
|
||||
set_f32_3d(target_logits, token, i-1, k, +1.0f);
|
||||
set_f32_3d(target_probs, token, i-1, k, +1.0f);
|
||||
if (i<n_tokens) {
|
||||
set_i32_2d(tokens_input, i, k, token);
|
||||
}
|
||||
}
|
||||
// printf("\n=\n");
|
||||
// for (int i=0; i<n_tokens; ++i) {
|
||||
// int token = get_i32_2d(tokens_input, i, k);
|
||||
// print_token(lctx, token);
|
||||
// }
|
||||
// printf("\n-\n");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4011,8 +4007,7 @@ int main(int argc, char ** argv) {
|
|||
struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
|
||||
struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
|
||||
|
||||
|
||||
get_example_targets_batch(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs);
|
||||
get_example_targets_batch(train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs);
|
||||
|
||||
GGML_ASSERT(n_past == 0);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue