diff --git a/llama.cpp b/llama.cpp index 9b17bf347..74a802fd4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1712,14 +1712,14 @@ struct llama_hparams { return n_embd_head_v * n_head_kv; } - uint32_t n_embd_k_s() const { // dimension of the recurrent convolution state embeddings + uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings // corresponds to Mamba's conv_states size // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; } - uint32_t n_embd_v_s() const { // dimension of the ssm scan state embeddings + uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings // corresponds to Mamba's ssm_states size return ssm_d_state * ssm_d_inner; } @@ -8573,7 +8573,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } - // For Mamba (and other constant-time-and-size architectures), + // For Mamba (and other recurrent architectures), // update the correct state(s)/sequence(s) for each token of the batch. // Like with the KQ_mask, if a token in the batch has multiple sequences, // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv). diff --git a/llama.h b/llama.h index f0aca1fe5..ee804a658 100644 --- a/llama.h +++ b/llama.h @@ -235,7 +235,7 @@ extern "C" { uint32_t seed; // RNG seed, -1 for random uint32_t n_ctx; // text context, 0 = from model uint32_t n_batch; // prompt processing maximum batch size - uint32_t n_parallel; // number of parallel sequences + uint32_t n_parallel; // number of parallel sequences (i.e. distinct states for recurrent models) uint32_t n_threads; // number of threads to use for generation uint32_t n_threads_batch; // number of threads to use for batch processing