check n_ubatch >= n_tokens with non-casual attention

This commit is contained in:
slaren 2024-03-13 13:59:08 +01:00
parent 54cdd478d7
commit cda49d3828
3 changed files with 10 additions and 5 deletions

View file

@ -1738,7 +1738,8 @@ struct server_context {
} }
// process in chunks of params.n_batch // process in chunks of params.n_batch
int32_t n_batch = params.n_batch; int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);
// next, batch any pending prompts without exceeding n_batch // next, batch any pending prompts without exceeding n_batch
if (params.cont_batching || batch.n_tokens == 0) { if (params.cont_batching || batch.n_tokens == 0) {
@ -1811,7 +1812,7 @@ struct server_context {
if (slot.embedding) { if (slot.embedding) {
// this prompt is too large to process - discard it // this prompt is too large to process - discard it
if (slot.n_prompt_tokens > n_batch) { if (slot.n_prompt_tokens > n_ubatch) {
slot.state = SLOT_STATE_PROCESSING; slot.state = SLOT_STATE_PROCESSING;
slot.command = SLOT_COMMAND_NONE; slot.command = SLOT_COMMAND_NONE;
slot.release(); slot.release();

View file

@ -8774,6 +8774,8 @@ static int llama_decode_internal(
GGML_ASSERT(n_tokens_all <= cparams.n_batch); GGML_ASSERT(n_tokens_all <= cparams.n_batch);
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
if (lctx.t_compute_start_us == 0) { if (lctx.t_compute_start_us == 0) {
lctx.t_compute_start_us = ggml_time_us(); lctx.t_compute_start_us = ggml_time_us();
} }
@ -9011,9 +9013,6 @@ static int llama_decode_internal(
case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_MEAN:
{ {
// FIXME: this may not work if the sequences are split into different batches
GGML_ASSERT(n_tokens_all == n_tokens);
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0); GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
// extract sequence embeddings // extract sequence embeddings
@ -13076,6 +13075,10 @@ uint32_t llama_n_batch(const struct llama_context * ctx) {
return ctx->cparams.n_batch; return ctx->cparams.n_batch;
} }
uint32_t llama_n_ubatch(const struct llama_context * ctx) {
return ctx->cparams.n_ubatch;
}
uint32_t llama_n_seq_max(const struct llama_context * ctx) { uint32_t llama_n_seq_max(const struct llama_context * ctx) {
return ctx->kv_self.size; return ctx->kv_self.size;
} }

View file

@ -378,6 +378,7 @@ extern "C" {
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model); LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);