add n_ubatch (-ub) parameter
This commit is contained in:
parent
09688c771b
commit
bc98eda9d5
5 changed files with 27 additions and 11 deletions
|
@ -429,6 +429,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||
break;
|
||||
}
|
||||
params.n_batch = std::stoi(argv[i]);
|
||||
} else if (arg == "-ub" || arg == "--ubatch-size") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.n_ubatch = std::stoi(argv[i]);
|
||||
} else if (arg == "--keep") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -891,6 +897,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
|
||||
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
|
||||
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||
printf(" -ub N, --ubatch-size N\n");
|
||||
printf(" micro batch size for prompt processing (default: %d)\n", params.n_ubatch);
|
||||
printf(" --samplers samplers that will be used for generation in the order, separated by \';\', for example: \"top_k;tfs;typical;top_p;min_p;temp\"\n");
|
||||
printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sparams.samplers_sequence.c_str());
|
||||
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
|
||||
|
@ -1133,6 +1141,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||
|
||||
cparams.n_ctx = params.n_ctx;
|
||||
cparams.n_batch = params.n_batch;
|
||||
cparams.n_ubatch = params.n_ubatch;
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
||||
cparams.mul_mat_q = params.mul_mat_q;
|
||||
|
|
|
@ -51,7 +51,8 @@ struct gpt_params {
|
|||
int32_t n_threads_batch_draft = -1;
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
int32_t n_ctx = 512; // context size
|
||||
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
||||
int32_t n_batch = 4096; // batch size for prompt processing (must be >=32 to use BLAS)
|
||||
int32_t n_ubatch = 256; // batch size for prompt processing (must be >=32 to use BLAS)
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
int32_t n_draft = 8; // number of tokens to draft during speculative decoding
|
||||
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
||||
|
|
|
@ -1035,13 +1035,13 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
|||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
|
||||
params.n_batch = 512;
|
||||
//params.n_batch = 512;
|
||||
if (!gpt_params_parse(argc, argv, params)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
params.logits_all = true;
|
||||
params.n_batch = std::min(params.n_batch, params.n_ctx);
|
||||
//params.n_batch = std::min(params.n_batch, params.n_ctx);
|
||||
|
||||
if (params.ppl_stride > 0) {
|
||||
fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
|
||||
|
|
19
llama.cpp
19
llama.cpp
|
@ -1418,6 +1418,7 @@ struct llama_hparams {
|
|||
struct llama_cparams {
|
||||
uint32_t n_ctx; // context size used during inference
|
||||
uint32_t n_batch;
|
||||
uint32_t n_ubatch;
|
||||
uint32_t n_threads; // number of threads to use for generation
|
||||
uint32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
|
@ -6629,11 +6630,11 @@ static int llama_decode_internal(
|
|||
#endif
|
||||
|
||||
|
||||
const uint32_t n_microbatch = cparams.n_batch;
|
||||
const uint32_t n_ubatch = cparams.n_ubatch;
|
||||
//const uint32_t n_microbatch = 256;
|
||||
|
||||
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_microbatch) {
|
||||
const uint32_t n_tokens = std::min(n_microbatch, n_tokens_all - cur_token);
|
||||
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
|
||||
const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
|
||||
|
||||
llama_batch batch = {
|
||||
/* .n_tokens = */ (int32_t) n_tokens,
|
||||
|
@ -6831,8 +6832,8 @@ static int llama_decode_internal(
|
|||
}
|
||||
}
|
||||
|
||||
//ggml_backend_sched_synchronize(lctx.sched);
|
||||
//lctx.buf_cpu_ub_cur = 0;
|
||||
ggml_backend_sched_synchronize(lctx.sched);
|
||||
lctx.buf_cpu_ub_cur = 0;
|
||||
|
||||
// measure the performance only for the single-token evals
|
||||
if (n_tokens_all == 1) {
|
||||
|
@ -9701,7 +9702,8 @@ struct llama_context_params llama_context_default_params() {
|
|||
struct llama_context_params result = {
|
||||
/*.seed =*/ LLAMA_DEFAULT_SEED,
|
||||
/*.n_ctx =*/ 512,
|
||||
/*.n_batch =*/ 512,
|
||||
/*.n_batch =*/ 4096,
|
||||
/*.n_ubatch =*/ 256,
|
||||
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
|
||||
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
|
||||
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_UNSPECIFIED,
|
||||
|
@ -9838,6 +9840,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
auto & cparams = ctx->cparams;
|
||||
|
||||
cparams.n_batch = params.n_batch;
|
||||
cparams.n_ubatch = params.n_ubatch == 0 ? params.n_batch : params.n_ubatch;
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch;
|
||||
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
||||
|
@ -9876,6 +9879,8 @@ struct llama_context * llama_new_context_with_model(
|
|||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
||||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
||||
|
||||
|
@ -9985,7 +9990,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
ctx->alloc_cpu = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu);
|
||||
|
||||
// build worst-case graph
|
||||
int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch);
|
||||
int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
int n_past = cparams.n_ctx - n_tokens;
|
||||
llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0));
|
||||
|
|
3
llama.h
3
llama.h
|
@ -218,7 +218,8 @@ extern "C" {
|
|||
struct llama_context_params {
|
||||
uint32_t seed; // RNG seed, -1 for random
|
||||
uint32_t n_ctx; // text context, 0 = from model
|
||||
uint32_t n_batch; // prompt processing maximum batch size
|
||||
uint32_t n_batch; // prompt processing maximum batch size (ignored if n_ubatch is set)
|
||||
uint32_t n_ubatch; // prompt processing maximum batch size
|
||||
uint32_t n_threads; // number of threads to use for generation
|
||||
uint32_t n_threads_batch; // number of threads to use for batch processing
|
||||
int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue