llama : avoid hardcoded special tokens

This commit is contained in:
Georgi Gerganov 2023-08-18 17:29:20 +03:00
parent 035d511457
commit 5d2656d670
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
11 changed files with 61 additions and 65 deletions

View file

@ -1996,7 +1996,7 @@ void print_tokens_batch(struct llama_context* ctx, struct ggml_tensor * tokens)
}
}
void get_example_targets(const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
void get_example_targets(struct llama_context * lctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
int n_tokens = tokens_input->ne[0];
int n_vocab = target_logits->ne[0];
@ -2005,7 +2005,7 @@ void get_example_targets(const int * train_samples, size_t n_train_samples, cons
ggml_set_f32(target_logits, -1.0f/n_vocab);
ggml_set_f32(target_probs, 0.0f);
ggml_set_i32_1d(tokens_input, 0, llama_token_bos());
ggml_set_i32_1d(tokens_input, 0, llama_token_bos(lctx));
for (int i=1; i<n_tokens+1; ++i) {
int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
set_f32_2d(target_logits, token, i-1, +1.0f);
@ -2016,7 +2016,7 @@ void get_example_targets(const int * train_samples, size_t n_train_samples, cons
}
}
void get_example_targets_batch(struct llama_context * /*lctx*/, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
void get_example_targets_batch(struct llama_context * lctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) {
GGML_ASSERT(tokens_input->n_dims == 2);
GGML_ASSERT(target_logits->n_dims == 3);
GGML_ASSERT(target_probs->n_dims == 3);
@ -2036,7 +2036,7 @@ void get_example_targets_batch(struct llama_context * /*lctx*/, const int * trai
size_t sample = train_samples[(example_id*n_batch + k) % n_train_samples];
GGML_ASSERT(sample+n_tokens-1 < n_train_data);
set_i32_2d(tokens_input, 0, k, llama_token_bos());
set_i32_2d(tokens_input, 0, k, llama_token_bos(lctx));
for (int i=1; i<n_tokens+1; ++i) {
int token = clamp(train_data[sample+i-1], 0, n_vocab-1);
// print_token(lctx, token);
@ -2294,7 +2294,7 @@ llama_token sample(struct my_llama_sampler * sampler, float * logits, const llam
const auto params = sampler->params;
// Apply penalties
const float nl_logit = logits[llama_token_nl()];
const float nl_logit = logits[llama_token_nl(ctx)];
const int n_last = std::min(std::min(n_last_tokens, params.repeat_last_n), sampler->n_ctx);
@ -2313,7 +2313,7 @@ llama_token sample(struct my_llama_sampler * sampler, float * logits, const llam
params.alpha_presence);
if (!params.penalize_nl) {
logits[llama_token_nl()] = nl_logit;
logits[llama_token_nl(ctx)] = nl_logit;
}
llama_token token = 0;
@ -3181,7 +3181,7 @@ int main(int argc, char ** argv) {
std::vector<int> train_samples;
train_samples.push_back(0);
for (int i = 1; i < (int) train_tokens.size() - n_tokens; ++i) {
if (!params.samples_start_after_nl || (train_tokens[i-1] == llama_token_nl())) {
if (!params.samples_start_after_nl || (train_tokens[i-1] == llama_token_nl(lctx))) {
train_samples.push_back(i);
}
}
@ -3341,7 +3341,7 @@ int main(int argc, char ** argv) {
struct ggml_tensor * target_logits = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens);
struct ggml_tensor * target_probs = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens);
get_example_targets(train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), rand()%train_samples.size(), tokens_input, target_logits, target_probs);
get_example_targets(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), rand()%train_samples.size(), tokens_input, target_logits, target_probs);
for (int i=sample_ctx; i<n_tokens; ++i) {
ggml_set_i32_1d(tokens_input, i, n_vocab/2);
}