mamba : clarify some comments

This commit is contained in:
Francis Couture-Harpin 2024-03-04 15:57:40 -05:00
parent 2a99d1b243
commit 93fd4b8d5b
2 changed files with 4 additions and 4 deletions

View file

@ -1712,14 +1712,14 @@ struct llama_hparams {
return n_embd_head_v * n_head_kv; return n_embd_head_v * n_head_kv;
} }
uint32_t n_embd_k_s() const { // dimension of the recurrent convolution state embeddings uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings
// corresponds to Mamba's conv_states size // corresponds to Mamba's conv_states size
// TODO: maybe support other convolution strides than 1 // TODO: maybe support other convolution strides than 1
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
} }
uint32_t n_embd_v_s() const { // dimension of the ssm scan state embeddings uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
// corresponds to Mamba's ssm_states size // corresponds to Mamba's ssm_states size
return ssm_d_state * ssm_d_inner; return ssm_d_state * ssm_d_inner;
} }
@ -8573,7 +8573,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
} }
} }
} }
// For Mamba (and other constant-time-and-size architectures), // For Mamba (and other recurrent architectures),
// update the correct state(s)/sequence(s) for each token of the batch. // update the correct state(s)/sequence(s) for each token of the batch.
// Like with the KQ_mask, if a token in the batch has multiple sequences, // Like with the KQ_mask, if a token in the batch has multiple sequences,
// they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv). // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv).

View file

@ -235,7 +235,7 @@ extern "C" {
uint32_t seed; // RNG seed, -1 for random uint32_t seed; // RNG seed, -1 for random
uint32_t n_ctx; // text context, 0 = from model uint32_t n_ctx; // text context, 0 = from model
uint32_t n_batch; // prompt processing maximum batch size uint32_t n_batch; // prompt processing maximum batch size
uint32_t n_parallel; // number of parallel sequences uint32_t n_parallel; // number of parallel sequences (i.e. distinct states for recurrent models)
uint32_t n_threads; // number of threads to use for generation uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing uint32_t n_threads_batch; // number of threads to use for batch processing