diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 2fa9cfa3f..467e2f027 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -108,6 +108,9 @@ int main(int argc, char ** argv) { return 1; } + // For BERT models, batch size must be equal to ubatch size + params.n_ubatch = params.n_batch; + if (params.chunk_size <= 0) { fprintf(stderr, "chunk_size must be positive\n"); return 1;