simple : add parallel decoding support
This commit is contained in:
parent
addae65fd4
commit
b377bf2266
7 changed files with 187 additions and 76 deletions
|
@ -123,7 +123,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
std::vector<llama_token> tokens_system;
|
||||
tokens_system = ::llama_tokenize(ctx, k_system, true);
|
||||
const uint32_t n_tokens_system = tokens_system.size();
|
||||
const int32_t n_tokens_system = tokens_system.size();
|
||||
|
||||
llama_seq_id g_seq_id = 0;
|
||||
|
||||
|
@ -144,7 +144,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
batch.n_tokens = n_tokens_system;
|
||||
|
||||
for (uint32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
batch.token[i] = tokens_system[i];
|
||||
batch.pos[i] = i;
|
||||
batch.seq_id[i] = 0;
|
||||
|
@ -156,7 +156,7 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
// assign the system KV cachce to all parallel sequences
|
||||
// assign the system KV cache to all parallel sequences
|
||||
for (int32_t i = 1; i < n_clients; ++i) {
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
|
||||
}
|
||||
|
@ -248,7 +248,7 @@ int main(int argc, char ** argv) {
|
|||
int32_t n_batch = params.n_batch;
|
||||
|
||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||
const uint32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue