mamba : clarify some comments
This commit is contained in:
parent
2a99d1b243
commit
93fd4b8d5b
2 changed files with 4 additions and 4 deletions
|
@ -1712,14 +1712,14 @@ struct llama_hparams {
|
||||||
return n_embd_head_v * n_head_kv;
|
return n_embd_head_v * n_head_kv;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t n_embd_k_s() const { // dimension of the recurrent convolution state embeddings
|
uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings
|
||||||
// corresponds to Mamba's conv_states size
|
// corresponds to Mamba's conv_states size
|
||||||
// TODO: maybe support other convolution strides than 1
|
// TODO: maybe support other convolution strides than 1
|
||||||
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
||||||
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t n_embd_v_s() const { // dimension of the ssm scan state embeddings
|
uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
|
||||||
// corresponds to Mamba's ssm_states size
|
// corresponds to Mamba's ssm_states size
|
||||||
return ssm_d_state * ssm_d_inner;
|
return ssm_d_state * ssm_d_inner;
|
||||||
}
|
}
|
||||||
|
@ -8573,7 +8573,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// For Mamba (and other constant-time-and-size architectures),
|
// For Mamba (and other recurrent architectures),
|
||||||
// update the correct state(s)/sequence(s) for each token of the batch.
|
// update the correct state(s)/sequence(s) for each token of the batch.
|
||||||
// Like with the KQ_mask, if a token in the batch has multiple sequences,
|
// Like with the KQ_mask, if a token in the batch has multiple sequences,
|
||||||
// they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv).
|
// they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv).
|
||||||
|
|
2
llama.h
2
llama.h
|
@ -235,7 +235,7 @@ extern "C" {
|
||||||
uint32_t seed; // RNG seed, -1 for random
|
uint32_t seed; // RNG seed, -1 for random
|
||||||
uint32_t n_ctx; // text context, 0 = from model
|
uint32_t n_ctx; // text context, 0 = from model
|
||||||
uint32_t n_batch; // prompt processing maximum batch size
|
uint32_t n_batch; // prompt processing maximum batch size
|
||||||
uint32_t n_parallel; // number of parallel sequences
|
uint32_t n_parallel; // number of parallel sequences (i.e. distinct states for recurrent models)
|
||||||
uint32_t n_threads; // number of threads to use for generation
|
uint32_t n_threads; // number of threads to use for generation
|
||||||
uint32_t n_threads_batch; // number of threads to use for batch processing
|
uint32_t n_threads_batch; // number of threads to use for batch processing
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue