set cparams.n_parallel to the number of sequences
This commit is contained in:
parent
7d999555b7
commit
ac07f7d0f7
2 changed files with 4 additions and 2 deletions
|
@ -1822,7 +1822,9 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence;
|
const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence;
|
||||||
if (ppl) {
|
if (ppl) {
|
||||||
int32_t n_kv = std::max(1, params.n_batch / n_ctx) * n_ctx;
|
int n_seq = std::max(1, params.n_batch / n_ctx);
|
||||||
|
int32_t n_kv = n_seq * n_ctx;
|
||||||
|
params.n_parallel = n_seq;
|
||||||
params.n_ctx = n_kv;
|
params.n_ctx = n_kv;
|
||||||
params.n_batch = std::min(params.n_batch, n_kv);
|
params.n_batch = std::min(params.n_batch, n_kv);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -8943,7 +8943,7 @@ static int llama_decode_internal(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
logits_valid[i] = batch.logits[i] == 1;
|
logits_valid[i] = batch.logits[i] != 0;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
} else if (lctx.logits_all) {
|
} else if (lctx.logits_all) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue