examplse : de-shadow

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-12 14:25:32 +02:00
parent 82caffa74e
commit 9a735ae6d8
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
16 changed files with 152 additions and 159 deletions

View file

@ -122,9 +122,9 @@ struct slot_params {
samplers.emplace_back(common_sampler_type_to_str(sampler));
}
json lora = json::array();
for (size_t i = 0; i < this->lora.size(); ++i) {
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
json json_lora = json::array();
for (size_t i = 0; i < lora.size(); ++i) {
json_lora.push_back({{"id", i}, {"scale", lora[i].scale}});
}
return json {
@ -167,7 +167,7 @@ struct slot_params {
{"speculative.p_min", speculative.p_min},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"lora", lora},
{"lora", json_lora},
};
}
};
@ -1641,7 +1641,7 @@ struct server_context {
llama_context_params cparams_dft;
llama_batch batch = {};
llama_batch batch_main = {};
bool clean_kv_cache = true;
bool add_bos_token = true;
@ -1676,7 +1676,7 @@ struct server_context {
llama_batch_free(slot.batch_spec);
}
llama_batch_free(batch);
llama_batch_free(batch_main);
}
bool load_model(const common_params & params) {
@ -1797,7 +1797,7 @@ struct server_context {
const int32_t n_batch = llama_n_batch(ctx);
// only a single seq_id per token is needed
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
batch_main = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
}
metrics.init();
@ -2655,7 +2655,7 @@ struct server_context {
}
// start populating the batch for this iteration
common_batch_clear(batch);
common_batch_clear(batch_main);
// track if given slot can be batched with slots already in the batch
server_slot * slot_batched = nullptr;
@ -2673,9 +2673,9 @@ struct server_context {
continue;
}
slot.i_batch = batch.n_tokens;
slot.i_batch = batch_main.n_tokens;
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
common_batch_add(batch_main, slot.sampled, slot.n_past, { slot.id }, true);
slot.n_past += 1;
@ -2692,7 +2692,7 @@ struct server_context {
int32_t n_ubatch = llama_n_ubatch(ctx);
// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
if (params_base.cont_batching || batch_main.n_tokens == 0) {
for (auto & slot : slots) {
// check if we can batch this slot with the previous one
if (slot.is_processing()) {
@ -2858,7 +2858,7 @@ struct server_context {
// non-causal tasks require to fit the entire prompt in the physical batch
if (slot.is_non_causal()) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
if (batch_main.n_tokens + slot.n_prompt_tokens > n_batch) {
continue;
}
}
@ -2878,11 +2878,11 @@ struct server_context {
slot.cache_tokens.resize(slot.n_past);
// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
while (slot.n_past < slot.n_prompt_tokens && batch_main.n_tokens < n_batch) {
// without pooling, we want to output the embeddings for all the tokens in the batch
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
common_batch_add(batch_main, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
@ -2892,13 +2892,13 @@ struct server_context {
slot.n_past++;
}
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch_main.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
// entire prompt has been processed
if (slot.n_past == slot.n_prompt_tokens) {
slot.state = SLOT_STATE_DONE_PROMPT;
GGML_ASSERT(batch.n_tokens > 0);
GGML_ASSERT(batch_main.n_tokens > 0);
common_sampler_reset(slot.smpl);
@ -2908,27 +2908,27 @@ struct server_context {
}
// extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true;
batch_main.logits[batch_main.n_tokens - 1] = true;
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
slot.i_batch = batch_main.n_tokens - 1;
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch_main.n_tokens);
}
}
if (batch.n_tokens >= n_batch) {
if (batch_main.n_tokens >= n_batch) {
break;
}
}
}
if (batch.n_tokens == 0) {
if (batch_main.n_tokens == 0) {
SRV_WRN("%s", "no tokens to decode\n");
return;
}
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
SRV_DBG("decoding batch, n_tokens = %d\n", batch_main.n_tokens);
if (slot_batched) {
// make sure we're in the right embedding mode
@ -2938,17 +2938,17 @@ struct server_context {
}
// process the created batch of tokens
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
for (int32_t i_batch = 0; i_batch < batch_main.n_tokens; i_batch += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch_main.n_tokens - i_batch);
llama_batch batch_view = {
n_tokens,
batch.token + i,
batch_main.token + i_batch,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
batch_main.pos + i_batch,
batch_main.n_seq_id + i_batch,
batch_main.seq_id + i_batch,
batch_main.logits + i_batch,
};
const int ret = llama_decode(ctx, batch_view);
@ -2957,7 +2957,7 @@ struct server_context {
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
// if you get here, it means the KV cache is full - try increasing it via the context size
SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i_batch = %d, n_batch = %d, ret = %d\n", i_batch, n_batch, ret);
for (auto & slot : slots) {
slot.release();
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
@ -2967,15 +2967,15 @@ struct server_context {
// retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2;
i -= n_batch;
i_batch -= n_batch;
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i_batch = %d, n_batch = %d, ret = %d\n", i_batch, n_batch, ret);
continue; // continue loop of n_batch
}
for (auto & slot : slots) {
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
if (slot.i_batch < (int) i_batch || slot.i_batch >= (int) (i_batch + n_tokens)) {
continue; // continue loop of slots
}
@ -3001,7 +3001,7 @@ struct server_context {
continue; // continue loop of slots
}
const int tok_idx = slot.i_batch - i;
const int tok_idx = slot.i_batch - i_batch;
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
@ -3687,8 +3687,8 @@ int main(int argc, char ** argv) {
} else {
// multiple results (multitask)
json arr = json::array();
for (auto & res : results) {
arr.push_back(res->to_json());
for (auto & result : results) {
arr.push_back(result->to_json());
}
res_ok(res, arr);
}