assign n_batch value to n_ubatch
This commit is contained in:
parent
7d819d088e
commit
e16279ed0e
1 changed files with 3 additions and 0 deletions
|
@ -108,6 +108,9 @@ int main(int argc, char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For BERT models, batch size must be equal to ubatch size
|
||||||
|
params.n_ubatch = params.n_batch;
|
||||||
|
|
||||||
if (params.chunk_size <= 0) {
|
if (params.chunk_size <= 0) {
|
||||||
fprintf(stderr, "chunk_size must be positive\n");
|
fprintf(stderr, "chunk_size must be positive\n");
|
||||||
return 1;
|
return 1;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue