add n_ubatch (-ub) parameter

This commit is contained in:
slaren 2024-01-20 16:49:24 +01:00
parent 09688c771b
commit bc98eda9d5
5 changed files with 27 additions and 11 deletions

View file

@ -429,6 +429,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_batch = std::stoi(argv[i]);
} else if (arg == "-ub" || arg == "--ubatch-size") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.n_ubatch = std::stoi(argv[i]);
} else if (arg == "--keep") {
if (++i >= argc) {
invalid_param = true;
@ -891,6 +897,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" -ub N, --ubatch-size N\n");
printf(" micro batch size for prompt processing (default: %d)\n", params.n_ubatch);
printf(" --samplers samplers that will be used for generation in the order, separated by \';\', for example: \"top_k;tfs;typical;top_p;min_p;temp\"\n");
printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sparams.samplers_sequence.c_str());
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
@ -1133,6 +1141,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.n_ctx = params.n_ctx;
cparams.n_batch = params.n_batch;
cparams.n_ubatch = params.n_ubatch;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.mul_mat_q = params.mul_mat_q;

View file

@ -51,7 +51,8 @@ struct gpt_params {
int32_t n_threads_batch_draft = -1;
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_batch = 4096; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_ubatch = 256; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_draft = 8; // number of tokens to draft during speculative decoding
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)

View file

@ -1035,13 +1035,13 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
int main(int argc, char ** argv) {
gpt_params params;
params.n_batch = 512;
//params.n_batch = 512;
if (!gpt_params_parse(argc, argv, params)) {
return 1;
}
params.logits_all = true;
params.n_batch = std::min(params.n_batch, params.n_ctx);
//params.n_batch = std::min(params.n_batch, params.n_ctx);
if (params.ppl_stride > 0) {
fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",

View file

@ -1418,6 +1418,7 @@ struct llama_hparams {
struct llama_cparams {
uint32_t n_ctx; // context size used during inference
uint32_t n_batch;
uint32_t n_ubatch;
uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing
@ -6629,11 +6630,11 @@ static int llama_decode_internal(
#endif
const uint32_t n_microbatch = cparams.n_batch;
const uint32_t n_ubatch = cparams.n_ubatch;
//const uint32_t n_microbatch = 256;
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_microbatch) {
const uint32_t n_tokens = std::min(n_microbatch, n_tokens_all - cur_token);
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
llama_batch batch = {
/* .n_tokens = */ (int32_t) n_tokens,
@ -6831,8 +6832,8 @@ static int llama_decode_internal(
}
}
//ggml_backend_sched_synchronize(lctx.sched);
//lctx.buf_cpu_ub_cur = 0;
ggml_backend_sched_synchronize(lctx.sched);
lctx.buf_cpu_ub_cur = 0;
// measure the performance only for the single-token evals
if (n_tokens_all == 1) {
@ -9701,7 +9702,8 @@ struct llama_context_params llama_context_default_params() {
struct llama_context_params result = {
/*.seed =*/ LLAMA_DEFAULT_SEED,
/*.n_ctx =*/ 512,
/*.n_batch =*/ 512,
/*.n_batch =*/ 4096,
/*.n_ubatch =*/ 256,
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_UNSPECIFIED,
@ -9838,6 +9840,7 @@ struct llama_context * llama_new_context_with_model(
auto & cparams = ctx->cparams;
cparams.n_batch = params.n_batch;
cparams.n_ubatch = params.n_ubatch == 0 ? params.n_batch : params.n_ubatch;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
@ -9876,6 +9879,8 @@ struct llama_context * llama_new_context_with_model(
}
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
@ -9985,7 +9990,7 @@ struct llama_context * llama_new_context_with_model(
ctx->alloc_cpu = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu);
// build worst-case graph
int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch);
int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch);
int n_past = cparams.n_ctx - n_tokens;
llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0));

View file

@ -218,7 +218,8 @@ extern "C" {
struct llama_context_params {
uint32_t seed; // RNG seed, -1 for random
uint32_t n_ctx; // text context, 0 = from model
uint32_t n_batch; // prompt processing maximum batch size
uint32_t n_batch; // prompt processing maximum batch size (ignored if n_ubatch is set)
uint32_t n_ubatch; // prompt processing maximum batch size
uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing
int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`