diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 770b41e7b..2c17d0b99 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -2799,19 +2799,19 @@ void shuffle_ints(int * begin, int * end) { } struct my_llama_sampler_params { - float temp = 0.0f; // <= 0.0 disabled - int top_k = 20; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float tfs_z = 1.00f; // 1.0 = disabled - float typical_p = 1.00f; // 1.0 = disabled - int repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float repeat_penalty = 1.0f; // 1.0 = disabled - float alpha_presence = 0.0f; // 0.0 = disabled - float alpha_frequency = 0.0f; // 0.0 = disabled - int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate - bool penalize_nl = true; // consider newlines as a repeatable token + float temp = 0.0f; // <= 0.0 disabled + int top_k = 20; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float tfs_z = 1.00f; // 1.0 = disabled + float typical_p = 1.00f; // 1.0 = disabled + int repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float repeat_penalty = 1.0f; // 1.0 = disabled + float presence_penalty = 0.0f; // 0.0 = disabled + float frequency_penalty = 0.0f; // 0.0 = disabled + int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate + bool penalize_nl = true; // consider newlines as a repeatable token }; struct my_llama_sampler { @@ -2871,8 +2871,8 @@ llama_token sample(struct my_llama_sampler * sampler, float * logits, const llam candidates_p, last_tokens + n_last_tokens - n_last, n_last, - params.alpha_frequency, - params.alpha_presence); + params.frequency_penalty, + params.presence_penalty); if (!params.penalize_nl) { logits[llama_token_nl()] = nl_logit; @@ -4203,12 +4203,22 @@ int main(int argc, char ** argv) { int n_gen = params.n_predict; int sample_ctx = n_tokens - n_tokens/8; - sampler.params.temp = 0.2f; - sampler.params.repeat_penalty = 1.1f; - sampler.params.mirostat = 2; + // use defaults from common.h + sampler.params.top_k = 40; + sampler.params.top_p = 0.95f; + sampler.params.tfs_z = 1.00f; + sampler.params.typical_p = 1.00f; + sampler.params.temp = 0.8f; + sampler.params.repeat_penalty = 1.1f; + sampler.params.repeat_last_n = 64; + sampler.params.frequency_penalty = 0.0f; + sampler.params.presence_penalty = 0.0f; + sampler.params.mirostat = 0; + sampler.params.mirostat_tau = 5.00f; + sampler.params.mirostat_eta = 0.10f; init_sampler(&sampler, lctx); - printf("Generating %d tokens.\n", n_gen); + printf("[Prediction context]\n"); struct ggml_tensor * tokens_input = ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, n_tokens); struct ggml_tensor * target_logits = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens); @@ -4223,7 +4233,7 @@ int main(int argc, char ** argv) { print_token(lctx, ggml_get_i32_1d(tokens_input, i)); } - printf("---\n"); + printf("\n[Generating %d tokens]\n", n_gen); for (int i=0; i