llama : simplify Mamba with advanced batch splits (#8526)
* llama : advanced batch splits This includes equal-sequence-length batch splits which are useful to simplify recurrent model operators. * llama : always make recurrent state slots contiguous * ggml : simplify mamba operators * llama : fix integer signedness mixing * llama : logits_all has priority over batch->logits Otherwise, the server embeddings tests failed. This was likely an existing problem but was only detected here because of an additional assertion. * llama : apply suggestions Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * llama : fix t5 segfault * llama : fix Mamba session save and restore * llama : minor cosmetic changes * llama : rename llama_reorder_outputs to llama_output_reorder Also move it closer to llama_output_reserve. * llama : fix pooled embeddings when using batches with equal_seqs * minor : add struct members for clarity ggml-ci * llama : fix T5 segfault again * llama : fix Mamba pooled embeddings with multiple sequences Until the pooled embeddings are refactored to allow splitting across ubatches for causal embeddings, recurrent models can only process a single sequence per ubatch when calculating pooled embeddings. * llama : add llama_model_is_recurrent to simplify figuring that out This will make it easier to more cleanly support RWKV-v6 and Mamba-2. * llama : fix simple splits when the batch contains embeddings --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
fc54ef0d1c
commit
a1631e53f6
4 changed files with 1134 additions and 675 deletions
|
@ -511,6 +511,9 @@ extern "C" {
|
|||
// to the decoder to start generating output sequence. For other models, it returns -1.
|
||||
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
|
||||
|
||||
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
|
||||
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
|
||||
|
||||
// Returns 0 on success
|
||||
LLAMA_API uint32_t llama_model_quantize(
|
||||
const char * fname_inp,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue