change sampling parameters for prediction after training to defaults of common.h

and clarify what is context for prediction and what are generated tokens
This commit is contained in:
xaedes 2023-07-03 18:24:57 +02:00
parent 17a0898d50
commit 24a4b099f3
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -2799,19 +2799,19 @@ void shuffle_ints(int * begin, int * end) {
}
struct my_llama_sampler_params {
float temp = 0.0f; // <= 0.0 disabled
int top_k = 20; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
int repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float repeat_penalty = 1.0f; // 1.0 = disabled
float alpha_presence = 0.0f; // 0.0 = disabled
float alpha_frequency = 0.0f; // 0.0 = disabled
int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = true; // consider newlines as a repeatable token
float temp = 0.0f; // <= 0.0 disabled
int top_k = 20; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
int repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float repeat_penalty = 1.0f; // 1.0 = disabled
float presence_penalty = 0.0f; // 0.0 = disabled
float frequency_penalty = 0.0f; // 0.0 = disabled
int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = true; // consider newlines as a repeatable token
};
struct my_llama_sampler {
@ -2871,8 +2871,8 @@ llama_token sample(struct my_llama_sampler * sampler, float * logits, const llam
candidates_p,
last_tokens + n_last_tokens - n_last,
n_last,
params.alpha_frequency,
params.alpha_presence);
params.frequency_penalty,
params.presence_penalty);
if (!params.penalize_nl) {
logits[llama_token_nl()] = nl_logit;
@ -4203,12 +4203,22 @@ int main(int argc, char ** argv) {
int n_gen = params.n_predict;
int sample_ctx = n_tokens - n_tokens/8;
sampler.params.temp = 0.2f;
sampler.params.repeat_penalty = 1.1f;
sampler.params.mirostat = 2;
// use defaults from common.h
sampler.params.top_k = 40;
sampler.params.top_p = 0.95f;
sampler.params.tfs_z = 1.00f;
sampler.params.typical_p = 1.00f;
sampler.params.temp = 0.8f;
sampler.params.repeat_penalty = 1.1f;
sampler.params.repeat_last_n = 64;
sampler.params.frequency_penalty = 0.0f;
sampler.params.presence_penalty = 0.0f;
sampler.params.mirostat = 0;
sampler.params.mirostat_tau = 5.00f;
sampler.params.mirostat_eta = 0.10f;
init_sampler(&sampler, lctx);
printf("Generating %d tokens.\n", n_gen);
printf("[Prediction context]\n");
struct ggml_tensor * tokens_input = ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, n_tokens);
struct ggml_tensor * target_logits = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens);
@ -4223,7 +4233,7 @@ int main(int argc, char ** argv) {
print_token(lctx, ggml_get_i32_1d(tokens_input, i));
}
printf("---\n");
printf("\n[Generating %d tokens]\n", n_gen);
for (int i=0; i<n_gen; ++i) {
struct ggml_init_params cparams = {
compute_size, // .mem_size