llama-bench : add -fa,--flash-attn arg
This commit is contained in:
parent
87968de9a9
commit
260cdb2d08
1 changed files with 27 additions and 3 deletions
|
@ -174,6 +174,7 @@ struct cmd_params {
|
||||||
std::vector<llama_split_mode> split_mode;
|
std::vector<llama_split_mode> split_mode;
|
||||||
std::vector<int> main_gpu;
|
std::vector<int> main_gpu;
|
||||||
std::vector<bool> no_kv_offload;
|
std::vector<bool> no_kv_offload;
|
||||||
|
std::vector<bool> flash_attn;
|
||||||
std::vector<std::vector<float>> tensor_split;
|
std::vector<std::vector<float>> tensor_split;
|
||||||
std::vector<bool> use_mmap;
|
std::vector<bool> use_mmap;
|
||||||
std::vector<bool> embeddings;
|
std::vector<bool> embeddings;
|
||||||
|
@ -195,6 +196,7 @@ static const cmd_params cmd_params_defaults = {
|
||||||
/* split_mode */ {LLAMA_SPLIT_MODE_LAYER},
|
/* split_mode */ {LLAMA_SPLIT_MODE_LAYER},
|
||||||
/* main_gpu */ {0},
|
/* main_gpu */ {0},
|
||||||
/* no_kv_offload */ {false},
|
/* no_kv_offload */ {false},
|
||||||
|
/* flash_attn */ {false},
|
||||||
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
|
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
|
||||||
/* use_mmap */ {true},
|
/* use_mmap */ {true},
|
||||||
/* embeddings */ {false},
|
/* embeddings */ {false},
|
||||||
|
@ -220,6 +222,7 @@ static void print_usage(int /* argc */, char ** argv) {
|
||||||
printf(" -sm, --split-mode <none|layer|row> (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
|
printf(" -sm, --split-mode <none|layer|row> (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
|
||||||
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
|
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
|
||||||
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
|
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
|
||||||
|
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
|
||||||
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
|
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
|
||||||
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
|
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
|
||||||
printf(" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n");
|
printf(" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n");
|
||||||
|
@ -393,6 +396,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
auto p = split<bool>(argv[i], split_delim);
|
auto p = split<bool>(argv[i], split_delim);
|
||||||
params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end());
|
params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end());
|
||||||
|
} else if (arg == "-fa" || arg == "--flash-attn") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
auto p = split<bool>(argv[i], split_delim);
|
||||||
|
params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end());
|
||||||
} else if (arg == "-mmp" || arg == "--mmap") {
|
} else if (arg == "-mmp" || arg == "--mmap") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -477,6 +487,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||||
if (params.split_mode.empty()) { params.split_mode = cmd_params_defaults.split_mode; }
|
if (params.split_mode.empty()) { params.split_mode = cmd_params_defaults.split_mode; }
|
||||||
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
|
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
|
||||||
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
|
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
|
||||||
|
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
|
||||||
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
|
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
|
||||||
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
|
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
|
||||||
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
|
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
|
||||||
|
@ -498,6 +509,7 @@ struct cmd_params_instance {
|
||||||
llama_split_mode split_mode;
|
llama_split_mode split_mode;
|
||||||
int main_gpu;
|
int main_gpu;
|
||||||
bool no_kv_offload;
|
bool no_kv_offload;
|
||||||
|
bool flash_attn;
|
||||||
std::vector<float> tensor_split;
|
std::vector<float> tensor_split;
|
||||||
bool use_mmap;
|
bool use_mmap;
|
||||||
bool embeddings;
|
bool embeddings;
|
||||||
|
@ -532,6 +544,7 @@ struct cmd_params_instance {
|
||||||
cparams.type_k = type_k;
|
cparams.type_k = type_k;
|
||||||
cparams.type_v = type_v;
|
cparams.type_v = type_v;
|
||||||
cparams.offload_kqv = !no_kv_offload;
|
cparams.offload_kqv = !no_kv_offload;
|
||||||
|
cparams.flash_attn = flash_attn;
|
||||||
cparams.embeddings = embeddings;
|
cparams.embeddings = embeddings;
|
||||||
|
|
||||||
return cparams;
|
return cparams;
|
||||||
|
@ -554,6 +567,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||||
for (const auto & tk : params.type_k)
|
for (const auto & tk : params.type_k)
|
||||||
for (const auto & tv : params.type_v)
|
for (const auto & tv : params.type_v)
|
||||||
for (const auto & nkvo : params.no_kv_offload)
|
for (const auto & nkvo : params.no_kv_offload)
|
||||||
|
for (const auto & fa : params.flash_attn)
|
||||||
for (const auto & nt : params.n_threads) {
|
for (const auto & nt : params.n_threads) {
|
||||||
for (const auto & n_prompt : params.n_prompt) {
|
for (const auto & n_prompt : params.n_prompt) {
|
||||||
if (n_prompt == 0) {
|
if (n_prompt == 0) {
|
||||||
|
@ -572,6 +586,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||||
/* .split_mode = */ sm,
|
/* .split_mode = */ sm,
|
||||||
/* .main_gpu = */ mg,
|
/* .main_gpu = */ mg,
|
||||||
/* .no_kv_offload= */ nkvo,
|
/* .no_kv_offload= */ nkvo,
|
||||||
|
/* .flash_attn = */ fa,
|
||||||
/* .tensor_split = */ ts,
|
/* .tensor_split = */ ts,
|
||||||
/* .use_mmap = */ mmp,
|
/* .use_mmap = */ mmp,
|
||||||
/* .embeddings = */ embd,
|
/* .embeddings = */ embd,
|
||||||
|
@ -596,6 +611,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||||
/* .split_mode = */ sm,
|
/* .split_mode = */ sm,
|
||||||
/* .main_gpu = */ mg,
|
/* .main_gpu = */ mg,
|
||||||
/* .no_kv_offload= */ nkvo,
|
/* .no_kv_offload= */ nkvo,
|
||||||
|
/* .flash_attn = */ fa,
|
||||||
/* .tensor_split = */ ts,
|
/* .tensor_split = */ ts,
|
||||||
/* .use_mmap = */ mmp,
|
/* .use_mmap = */ mmp,
|
||||||
/* .embeddings = */ embd,
|
/* .embeddings = */ embd,
|
||||||
|
@ -633,6 +649,7 @@ struct test {
|
||||||
llama_split_mode split_mode;
|
llama_split_mode split_mode;
|
||||||
int main_gpu;
|
int main_gpu;
|
||||||
bool no_kv_offload;
|
bool no_kv_offload;
|
||||||
|
bool flash_attn;
|
||||||
std::vector<float> tensor_split;
|
std::vector<float> tensor_split;
|
||||||
bool use_mmap;
|
bool use_mmap;
|
||||||
bool embeddings;
|
bool embeddings;
|
||||||
|
@ -657,6 +674,7 @@ struct test {
|
||||||
split_mode = inst.split_mode;
|
split_mode = inst.split_mode;
|
||||||
main_gpu = inst.main_gpu;
|
main_gpu = inst.main_gpu;
|
||||||
no_kv_offload = inst.no_kv_offload;
|
no_kv_offload = inst.no_kv_offload;
|
||||||
|
flash_attn = inst.flash_attn;
|
||||||
tensor_split = inst.tensor_split;
|
tensor_split = inst.tensor_split;
|
||||||
use_mmap = inst.use_mmap;
|
use_mmap = inst.use_mmap;
|
||||||
embeddings = inst.embeddings;
|
embeddings = inst.embeddings;
|
||||||
|
@ -731,7 +749,7 @@ struct test {
|
||||||
"n_batch", "n_ubatch",
|
"n_batch", "n_ubatch",
|
||||||
"n_threads", "type_k", "type_v",
|
"n_threads", "type_k", "type_v",
|
||||||
"n_gpu_layers", "split_mode",
|
"n_gpu_layers", "split_mode",
|
||||||
"main_gpu", "no_kv_offload",
|
"main_gpu", "no_kv_offload", "flash_attn",
|
||||||
"tensor_split", "use_mmap", "embeddings",
|
"tensor_split", "use_mmap", "embeddings",
|
||||||
"n_prompt", "n_gen", "test_time",
|
"n_prompt", "n_gen", "test_time",
|
||||||
"avg_ns", "stddev_ns",
|
"avg_ns", "stddev_ns",
|
||||||
|
@ -753,7 +771,7 @@ struct test {
|
||||||
}
|
}
|
||||||
if (field == "cuda" || field == "opencl" || field == "vulkan" || field == "kompute" || field == "metal" ||
|
if (field == "cuda" || field == "opencl" || field == "vulkan" || field == "kompute" || field == "metal" ||
|
||||||
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
|
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
|
||||||
field == "use_mmap" || field == "embeddings") {
|
field == "flash_attn" || field == "use_mmap" || field == "embeddings") {
|
||||||
return BOOL;
|
return BOOL;
|
||||||
}
|
}
|
||||||
if (field == "avg_ts" || field == "stddev_ts") {
|
if (field == "avg_ts" || field == "stddev_ts") {
|
||||||
|
@ -787,7 +805,7 @@ struct test {
|
||||||
std::to_string(n_batch), std::to_string(n_ubatch),
|
std::to_string(n_batch), std::to_string(n_ubatch),
|
||||||
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
|
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
|
||||||
std::to_string(n_gpu_layers), split_mode_str(split_mode),
|
std::to_string(n_gpu_layers), split_mode_str(split_mode),
|
||||||
std::to_string(main_gpu), std::to_string(no_kv_offload),
|
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
|
||||||
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
|
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
|
||||||
std::to_string(n_prompt), std::to_string(n_gen), test_time,
|
std::to_string(n_prompt), std::to_string(n_gen), test_time,
|
||||||
std::to_string(avg_ns()), std::to_string(stdev_ns()),
|
std::to_string(avg_ns()), std::to_string(stdev_ns()),
|
||||||
|
@ -955,6 +973,9 @@ struct markdown_printer : public printer {
|
||||||
if (field == "no_kv_offload") {
|
if (field == "no_kv_offload") {
|
||||||
return "nkvo";
|
return "nkvo";
|
||||||
}
|
}
|
||||||
|
if (field == "flash_attn") {
|
||||||
|
return "fa";
|
||||||
|
}
|
||||||
if (field == "use_mmap") {
|
if (field == "use_mmap") {
|
||||||
return "mmap";
|
return "mmap";
|
||||||
}
|
}
|
||||||
|
@ -1001,6 +1022,9 @@ struct markdown_printer : public printer {
|
||||||
if (params.no_kv_offload.size() > 1 || params.no_kv_offload != cmd_params_defaults.no_kv_offload) {
|
if (params.no_kv_offload.size() > 1 || params.no_kv_offload != cmd_params_defaults.no_kv_offload) {
|
||||||
fields.emplace_back("no_kv_offload");
|
fields.emplace_back("no_kv_offload");
|
||||||
}
|
}
|
||||||
|
if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) {
|
||||||
|
fields.emplace_back("flash_attn");
|
||||||
|
}
|
||||||
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
|
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
|
||||||
fields.emplace_back("tensor_split");
|
fields.emplace_back("tensor_split");
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue