diff --git a/common/common.cpp b/common/common.cpp index ce20360a4..4f92ef357 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -429,6 +429,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.n_batch = std::stoi(argv[i]); + } else if (arg == "-ub" || arg == "--ubatch-size") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_ubatch = std::stoi(argv[i]); } else if (arg == "--keep") { if (++i >= argc) { invalid_param = true; @@ -891,6 +897,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict); printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); + printf(" -ub N, --ubatch-size N\n"); + printf(" micro batch size for prompt processing (default: %d)\n", params.n_ubatch); printf(" --samplers samplers that will be used for generation in the order, separated by \';\', for example: \"top_k;tfs;typical;top_p;min_p;temp\"\n"); printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sparams.samplers_sequence.c_str()); printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); @@ -1133,6 +1141,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.n_ctx = params.n_ctx; cparams.n_batch = params.n_batch; + cparams.n_ubatch = params.n_ubatch; cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; cparams.mul_mat_q = params.mul_mat_q; diff --git a/common/common.h b/common/common.h index 0ae9c18b3..edc1264de 100644 --- a/common/common.h +++ b/common/common.h @@ -51,7 +51,8 @@ struct gpt_params { int32_t n_threads_batch_draft = -1; int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 512; // context size - int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_batch = 4096; // batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_ubatch = 256; // batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt int32_t n_draft = 8; // number of tokens to draft during speculative decoding int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index f91f5795a..76ccd08ab 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -1035,13 +1035,13 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { int main(int argc, char ** argv) { gpt_params params; - params.n_batch = 512; + //params.n_batch = 512; if (!gpt_params_parse(argc, argv, params)) { return 1; } params.logits_all = true; - params.n_batch = std::min(params.n_batch, params.n_ctx); + //params.n_batch = std::min(params.n_batch, params.n_ctx); if (params.ppl_stride > 0) { fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n", diff --git a/llama.cpp b/llama.cpp index 7ac046ed0..413dd0480 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1418,6 +1418,7 @@ struct llama_hparams { struct llama_cparams { uint32_t n_ctx; // context size used during inference uint32_t n_batch; + uint32_t n_ubatch; uint32_t n_threads; // number of threads to use for generation uint32_t n_threads_batch; // number of threads to use for batch processing @@ -6629,11 +6630,11 @@ static int llama_decode_internal( #endif - const uint32_t n_microbatch = cparams.n_batch; + const uint32_t n_ubatch = cparams.n_ubatch; //const uint32_t n_microbatch = 256; - for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_microbatch) { - const uint32_t n_tokens = std::min(n_microbatch, n_tokens_all - cur_token); + for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) { + const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); llama_batch batch = { /* .n_tokens = */ (int32_t) n_tokens, @@ -6831,8 +6832,8 @@ static int llama_decode_internal( } } - //ggml_backend_sched_synchronize(lctx.sched); - //lctx.buf_cpu_ub_cur = 0; + ggml_backend_sched_synchronize(lctx.sched); + lctx.buf_cpu_ub_cur = 0; // measure the performance only for the single-token evals if (n_tokens_all == 1) { @@ -9701,7 +9702,8 @@ struct llama_context_params llama_context_default_params() { struct llama_context_params result = { /*.seed =*/ LLAMA_DEFAULT_SEED, /*.n_ctx =*/ 512, - /*.n_batch =*/ 512, + /*.n_batch =*/ 4096, + /*.n_ubatch =*/ 256, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_UNSPECIFIED, @@ -9838,6 +9840,7 @@ struct llama_context * llama_new_context_with_model( auto & cparams = ctx->cparams; cparams.n_batch = params.n_batch; + cparams.n_ubatch = params.n_ubatch == 0 ? params.n_batch : params.n_ubatch; cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; @@ -9876,6 +9879,8 @@ struct llama_context * llama_new_context_with_model( } LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); + LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); + LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -9985,7 +9990,7 @@ struct llama_context * llama_new_context_with_model( ctx->alloc_cpu = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu); // build worst-case graph - int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch); + int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch); int n_past = cparams.n_ctx - n_tokens; llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0)); diff --git a/llama.h b/llama.h index e268d7a1d..99425d6d0 100644 --- a/llama.h +++ b/llama.h @@ -218,7 +218,8 @@ extern "C" { struct llama_context_params { uint32_t seed; // RNG seed, -1 for random uint32_t n_ctx; // text context, 0 = from model - uint32_t n_batch; // prompt processing maximum batch size + uint32_t n_batch; // prompt processing maximum batch size (ignored if n_ubatch is set) + uint32_t n_ubatch; // prompt processing maximum batch size uint32_t n_threads; // number of threads to use for generation uint32_t n_threads_batch; // number of threads to use for batch processing int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`