simple : add parallel decoding support

This commit is contained in:
Georgi Gerganov 2023-09-20 13:06:34 +03:00
parent addae65fd4
commit b377bf2266
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
7 changed files with 187 additions and 76 deletions

View file

@ -123,7 +123,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> tokens_system;
tokens_system = ::llama_tokenize(ctx, k_system, true);
const uint32_t n_tokens_system = tokens_system.size();
const int32_t n_tokens_system = tokens_system.size();
llama_seq_id g_seq_id = 0;
@ -144,7 +144,7 @@ int main(int argc, char ** argv) {
batch.n_tokens = n_tokens_system;
for (uint32_t i = 0; i < batch.n_tokens; ++i) {
for (int32_t i = 0; i < batch.n_tokens; ++i) {
batch.token[i] = tokens_system[i];
batch.pos[i] = i;
batch.seq_id[i] = 0;
@ -156,7 +156,7 @@ int main(int argc, char ** argv) {
return 1;
}
// assign the system KV cachce to all parallel sequences
// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i < n_clients; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
}
@ -248,7 +248,7 @@ int main(int argc, char ** argv) {
int32_t n_batch = params.n_batch;
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const uint32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
llama_batch batch_view = {
n_tokens,