speculative : add tree-based sampling example (#3624)

* sampling : one sequence per sampling context

ggml-ci

* speculative : add tree-based sampling support

ggml-ci

* speculative : reuse the n_parallel CLI param

* speculative : refactor sampling

* examples : fix build after sampling refactoring

ggml-ci

* batched : fix n_seq_id

* sampling : fix malloc

ggml-ci

* swift : fix build

ggml-ci

* swift : try to fix build

ggml-ci

* prompts : add assistant.txt

* common : add llama_batch_add() and llama_batch_clear() helpers

* speculative : minor refactor

ggml-ci

* minor : comments + rename

ggml-ci

* speculative : fix off-by-one for n_drafted

* speculative : fix the n_drafted fix + p constants
This commit is contained in:
Georgi Gerganov 2023-10-18 16:21:57 +03:00 committed by GitHub
parent c67fe68e41
commit 0e89203b51
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 737 additions and 578 deletions

View file

@ -1450,7 +1450,10 @@ static bool llama_kv_cache_find_slot(
for (uint32_t i = 0; i < n_tokens; i++) {
cache.cells[cache.head + i].pos = batch.pos[i];
cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i]);
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
}
}
return true;
@ -1530,6 +1533,9 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
} else {
cache.cells[i].seq_id.clear();
cache.cells[i].seq_id.insert(seq_id);
}
}
@ -3178,7 +3184,7 @@ static struct ggml_cgraph * llm_build_llama(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@ -3564,7 +3570,7 @@ static struct ggml_cgraph * llm_build_baichaun(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@ -3963,7 +3969,7 @@ static struct ggml_cgraph * llm_build_refact(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@ -4315,7 +4321,7 @@ static struct ggml_cgraph * llm_build_falcon(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@ -4667,7 +4673,7 @@ static struct ggml_cgraph * llm_build_starcoder(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@ -4898,7 +4904,7 @@ static struct ggml_cgraph * llm_build_persimmon(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
@ -5296,7 +5302,7 @@ static struct ggml_cgraph * llm_build_bloom(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@ -5564,7 +5570,7 @@ static struct ggml_cgraph * llm_build_mpt(
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
@ -5864,8 +5870,11 @@ static int llama_decode_internal(
// helpers for smoother batch API transistion
// after deprecating the llama_eval calls, these will be removed
std::vector<llama_pos> pos;
std::vector<llama_seq_id> seq_id;
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id_arr;
std::vector<std::vector<llama_seq_id>> seq_id;
if (batch.pos == nullptr) {
pos.resize(n_tokens);
@ -5877,12 +5886,18 @@ static int llama_decode_internal(
}
if (batch.seq_id == nullptr) {
n_seq_id.resize(n_tokens);
seq_id.resize(n_tokens);
seq_id_arr.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
seq_id[i] = batch.all_seq_id;
n_seq_id[i] = 1;
seq_id[i].resize(1);
seq_id[i][0] = batch.all_seq_id;
seq_id_arr[i] = seq_id[i].data();
}
batch.seq_id = seq_id.data();
batch.n_seq_id = n_seq_id.data();
batch.seq_id = seq_id_arr.data();
}
if (!llama_kv_cache_find_slot(kv_self, batch)) {
@ -9109,6 +9124,9 @@ void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llam
}
void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
if (seq_id_src == seq_id_dst) {
return;
}
llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
}
@ -9561,7 +9579,7 @@ int llama_eval_embd(
int n_past) {
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, };
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
const int ret = llama_decode_internal(*ctx, batch);
if (ret < 0) {
@ -9582,20 +9600,21 @@ struct llama_batch llama_batch_get_one(
llama_pos pos_0,
llama_seq_id seq_id) {
return {
/*n_tokens =*/ n_tokens,
/*tokens =*/ tokens,
/*embd =*/ nullptr,
/*pos =*/ nullptr,
/*seq_id =*/ nullptr,
/*logits =*/ nullptr,
/*all_pos_0 =*/ pos_0,
/*all_pos_1 =*/ 1,
/*all_seq_id =*/ seq_id,
/*n_tokens =*/ n_tokens,
/*tokens =*/ tokens,
/*embd =*/ nullptr,
/*pos =*/ nullptr,
/*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr,
/*logits =*/ nullptr,
/*all_pos_0 =*/ pos_0,
/*all_pos_1 =*/ 1,
/*all_seq_id =*/ seq_id,
};
}
struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) {
llama_batch batch = { -1, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) {
llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
if (embd) {
batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd);
@ -9603,19 +9622,29 @@ struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) {
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
}
batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
batch.seq_id = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_tokens);
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
for (int i = 0; i < n_tokens; ++i) {
batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
}
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
return batch;
}
void llama_batch_free(struct llama_batch batch) {
if (batch.token) free(batch.token);
if (batch.embd) free(batch.embd);
if (batch.pos) free(batch.pos);
if (batch.seq_id) free(batch.seq_id);
if (batch.logits) free(batch.logits);
if (batch.token) free(batch.token);
if (batch.embd) free(batch.embd);
if (batch.pos) free(batch.pos);
if (batch.n_seq_id) free(batch.n_seq_id);
if (batch.seq_id) {
for (int i = 0; i < batch.n_tokens; ++i) {
free(batch.seq_id[i]);
}
free(batch.seq_id);
}
if (batch.logits) free(batch.logits);
}
int llama_decode(