check n_ubatch >= n_tokens with non-casual attention
This commit is contained in:
parent
54cdd478d7
commit
cda49d3828
3 changed files with 10 additions and 5 deletions
|
@ -1738,7 +1738,8 @@ struct server_context {
|
|||
}
|
||||
|
||||
// process in chunks of params.n_batch
|
||||
int32_t n_batch = params.n_batch;
|
||||
int32_t n_batch = llama_n_batch(ctx);
|
||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||
|
||||
// next, batch any pending prompts without exceeding n_batch
|
||||
if (params.cont_batching || batch.n_tokens == 0) {
|
||||
|
@ -1811,7 +1812,7 @@ struct server_context {
|
|||
|
||||
if (slot.embedding) {
|
||||
// this prompt is too large to process - discard it
|
||||
if (slot.n_prompt_tokens > n_batch) {
|
||||
if (slot.n_prompt_tokens > n_ubatch) {
|
||||
slot.state = SLOT_STATE_PROCESSING;
|
||||
slot.command = SLOT_COMMAND_NONE;
|
||||
slot.release();
|
||||
|
|
|
@ -8774,6 +8774,8 @@ static int llama_decode_internal(
|
|||
|
||||
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
||||
|
||||
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
||||
|
||||
if (lctx.t_compute_start_us == 0) {
|
||||
lctx.t_compute_start_us = ggml_time_us();
|
||||
}
|
||||
|
@ -9011,9 +9013,6 @@ static int llama_decode_internal(
|
|||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
{
|
||||
// FIXME: this may not work if the sequences are split into different batches
|
||||
GGML_ASSERT(n_tokens_all == n_tokens);
|
||||
|
||||
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
|
||||
|
||||
// extract sequence embeddings
|
||||
|
@ -13076,6 +13075,10 @@ uint32_t llama_n_batch(const struct llama_context * ctx) {
|
|||
return ctx->cparams.n_batch;
|
||||
}
|
||||
|
||||
uint32_t llama_n_ubatch(const struct llama_context * ctx) {
|
||||
return ctx->cparams.n_ubatch;
|
||||
}
|
||||
|
||||
uint32_t llama_n_seq_max(const struct llama_context * ctx) {
|
||||
return ctx->kv_self.size;
|
||||
}
|
||||
|
|
1
llama.h
1
llama.h
|
@ -378,6 +378,7 @@ extern "C" {
|
|||
|
||||
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
||||
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
||||
|
||||
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue