assign n_batch value to n_ubatch

This commit is contained in:
Minsoo Cheong 2024-03-23 21:41:22 +09:00
parent 7d819d088e
commit e16279ed0e

View file

@ -108,6 +108,9 @@ int main(int argc, char ** argv) {
return 1;
}
// For BERT models, batch size must be equal to ubatch size
params.n_ubatch = params.n_batch;
if (params.chunk_size <= 0) {
fprintf(stderr, "chunk_size must be positive\n");
return 1;