Fix a bug in the rope calculation

This commit is contained in:
Georgi Gerganov 2023-03-10 23:46:39 +02:00
parent 18ebda34d6
commit 70bc0b8b15
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 92 additions and 6 deletions

View file

@ -400,7 +400,7 @@ bool llama_eval(
const int n_ctx = hparams.n_ctx;
const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
const int n_rot = hparams.n_rot;
const int n_rot = hparams.n_embd/hparams.n_head;
const int d_key = n_embd/n_head;
@ -628,6 +628,9 @@ int main(int argc, char ** argv) {
params.prompt = gpt_random_prompt(rng);
}
// params.prompt = R"(// this function checks if the number n is prime
//bool is_prime(int n) {)";
int64_t t_load_us = 0;
gpt_vocab vocab;
@ -691,7 +694,6 @@ int main(int argc, char ** argv) {
if (i >= embd_inp.size()) {
// sample next token
const int top_k = params.top_k;
const float top_p = params.top_p;
const float temp = params.temp;
@ -702,7 +704,7 @@ int main(int argc, char ** argv) {
{
const int64_t t_start_sample_us = ggml_time_us();
id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_p, temp, rng);
t_sample_us += ggml_time_us() - t_start_sample_us;
}