diff --git a/.gitignore b/.gitignore index fba207045..5ab81445d 100644 --- a/.gitignore +++ b/.gitignore @@ -105,3 +105,4 @@ poetry.toml /tests/test-tokenizer-1-bpe /tests/test-rope /tests/test-backend-ops +/tests/test-autorelease diff --git a/Makefile b/Makefile index 995b89f7a..a8658a596 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ TEST_TARGETS = \ tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt \ tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama \ tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama tests/test-tokenizer-1-bpe tests/test-rope \ - tests/test-backend-ops + tests/test-backend-ops tests/test-autorelease # Code coverage output files COV_TARGETS = *.gcno tests/*.gcno *.gcda tests/*.gcda *.gcov tests/*.gcov lcov-report gcovr-report @@ -747,3 +747,6 @@ tests/test-c.o: tests/test-c.c llama.h tests/test-backend-ops: tests/test-backend-ops.cpp ggml.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) + +tests/test-autorelease: tests/test-autorelease.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) diff --git a/ci/run.sh b/ci/run.sh index 47a254f4c..791b17a19 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -36,6 +36,10 @@ if [ ! -z ${GG_BUILD_METAL} ]; then CMAKE_EXTRA="${CMAKE_EXTRA} -DLLAMA_METAL_SHADER_DEBUG=ON" fi +if [ ! -z ${GG_BUILD_CUDA} ]; then + CMAKE_EXTRA="${CMAKE_EXTRA} -DLLAMA_CUBLAS=1" +fi + ## helpers # download a file if it does not exist or if it is outdated @@ -160,8 +164,8 @@ function gg_run_open_llama_3b_v2 { set -e - (time cmake -DCMAKE_BUILD_TYPE=Release -DLLAMA_QKK_64=1 .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log - (time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log + (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} -DLLAMA_QKK_64=1 .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log + (time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log python3 ../convert.py ${path_models} @@ -179,6 +183,8 @@ function gg_run_open_llama_3b_v2 { wiki_test_60="${path_wiki}/wiki.test-60.raw" + ./bin/test-autorelease ${model_f16} + ./bin/quantize ${model_f16} ${model_q8_0} q8_0 ./bin/quantize ${model_f16} ${model_q4_0} q4_0 ./bin/quantize ${model_f16} ${model_q4_1} q4_1 @@ -214,6 +220,8 @@ function gg_run_open_llama_3b_v2 { (time ./bin/perplexity --model ${model_q5_k} -f ${wiki_test_60} -c 128 -b 128 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log (time ./bin/perplexity --model ${model_q6_k} -f ${wiki_test_60} -c 128 -b 128 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log + (time ./bin/imatrix --model ${model_f16} -f ${wiki_test_60} -c 128 -b 128 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log + (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log function check_ppl { @@ -241,6 +249,8 @@ function gg_run_open_llama_3b_v2 { check_ppl "q5_k" "$(cat $OUT/${ci}-tg-q5_k.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log check_ppl "q6_k" "$(cat $OUT/${ci}-tg-q6_k.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + cat $OUT/${ci}-imatrix.log | grep "Final" >> $OUT/${ci}-imatrix-sum.log + # lora function compare_ppl { qnt="$1" @@ -282,7 +292,6 @@ function gg_run_open_llama_3b_v2 { (time ./bin/perplexity --model ${model_q8_0} -f ${shakespeare} --lora ${lora_shakespeare} --lora-base ${model_f16} -c 128 -b 128 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-ppl-shakespeare-lora-q8_0-f16.log compare_ppl "q8_0 / f16 base shakespeare" "$(cat $OUT/${ci}-ppl-shakespeare-q8_0.log | grep "^\[1\]")" "$(cat $OUT/${ci}-ppl-shakespeare-lora-q8_0-f16.log | grep "^\[1\]")" | tee -a $OUT/${ci}-lora-ppl.log - set +e } @@ -292,6 +301,7 @@ function gg_sum_open_llama_3b_v2 { gg_printf 'OpenLLaMA 3B-v2:\n' gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" gg_printf '- perplexity:\n%s\n' "$(cat $OUT/${ci}-ppl.log)" + gg_printf '- imatrix:\n```\n%s\n```\n' "$(cat $OUT/${ci}-imatrix-sum.log)" gg_printf '- lora:\n%s\n' "$(cat $OUT/${ci}-lora-ppl.log)" gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)" gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)" @@ -337,8 +347,8 @@ function gg_run_open_llama_7b_v2 { set -e - (time cmake -DCMAKE_BUILD_TYPE=Release -DLLAMA_CUBLAS=1 .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log - (time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log + (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} -DLLAMA_CUBLAS=1 .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log + (time make -j ) 2>&1 | tee -a $OUT/${ci}-make.log python3 ../convert.py ${path_models} @@ -391,6 +401,8 @@ function gg_run_open_llama_7b_v2 { (time ./bin/perplexity --model ${model_q5_k} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log (time ./bin/perplexity --model ${model_q6_k} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log + (time ./bin/imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log + (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log function check_ppl { @@ -418,6 +430,8 @@ function gg_run_open_llama_7b_v2 { check_ppl "q5_k" "$(cat $OUT/${ci}-tg-q5_k.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log check_ppl "q6_k" "$(cat $OUT/${ci}-tg-q6_k.log | grep "^\[1\]")" | tee -a $OUT/${ci}-ppl.log + cat $OUT/${ci}-imatrix.log | grep "Final" >> $OUT/${ci}-imatrix-sum.log + # lora function compare_ppl { qnt="$1" @@ -469,6 +483,7 @@ function gg_sum_open_llama_7b_v2 { gg_printf 'OpenLLaMA 7B-v2:\n' gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" gg_printf '- perplexity:\n%s\n' "$(cat $OUT/${ci}-ppl.log)" + gg_printf '- imatrix:\n```\n%s\n```\n' "$(cat $OUT/${ci}-imatrix-sum.log)" gg_printf '- lora:\n%s\n' "$(cat $OUT/${ci}-lora-ppl.log)" gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)" gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)" diff --git a/common/common.cpp b/common/common.cpp index 2b0865fff..ce20360a4 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -681,6 +681,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.hellaswag_tasks = std::stoi(argv[i]); + } else if (arg == "--winogrande") { + params.winogrande = true; + } else if (arg == "--winogrande-tasks") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.winogrande_tasks = std::stoi(argv[i]); } else if (arg == "--ignore-eos") { params.ignore_eos = true; } else if (arg == "--no-penalize-nl") { @@ -926,6 +934,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n"); printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n"); printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks); + printf(" --winogrande compute Winogrande score over random tasks from datafile supplied with -f\n"); + printf(" --winogrande-tasks N number of tasks to use when computing the Winogrande score (default: %zu)\n", params.winogrande_tasks); printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft); printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); diff --git a/common/common.h b/common/common.h index 1f43e6282..0ae9c18b3 100644 --- a/common/common.h +++ b/common/common.h @@ -105,6 +105,9 @@ struct gpt_params { bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score + bool winogrande = false; // compute Winogrande score over random tasks from datafile supplied in prompt + size_t winogrande_tasks= 0; // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed + bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS bool random_prompt = false; // do not randomize prompt if none provided bool use_color = false; // use color to distinguish generations and inputs diff --git a/convert.py b/convert.py index b47bb6185..e38ee5315 100755 --- a/convert.py +++ b/convert.py @@ -1100,7 +1100,7 @@ class OutputFile: scores.append(score) toktypes.append(toktype) - assert(len(tokens) == vocab.vocab_size) + assert len(tokens) == vocab.vocab_size return tokens, scores, toktypes diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 1461bc963..af78711c5 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -33,43 +33,120 @@ class IMatrixCollector { public: IMatrixCollector() = default; void set_parameters(StatParams&& params) { m_params = std::move(params); } - void collect_imatrix(const struct ggml_tensor * src0, const struct ggml_tensor * src1); + bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data); void save_imatrix() const; private: std::unordered_map m_stats; StatParams m_params; std::mutex m_mutex; int m_last_call = 0; + std::vector m_src1_data; + std::vector m_ids; // the expert ids from ggml_mul_mat_id }; -void IMatrixCollector::collect_imatrix(const struct ggml_tensor * src0, const struct ggml_tensor * src1) { - if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return; - if (!(strncmp(src0->name, "blk.", 4) == 0 || (m_params.collect_output_weight && strcmp(src0->name, "output.weight") == 0))) return; +bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) { + GGML_UNUSED(user_data); + + const struct ggml_tensor * src0 = t->src[0]; + const struct ggml_tensor * src1 = t->src[1]; + + // when ask is true, the scheduler wants to know if we are interested in data from this tensor + // if we return true, a follow-up call will be made with ask=false in which we can do the actual collection + if (ask) { + if (t->op == GGML_OP_MUL_MAT_ID) return true; // collect all indirect matrix multiplications + if (t->op != GGML_OP_MUL_MAT) return false; + if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return false; + if (!(strncmp(src0->name, "blk.", 4) == 0 || (m_params.collect_output_weight && strcmp(src0->name, "output.weight") == 0))) return false; + return true; + } + std::lock_guard lock(m_mutex); - auto& e = m_stats[src0->name]; - if (e.values.empty()) { - e.values.resize(src1->ne[0], 0); + + // copy the data from the GPU memory if needed + const bool is_host = ggml_backend_buffer_is_host(src1->buffer); + + if (!is_host) { + m_src1_data.resize(ggml_nelements(src1)); + ggml_backend_tensor_get(src1, m_src1_data.data(), 0, ggml_nbytes(src1)); } - else if (e.values.size() != (size_t)src1->ne[0]) { - fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", src0->name, (int)e.values.size(), (int)src1->ne[0]); - exit(1); //GGML_ASSERT(false); - } - ++e.ncall; - if (m_params.verbosity > 1) { - printf("%s[%d]: %s, %d x %d, %d\n",__func__,m_last_call,src0->name,(int)src1->ne[0],(int)src1->ne[1],(int)src1->type); - } - for (int row = 0; row < (int)src1->ne[1]; ++row) { - const float * x = (const float *)src1->data + row * src1->ne[0]; - for (int j = 0; j < (int)src1->ne[0]; ++j) { - e.values[j] += x[j]*x[j]; - } - } - if (e.ncall > m_last_call) { - m_last_call = e.ncall; - if (m_last_call % m_params.n_output_frequency == 0) { - save_imatrix(); + + const float * data = is_host ? (const float *) src1->data : m_src1_data.data(); + + if (t->op == GGML_OP_MUL_MAT_ID) { + const int idx = ((int32_t *) t->op_params)[0]; + const int n_as = ((int32_t *) t->op_params)[1]; + + // the top-k selected expert ids are stored in the src0 tensor + // for simplicity, always copy src0 to host, because it is small + // take into account that src0 is not contiguous! + GGML_ASSERT(src0->ne[1] == src1->ne[1]); + GGML_ASSERT(n_as*ggml_nrows(src0)); + m_ids.resize(ggml_nbytes(src0)/sizeof(int)); + ggml_backend_tensor_get(src0, m_ids.data(), 0, ggml_nbytes(src0)); + + // loop over all possible experts, regardless if they are used or not in the batch + // this is necessary to guarantee equal number of "ncall" for each tensor + for (int ex = 0; ex < n_as; ++ex) { + src0 = t->src[2 + ex]; + auto& e = m_stats[src0->name]; + if (e.values.empty()) { + e.values.resize(src1->ne[0], 0); + } + else if (e.values.size() != (size_t)src1->ne[0]) { + fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", src0->name, (int)e.values.size(), (int)src1->ne[0]); + exit(1); //GGML_ASSERT(false); + } + // NOTE: since we select top-k experts, the number of calls for the expert tensors will be k times larger + // using the following line, we can correct for that if needed + //if (idx == t->src[0]->ne[0] - 1) ++e.ncall; + ++e.ncall; + if (m_params.verbosity > 1) { + printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, src0->name, ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type); + } + for (int row = 0; row < (int)src1->ne[1]; ++row) { + const int excur = m_ids[row*n_as + idx]; + GGML_ASSERT(excur >= 0 && excur < n_as); // sanity check + if (excur != ex) continue; + const float * x = data + row * src1->ne[0]; + for (int j = 0; j < (int)src1->ne[0]; ++j) { + e.values[j] += x[j]*x[j]; + } + } + if (e.ncall > m_last_call) { + m_last_call = e.ncall; + if (m_last_call % m_params.n_output_frequency == 0) { + save_imatrix(); + } + } + } + } else { + auto& e = m_stats[src0->name]; + if (e.values.empty()) { + e.values.resize(src1->ne[0], 0); + } + else if (e.values.size() != (size_t)src1->ne[0]) { + fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", src0->name, (int)e.values.size(), (int)src1->ne[0]); + exit(1); //GGML_ASSERT(false); + } + ++e.ncall; + if (m_params.verbosity > 1) { + printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, src0->name, ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type); + } + for (int row = 0; row < (int)src1->ne[1]; ++row) { + const float * x = data + row * src1->ne[0]; + for (int j = 0; j < (int)src1->ne[0]; ++j) { + e.values[j] += x[j]*x[j]; + } + } + if (e.ncall > m_last_call) { + m_last_call = e.ncall; + if (m_last_call % m_params.n_output_frequency == 0) { + save_imatrix(); + } } } + + return true; } void IMatrixCollector::save_imatrix() const { @@ -93,8 +170,8 @@ void IMatrixCollector::save_imatrix() const { static IMatrixCollector g_collector; -static void ik_collect_imatrix(const struct ggml_tensor * src0, const struct ggml_tensor * src1) { - g_collector.collect_imatrix(src0, src1); +static bool ik_collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) { + return g_collector.collect_imatrix(t, ask, user_data); } @@ -320,8 +397,6 @@ int main(int argc, char ** argv) { g_collector.set_parameters(std::move(sparams)); - ggml_set_imatrix_collection(ik_collect_imatrix); - params.logits_all = true; params.n_batch = std::min(params.n_batch, params.n_ctx); @@ -340,16 +415,27 @@ int main(int argc, char ** argv) { llama_backend_init(params.numa); - llama_model * model; - llama_context * ctx; + llama_model_params mparams = llama_model_params_from_gpt_params(params); - // load the model and apply lora adapter, if any - std::tie(model, ctx) = llama_init_from_gpt_params(params); + llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams); if (model == NULL) { fprintf(stderr, "%s: error: unable to load model\n", __func__); return 1; } + llama_context_params cparams = llama_context_params_from_gpt_params(params); + + // pass the callback to the backend scheduler + // it will be executed for each node during the graph computation + cparams.cb_eval = ik_collect_imatrix; + cparams.cb_eval_user_data = NULL; + + llama_context * ctx = llama_new_context_with_model(model, cparams); + if (ctx == NULL) { + fprintf(stderr, "%s: error: unable to create context\n", __func__); + return 1; + } + const int n_ctx_train = llama_n_ctx_train(model); if (params.n_ctx > n_ctx_train) { fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index b4fedf803..ea2c8026c 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -9,6 +9,9 @@ #include #include #include +#include +#include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -419,9 +422,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par return {tokens, ppl, logit_history, prob_history}; } -static std::vector hellaswag_evaluate_tokens( - llama_context * ctx, std::vector & tokens, int n_past, int n_batch, int n_vocab -) { +static std::vector evaluate_tokens(llama_context * ctx, std::vector & tokens, + int n_past, int n_batch, int n_vocab) { std::vector result; result.reserve(tokens.size() * n_vocab); size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch; @@ -468,7 +470,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { prompt_lines.push_back(line); } - if( prompt_lines.size() % 6 != 0) { + if (prompt_lines.size() % 6 != 0) { fprintf(stderr, "%s : number of lines in prompt not a multiple of 6.\n", __func__); return; } @@ -483,7 +485,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); // Number of tasks to use when computing the score - if ( params.hellaswag_tasks < hs_task_count ) { + if (params.hellaswag_tasks < hs_task_count) { hs_task_count = params.hellaswag_tasks; } @@ -500,27 +502,54 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { std::string ending[4]; size_t ending_logprob_count[4]; double ending_logprob[4]; + + size_t i_batch; // starting index in the llama_batch + size_t common_prefix; // max number of initial tokens that are the same in all sentences + size_t required_tokens; // needed number of tokens to evaluate all 4 endings + std::vector seq_tokens[4]; }; fprintf(stderr, "%s : selecting %zu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first") ); // Select and read data from prompt lines - hs_data_t *hs_data = new hs_data_t[hs_task_count]; - for (size_t i=0; i < hs_task_count; i++) { + std::vector hs_data(hs_task_count); + for (size_t i = 0; i < hs_task_count; i++) { size_t idx = i; + auto & hs_cur = hs_data[i]; + // Select a random example of those left in the prompt if (randomize_tasks) { std::uniform_int_distribution dist(0, prompt_lines.size()/6-1 ) ; idx = dist(rng); } - hs_data[i].context = prompt_lines[idx*6]; - hs_data[i].gold_ending_idx = std::stoi( prompt_lines[idx*6+1] ); - for (size_t j=0; j < 4; j++) { - hs_data[i].ending[j] = prompt_lines[idx*6+2+j]; + hs_cur.context = prompt_lines[idx*6]; + hs_cur.gold_ending_idx = std::stoi( prompt_lines[idx*6+1] ); + for (size_t j = 0; j < 4; j++) { + hs_cur.ending[j] = prompt_lines[idx*6+2+j]; + hs_cur.seq_tokens[j] = ::llama_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], add_bos); } + // determine the common prefix of the endings + hs_cur.common_prefix = 0; + hs_cur.required_tokens = 0; + for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) { + if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] || + hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] || + hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[3][k]) { + break; + } + hs_cur.common_prefix++; + } + hs_cur.required_tokens = hs_cur.common_prefix + + hs_cur.seq_tokens[0].size() - hs_cur.common_prefix + + hs_cur.seq_tokens[1].size() - hs_cur.common_prefix + + hs_cur.seq_tokens[2].size() - hs_cur.common_prefix + + hs_cur.seq_tokens[3].size() - hs_cur.common_prefix; + + //GGML_ASSERT(hs_cur.common_prefix >= ::llama_tokenize(ctx, hs_cur.context, add_bos).size()); + // Delete the selected random example from the prompt if (randomize_tasks) { prompt_lines.erase( std::next(prompt_lines.begin(),idx*6) , std::next(prompt_lines.begin(),idx*6+6) ); @@ -528,154 +557,393 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { } fprintf(stderr, "%s : calculating hellaswag score over selected tasks.\n", __func__); + printf("\ntask\tacc_norm\n"); double acc = 0.0f; - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); - const int n_ctx = llama_n_ctx(ctx); - std::vector> ending_tokens(4); + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + const int n_ctx = llama_n_ctx(ctx); + const int n_batch = params.n_batch; + + const int max_tasks_per_batch = params.n_parallel; + const int max_seq = 4*max_tasks_per_batch; + + llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); std::vector tok_logits(n_vocab); + std::vector batch_logits(n_ctx*n_vocab); - for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) { - // Tokenize the context to count tokens - std::vector context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, add_bos); - size_t context_size = context_embd.size(); + auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) { + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); - for (int i = 0; i < 4; ++i) { - ending_tokens[i] = ::llama_tokenize(ctx, hs_data[task_idx].context + " " + hs_data[task_idx].ending[i], add_bos); - for (int k = 0; k < int(context_size); ++k) { - if (ending_tokens[i][k] != context_embd[k]) { - fprintf(stderr, "Oops: ending %d of task %d differs from context at position %d\n",i,int(task_idx),k); - break; + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, + 0, 0, 0, // unused + }; + + const int ret = llama_decode(ctx, batch_view); + if (ret != 0) { + LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); + return false; + } + + memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float)); + } + + return true; + }; + + for (size_t i0 = 0; i0 < hs_task_count; i0++) { + int n_cur = 0; + + size_t i1 = i0; + size_t i_batch = 0; // this tells us where in `llama_batch` we are currently + + llama_batch_clear(batch); + + // batch as much tasks as possible into the available context + // each task has 4 unique seuqnce ids - one for each ending + // the common prefix is shared among the 4 sequences to save tokens + // we extract logits only from the last common token and from all ending tokens of each sequence + while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) { + auto & hs_cur = hs_data[i1]; + + const int s0 = 4*(i1 - i0); + if (s0 + 4 > max_seq) { + break; + } + + for (size_t i = 0; i < hs_cur.common_prefix; ++i) { + llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); + } + batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix + + for (int s = 0; s < 4; ++s) { + for (size_t i = hs_cur.common_prefix; i < hs_cur.seq_tokens[s].size(); ++i) { + llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, true); } } + + hs_cur.i_batch = i_batch; + i_batch += hs_cur.required_tokens; + + n_cur += hs_data[i1].required_tokens; + if (++i1 == hs_task_count) { + break; + } } - // Do the 1st ending - // In this case we include the context when evaluating - //auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos); - auto query_embd = ending_tokens[0]; - auto query_size = query_embd.size(); - - // Stop if query wont fit the ctx window - if (query_size > (size_t)n_ctx) { - fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size); + if (i0 == i1) { + fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0); return; } - // Speedup small evaluations by evaluating atleast 32 tokens - if (query_size < 32) { - query_embd.resize(32); - } - - // clear the KV cache llama_kv_cache_clear(ctx); - auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab); - if (logits.empty()) { + // decode all tasks [i0, i1) + if (!decode_helper(ctx, batch, n_batch)) { + fprintf(stderr, "%s: llama_decode() failed\n", __func__); + return; + } + + // compute the logprobs for each ending of the decoded tasks + for (size_t i = i0; i < i1; ++i) { + auto & hs_cur = hs_data[i]; + + std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + hs_cur.common_prefix - 1), n_vocab*sizeof(float)); + + const auto first_probs = softmax(tok_logits); + + size_t li = hs_cur.common_prefix; // logits index in the batch + + for (int s = 0; s < 4; ++s) { + hs_cur.ending_logprob_count[s] = 1; + hs_cur.ending_logprob[s] = std::log(first_probs[hs_cur.seq_tokens[s][hs_cur.common_prefix]]); + + // Calculate the logprobs over the ending + for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) { + std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + li++), n_vocab*sizeof(float)); + + const float prob = softmax(tok_logits)[hs_cur.seq_tokens[s][j + 1]]; + + hs_cur.ending_logprob[s] += std::log(prob); + hs_cur.ending_logprob_count[s]++; + } + + // account that we skip the last token in the ending + ++li; + + // Calculate the mean token logprob for acc_norm + hs_cur.ending_logprob[s] /= hs_cur.ending_logprob_count[s]; + } + + // Find the ending with maximum logprob + size_t ending_logprob_max_idx = 0; + double ending_logprob_max_val = hs_cur.ending_logprob[0]; + for (size_t s = 1; s < 4; s++) { + if (hs_cur.ending_logprob[s] > ending_logprob_max_val) { + ending_logprob_max_idx = s; + ending_logprob_max_val = hs_cur.ending_logprob[s]; + } + } + + //printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_cur.gold_ending_idx); + + // If the gold ending got the maximum logprobe add one accuracy point + if (ending_logprob_max_idx == hs_cur.gold_ending_idx) { + acc += 1.0; + } + + // Print the accumulated accuracy mean x 100 + printf("%zu\t%.8lf\n", i + 1, acc/double(i + 1)*100.0); + fflush(stdout); + } + + i0 = i1 - 1; + } + + llama_batch_free(batch); + + printf("\n"); +} + +struct winogrande_entry { + std::string first; + std::string second; + std::array choices; + int answer; +}; + +static std::vector load_winogrande_from_csv(const std::string& prompt) { + std::vector result; + std::istringstream in(prompt); + std::string line; + std::array comma_pos; + while (true) { + std::getline(in, line); + if (in.fail() || in.eof()) break; + int ipos = 0; + bool quote_open = false; + for (int i = 0; i < int(line.size()); ++i) { + if (!quote_open) { + if (line[i] == ',') { + comma_pos[ipos++] = i; + if (ipos == 4) break; + } + else if (line[i] == '"') { + quote_open = true; + } + } + else { + if (line[i] == '"') { + quote_open = false; + } + } + } + if (ipos != 4) { + printf("%s: failed to find comma separators in <%s>\n", __func__, line.c_str()); + continue; + } + auto sentence = line[comma_pos[0]+1] == '"' ? line.substr(comma_pos[0]+2, comma_pos[1] - comma_pos[0] - 3) + : line.substr(comma_pos[0]+1, comma_pos[1] - comma_pos[0] - 1); + auto choice1 = line.substr(comma_pos[1]+1, comma_pos[2] - comma_pos[1] - 1); + auto choice2 = line.substr(comma_pos[2]+1, comma_pos[3] - comma_pos[2] - 1); + auto answer = line.substr(comma_pos[3]+1, line.size() - comma_pos[3] - 1); + auto index = line.substr(0, comma_pos[0]); + int where = 0; + for ( ; where < int(sentence.size()); ++where) { + if (sentence[where] == '_') break; + } + if (where == int(sentence.size())) { + printf("%s: no _ in <%s>\n", __func__, sentence.c_str()); + continue; + } + std::istringstream stream(answer.c_str()); + int i_answer; stream >> i_answer; + if (stream.fail() || i_answer < 1 || i_answer > 2) { + printf("%s: failed to parse answer <%s>\n", __func__, answer.c_str()); + continue; + } + result.emplace_back(); + auto& wg = result.back(); + wg.first = sentence.substr(0, where); + wg.second = sentence.substr(where + 1, sentence.size() - where - 1); + wg.choices[0] = std::move(choice1); + wg.choices[1] = std::move(choice2); + wg.answer = i_answer; + } + return result; +} + +/* + * Evaluates the Winogrande score. + * Uses a CSV containing task index, dentence, choice 1, choice 2, answer (1 or 2) + * You can get one such dataset from e.g. https://huggingface.co/datasets/ikawrakow/winogrande-eval-for-llama.cpp + * As an example, the 1st row in the above dataset is + * + * 0,Sarah was a much better surgeon than Maria so _ always got the easier cases.,Sarah,Maria,2 + * + */ +static void winogrande_score(llama_context * ctx, const gpt_params & params) { + + constexpr int k_min_trailing_ctx = 3; + + auto data = load_winogrande_from_csv(params.prompt); + if (data.empty()) { + fprintf(stderr, "%s: no tasks\n", __func__); + return; + } + + fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, data.size()); + + if (params.winogrande_tasks > 0 && params.winogrande_tasks < data.size()) { + fprintf(stderr, "%s : selecting %zu random tasks\n", __func__, params.winogrande_tasks); + std::mt19937 rng(1); + std::vector aux(data.size()); + for (int i = 0; i < int(data.size()); ++i) { + aux[i] = i; + } + float scale = 1/(1.f + (float)rng.max()); + std::vector selected; + selected.reserve(params.winogrande_tasks); + for (int i = 0; i < int(params.winogrande_tasks); ++i) { + int j = int(scale*rng()*aux.size()); + selected[i] = std::move(data[aux[j]]); + aux[j] = aux.back(); + aux.pop_back(); + } + data = std::move(selected); + } + + // This is needed as usual for LLaMA models + const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); + + fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__); + + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + const int n_ctx = llama_n_ctx(ctx); + + std::vector tok_logits(n_vocab); + + int n_correct = 0; + int n_done = 0; + + for (size_t task_idx = 0; task_idx < data.size(); task_idx++) { + const auto& task = data[task_idx]; + + auto base_context = ::llama_tokenize(ctx, task.first, add_bos); + auto base_ctx_1st = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos); + auto base_ctx_2nd = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos); + + auto sentence_1st = task.first + task.choices[0] + task.second; + auto sentence_2nd = task.first + task.choices[1] + task.second; + auto query_1st = ::llama_tokenize(ctx, sentence_1st, add_bos); + auto query_2nd = ::llama_tokenize(ctx, sentence_2nd, add_bos); + + if (query_1st.size() > (size_t)n_ctx || query_2nd.size() > (size_t)n_ctx) { + fprintf(stderr, "%s : number of tokens in queries %zu, %zu > n_ctxl\n", __func__, query_1st.size(), query_2nd.size()); + return; + } + + auto query_1st_size = query_1st.size(); + auto query_2nd_size = query_2nd.size(); + + // Speedup small evaluations by evaluating atleast 32 tokens + // For Winogrande this seems to slow it down rather than speed it up. + //if (query_1st.size() < 32) query_1st.resize(32); + //if (query_2nd.size() < 32) query_2nd.resize(32); + + llama_kv_cache_clear(ctx); + auto logits_1st = evaluate_tokens(ctx, query_1st, 0, params.n_batch, n_vocab); + + llama_kv_cache_clear(ctx); + auto logits_2nd = evaluate_tokens(ctx, query_2nd, 0, params.n_batch, n_vocab); + + if (logits_1st.empty() || logits_2nd.empty()) { fprintf(stderr, "%s : failed to eval\n", __func__); return; } - std::memcpy(tok_logits.data(), logits.data() + (context_size-1)*n_vocab, n_vocab*sizeof(float)); - const auto first_probs = softmax(tok_logits); + bool skip_choice = query_1st_size - base_ctx_1st.size() > k_min_trailing_ctx && + query_2nd_size - base_ctx_2nd.size() > k_min_trailing_ctx; - hs_data[task_idx].ending_logprob_count[0] = 1; - hs_data[task_idx].ending_logprob[0] = std::log(first_probs[query_embd[context_size]]); + float score_1st = 0; + bool is_nan_1st = false; + const auto& base_1 = skip_choice ? base_ctx_1st : base_context; + const int last_1st = query_1st_size - base_1.size() > 1 ? 1 : 0; + for (size_t j = base_1.size()-1; j < query_1st_size-1-last_1st; ++j) { + std::memcpy(tok_logits.data(), logits_1st.data() + j*n_vocab, n_vocab*sizeof(float)); + const float prob = softmax(tok_logits)[query_1st[j+1]]; + if (std::isnan(prob) || !prob) { + fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__, + prob, j, sentence_1st.c_str(), base_context.size()); + is_nan_1st = true; + break; + } + score_1st += std::log(prob); + } + score_1st /= (query_1st_size - base_1.size() - last_1st); - // Calculate the logprobs over the ending - for (size_t j = context_size; j < query_size - 1; j++) { + float score_2nd = 0; + bool is_nan_2nd = false; + const auto& base_2 = skip_choice ? base_ctx_2nd : base_context; + const int last_2nd = query_2nd_size - base_2.size() > 1 ? 1 : 0; + for (size_t j = base_2.size()-1; j < query_2nd_size-1-last_2nd; ++j) { + std::memcpy(tok_logits.data(), logits_2nd.data() + j*n_vocab, n_vocab*sizeof(float)); + const float prob = softmax(tok_logits)[query_2nd[j+1]]; + if (std::isnan(prob) || !prob) { + fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__, + prob, j, sentence_2nd.c_str(), base_context.size()); + is_nan_2nd = true; + break; + } + score_2nd += std::log(prob); + } + score_2nd /= (query_2nd_size - base_2.size() - last_2nd); - std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float)); - - const float prob = softmax(tok_logits)[query_embd[j + 1]]; - - hs_data[task_idx].ending_logprob[0] += std::log(prob); - hs_data[task_idx].ending_logprob_count[0]++; + if (is_nan_1st || is_nan_2nd) { + continue; } - // Calculate the mean token logprob for acc_norm - hs_data[task_idx].ending_logprob[0] /= hs_data[task_idx].ending_logprob_count[0]; - - // Do the remaining endings - // For these, we use the bare ending with n_past = context_size - // - for (size_t ending_idx = 1; ending_idx < 4; ending_idx++) { - - // Tokenize the query - query_embd.resize(ending_tokens[ending_idx].size() - context_size); - std::memcpy(query_embd.data(), ending_tokens[ending_idx].data() + context_size, query_embd.size()*sizeof(int)); - query_size = query_embd.size(); - - // Stop if query wont fit the ctx window - if (context_size + query_size > (size_t)n_ctx) { - fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size); - return; - } - - // Speedup small evaluations by evaluating atleast 32 tokens - // No, resizing to 32 is actually slightly slower (at least on CUDA) - //if (query_size < 32) { - // query_embd.resize(32); - //} - - // Evaluate the query - logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab); - if (logits.empty()) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return; - } - - hs_data[task_idx].ending_logprob_count[ending_idx] = 1; - hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]); - - // Calculate the logprobs over the ending - for (size_t j = 0; j < query_size - 1; j++) { - std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float)); - - const float prob = softmax(tok_logits)[query_embd[j + 1]]; - - hs_data[task_idx].ending_logprob[ending_idx] += std::log(prob); - hs_data[task_idx].ending_logprob_count[ending_idx]++; - } - - // Calculate the mean token logprob for acc_norm - hs_data[task_idx].ending_logprob[ending_idx] /= hs_data[task_idx].ending_logprob_count[ending_idx]; - - -// printf("task %lu, ending %lu, whole_len %lu, context_len %lu, ending_logprob_count %lu, ending_logprob %.4f\n", -// task_idx,ending_idx,whole_size,context_size, hs_data[task_idx].ending_logprob_count[ending_idx], hs_data[task_idx].ending_logprob[ending_idx] ); + if (std::isnan(score_1st) || std::isnan(score_2nd)) { + printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd); + printf("Q1: <%s> - %zu tokens\n", sentence_1st.c_str(), query_1st_size); + printf("Q2: <%s> - %zu tokens\n", sentence_2nd.c_str(), query_2nd_size); + printf("B : <%s> - %zu tokens\n", task.first.c_str(), base_context.size()); + printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", base_1.size(), base_2.size(), skip_choice); + continue; } - // Find the ending with maximum logprob - size_t ending_logprob_max_idx = 0; - double ending_logprob_max_val = hs_data[task_idx].ending_logprob[0]; - for (size_t j = 1; j < 4; j++) { - if (hs_data[task_idx].ending_logprob[j] > ending_logprob_max_val) { - ending_logprob_max_idx = j; - ending_logprob_max_val = hs_data[task_idx].ending_logprob[j]; - } - } + int result = score_1st > score_2nd ? 1 : 2; -// printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_data[task_idx].gold_ending_idx); - - // If the gold ending got the maximum logprobe add one accuracy point - if (ending_logprob_max_idx == hs_data[task_idx].gold_ending_idx) { - acc += 1.0; + if (result == task.answer) { + ++n_correct; } + ++n_done; // Print the accumulated accuracy mean x 100 - printf("%zu\t%.8lf\n",task_idx+1, acc/double(task_idx+1)*100.0); + printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n",task_idx+1, 100.0 * n_correct/n_done,score_1st,score_2nd,result,task.answer); fflush(stdout); } - delete [] hs_data; - printf("\n"); + + if (n_done < 100) return; + + const float p = 1.f*n_correct/n_done; + const float sigma = 100.f*sqrt(p*(1-p)/(n_done-1)); + printf("Final Winogrande score(%d tasks): %.4lf +/- %.4lf\n", n_done, 100*p, sigma); } + int main(int argc, char ** argv) { gpt_params params; @@ -733,6 +1001,8 @@ int main(int argc, char ** argv) { struct results_perplexity results; if (params.hellaswag) { hellaswag_score(ctx, params); + } else if (params.winogrande) { + winogrande_score(ctx, params); } else { results = perplexity(ctx, params); } diff --git a/ggml-backend.c b/ggml-backend.c index f5424fb90..ef518dae0 100644 --- a/ggml-backend.c +++ b/ggml-backend.c @@ -692,6 +692,8 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { switch (op->op) { + case GGML_OP_CPY: + return op->type != GGML_TYPE_IQ2_XXS && op->type != GGML_TYPE_IQ2_XS; // missing type_traits.from_float case GGML_OP_MUL_MAT: return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type; default: @@ -802,6 +804,9 @@ struct ggml_backend_sched { __attribute__((aligned(GGML_MEM_ALIGN))) #endif char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)]; + + ggml_backend_sched_eval_callback callback_eval; + void * callback_eval_user_data; }; #define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node) @@ -1324,9 +1329,38 @@ static void sched_compute_splits(ggml_backend_sched_t sched) { ggml_graph_dump_dot(split->graph, NULL, split_filename); #endif + uint64_t compute_start_us = ggml_time_us(); - ggml_backend_graph_compute(split_backend, &split->graph); - //ggml_backend_synchronize(split_backend); // necessary to measure compute time + if (!sched->callback_eval) { + ggml_backend_graph_compute(split_backend, &split->graph); + //ggml_backend_synchronize(split_backend); // necessary to measure compute time + } else { + // similar to ggml_backend_compare_graph_backend + for (int j0 = 0; j0 < split->graph.n_nodes; j0++) { + struct ggml_tensor * t = split->graph.nodes[j0]; + + // check if the user needs data from this node + bool need = sched->callback_eval(t, true, sched->callback_eval_user_data); + + int j1 = j0; + + // determine the range [j0, j1] of nodes that can be computed together + while (!need && j1 < split->graph.n_nodes - 1) { + t = split->graph.nodes[++j1]; + need = sched->callback_eval(t, true, sched->callback_eval_user_data); + } + + struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1); + + ggml_backend_graph_compute(split_backend, &gv); + + if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) { + break; + } + + j0 = j1; + } + } uint64_t compute_end_us = ggml_time_us(); compute_us[split_backend_id] += compute_end_us - compute_start_us; } @@ -1431,6 +1465,12 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) { sched_reset(sched); } + +void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) { + sched->callback_eval = callback; + sched->callback_eval_user_data = user_data; +} + int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) { return sched->n_splits; } diff --git a/ggml-backend.h b/ggml-backend.h index 12b4b4ab7..ab4ad773f 100644 --- a/ggml-backend.h +++ b/ggml-backend.h @@ -148,6 +148,14 @@ extern "C" { struct ggml_backend_sched; typedef struct ggml_backend_sched * ggml_backend_sched_t; + // when ask == true, the scheduler wants to know if the user wants to observe this node + // this allows the scheduler to batch nodes together in order to evaluate them in a single call + // + // when ask == false, the scheduler is passing the node tensor to the user for observation + // if the user returns false, the scheduler will cancel the graph compute + // + typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); + // Initialize a backend scheduler GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size); GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); @@ -168,6 +176,9 @@ extern "C" { // Reset all assignments and allocators - must be called before using the sched allocators to allocate inputs GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched); + // Set a callback to be called for each resulting node during graph compute + GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data); + // // Utils // diff --git a/ggml-cuda.cu b/ggml-cuda.cu index bb65ca642..bafb2ff1c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5131,10 +5131,10 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * const block_q_t * x = (const block_q_t *) vx; const block_q8_1 * y = (const block_q8_1 *) vy; - for (int i = 0; i < blocks_per_row; i += blocks_per_warp) { - const int ibx = row*blocks_per_row + i + threadIdx.x / (qi/vdr); // x block index + for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i; // x block index - const int iby = (i + threadIdx.x / (qi/vdr)) * (qk/QK8_1); // y block index that aligns with ibx + const int iby = i * (qk/QK8_1); // y block index that aligns with ibx const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int @@ -11058,6 +11058,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons if (a->ne[3] != b->ne[3]) { return false; } + ggml_type a_type = a->type; + if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS) { + if (b->ne[1] == 1 && ggml_nrows(b) > 1) { + return false; + } + } return true; } break; case GGML_OP_GET_ROWS: diff --git a/ggml-metal.m b/ggml-metal.m index 8bb4edd64..6d88d5c36 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -147,6 +147,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -238,21 +239,19 @@ static void * ggml_metal_host_malloc(size_t n) { static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_LOG_INFO("%s: allocating\n", __func__); - id device; - NSString * s; - -#if TARGET_OS_OSX +#if TARGET_OS_OSX && !GGML_METAL_NDEBUG // Show all the Metal device instances in the system NSArray * devices = MTLCopyAllDevices(); - for (device in devices) { - s = [device name]; + for (id device in devices) { + NSString * s = [device name]; GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]); } + [devices release]; // since it was created by a *Copy* C method #endif // Pick and show default Metal device - device = MTLCreateSystemDefaultDevice(); - s = [device name]; + id device = MTLCreateSystemDefaultDevice(); + NSString * s = [device name]; GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]); // Configure context @@ -303,22 +302,21 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { return NULL; } - // dictionary of preprocessor macros - NSMutableDictionary * prep = [NSMutableDictionary dictionary]; + @autoreleasepool { + // dictionary of preprocessor macros + NSMutableDictionary * prep = [NSMutableDictionary dictionary]; #ifdef GGML_QKK_64 - prep[@"QK_K"] = @(64); + prep[@"QK_K"] = @(64); #endif - MTLCompileOptions* options = [MTLCompileOptions new]; - options.preprocessorMacros = prep; + MTLCompileOptions* options = [MTLCompileOptions new]; + options.preprocessorMacros = prep; - //[options setFastMathEnabled:false]; + //[options setFastMathEnabled:false]; - ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; - - [options release]; - [prep release]; + ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; + } } if (error) { @@ -514,6 +512,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16, flash_attn_ext_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -668,6 +667,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_PAD: case GGML_OP_ARGSORT: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: @@ -713,7 +713,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const static bool ggml_metal_graph_compute( struct ggml_metal_context * ctx, struct ggml_cgraph * gf) { - @autoreleasepool { MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; edesc.dispatchType = MTLDispatchTypeSerial; @@ -2165,6 +2164,53 @@ static bool ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_FLASH_ATTN_EXT: + { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + + struct ggml_tensor * src2 = gf->nodes[i]->src[2]; + struct ggml_tensor * src3 = gf->nodes[i]->src[3]; + + size_t offs_src2 = 0; + size_t offs_src3 = 0; + + id id_src2 = src2 ? ggml_metal_get_buffer(ctx, src2, &offs_src2) : nil; + id id_src3 = src3 ? ggml_metal_get_buffer(ctx, src3, &offs_src3) : nil; + + float scale; + memcpy(&scale, dst->op_params, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16].pipeline; + + // TODO: extend if necessary + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&scale length:sizeof( float) atIndex:21]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: @@ -2256,7 +2302,6 @@ static bool ggml_metal_graph_compute( } return true; - } } //////////////////////////////////////////////////////////////////////////////// diff --git a/ggml-metal.metal b/ggml-metal.metal index 029578dc5..28847794c 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1959,6 +1959,35 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } +kernel void kernel_flash_attn_ext_f16( + device const half * q, + device const half * k, + device const half * v, + device const float * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant float & scale, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + // TODO: implement +} + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, diff --git a/ggml-quants.c b/ggml-quants.c index 31b053e33..7d2f033e9 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -1274,7 +1274,12 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * } float sumlx = 0; float suml2 = 0; +#ifdef HAVE_BUGGY_APPLE_LINKER + // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 + for (volatile int i = 0; i < n; ++i) { +#else for (int i = 0; i < n; ++i) { +#endif int l = nearest_int(iscale * x[i]); l = MAX(-nmax, MIN(nmax-1, l)); L[i] = l + nmax; @@ -1649,7 +1654,12 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f float max = x[0]; float sum_w = weights ? weights[0] : x[0]*x[0]; float sum_x = sum_w * x[0]; +#ifdef HAVE_BUGGY_APPLE_LINKER + // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 + for (volatile int i = 1; i < n; ++i) { +#else for (int i = 1; i < n; ++i) { +#endif if (x[i] < min) min = x[i]; if (x[i] > max) max = x[i]; float w = weights ? weights[i] : x[i]*x[i]; @@ -1660,7 +1670,7 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f min = 0; } if (max <= min) { - for (int i = 0; i < n; ++i) L[i] = 0; + memset(L, 0, n); *the_min = -min; return 0.f; } @@ -1862,7 +1872,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri size_t quantize_q2_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { (void)hist; - int row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row); + size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row); if (!quant_weights) { quantize_row_q2_K_reference(src, dst, nrow*n_per_row); } @@ -2181,7 +2191,7 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri size_t quantize_q3_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { (void)hist; - int row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row); + size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row); if (!quant_weights) { quantize_row_q3_K_reference(src, dst, nrow*n_per_row); } @@ -2448,7 +2458,7 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri size_t quantize_q4_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { (void)hist; - int row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row); + size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row); if (!quant_weights) { quantize_row_q4_K_reference(src, dst, nrow*n_per_row); } @@ -2771,7 +2781,7 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri size_t quantize_q5_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { (void)hist; - int row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row); + size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row); if (!quant_weights) { quantize_row_q5_K_reference(src, dst, nrow*n_per_row); } @@ -3025,7 +3035,7 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri size_t quantize_q6_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { (void)hist; - int row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row); + size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row); if (!quant_weights) { quantize_row_q6_K_reference(src, dst, nrow*n_per_row); } @@ -3072,7 +3082,7 @@ size_t quantize_q4_0(const float * src, void * dst, int nrow, int n_per_row, int if (!quant_weights) { return ggml_quantize_q4_0(src, dst, nrow*n_per_row, n_per_row, hist); } - int row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row); + size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row); char * qrow = (char *)dst; for (int row = 0; row < nrow; ++row) { quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights); @@ -3116,7 +3126,7 @@ size_t quantize_q4_1(const float * src, void * dst, int nrow, int n_per_row, int if (!quant_weights) { return ggml_quantize_q4_1(src, dst, nrow*n_per_row, n_per_row, hist); } - int row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row); + size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row); char * qrow = (char *)dst; for (int row = 0; row < nrow; ++row) { quantize_row_q4_1_impl(src, (block_q4_1*)qrow, n_per_row, quant_weights); @@ -3169,7 +3179,7 @@ size_t quantize_q5_0(const float * src, void * dst, int nrow, int n_per_row, int if (!quant_weights) { return ggml_quantize_q5_0(src, dst, nrow*n_per_row, n_per_row, hist); } - int row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row); + size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row); char * qrow = (char *)dst; for (int row = 0; row < nrow; ++row) { quantize_row_q5_0_impl(src, (block_q5_0*)qrow, n_per_row, quant_weights); @@ -3221,7 +3231,7 @@ size_t quantize_q5_1(const float * src, void * dst, int nrow, int n_per_row, int if (!quant_weights) { return ggml_quantize_q5_1(src, dst, nrow*n_per_row, n_per_row, hist); } - int row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row); + size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row); char * qrow = (char *)dst; for (int row = 0; row < nrow; ++row) { quantize_row_q5_1_impl(src, (block_q5_1*)qrow, n_per_row, quant_weights); @@ -8565,7 +8575,7 @@ static int iq2_compare_func(const void * left, const void * right) { return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0; } -static void q2xs_init_impl(int grid_size) { +void iq2xs_init_impl(int grid_size) { const int gindex = iq2_data_index(grid_size); if (iq2_data[gindex].grid) { return; @@ -8720,19 +8730,7 @@ static void q2xs_init_impl(int grid_size) { free(dist2); } -void ggml_init_iq2_quantization(enum ggml_type type) { - if (type == GGML_TYPE_IQ2_XXS) { - q2xs_init_impl(256); - } - else if (type == GGML_TYPE_IQ2_XS) { - q2xs_init_impl(512); - } - else { - fprintf(stderr, "======================== Why are you calling %s with type %d?\n", __func__, (int)type); - } -} - -static void q2xs_deinit_impl(int grid_size) { +void iq2xs_free_impl(int grid_size) { GGML_ASSERT(grid_size == 256 || grid_size == 512 || grid_size == 1024); const int gindex = iq2_data_index(grid_size); if (iq2_data[gindex].grid) { @@ -8742,18 +8740,6 @@ static void q2xs_deinit_impl(int grid_size) { } } -void ggml_deinit_iq2_quantization(enum ggml_type type) { - if (type == GGML_TYPE_IQ2_XXS) { - q2xs_deinit_impl(256); - } - else if (type == GGML_TYPE_IQ2_XS) { - q2xs_deinit_impl(512); - } - else { - fprintf(stderr, "======================== Why are you calling %s with type %d?\n", __func__, (int)type); - } -} - static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid, const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) { int num_neighbors = neighbours[0]; @@ -8786,10 +8772,10 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict const int * kmap_q2xs = iq2_data[gindex].map; const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; - GGML_ASSERT(quant_weights); - GGML_ASSERT(kgrid_q2xs); - GGML_ASSERT(kmap_q2xs); - GGML_ASSERT(kneighbors_q2xs); + GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); GGML_ASSERT(n%QK_K == 0); const int kMaxQ = 3; @@ -9005,10 +8991,10 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v const int * kmap_q2xs = iq2_data[gindex].map; const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours; - GGML_ASSERT(quant_weights); - GGML_ASSERT(kmap_q2xs); - GGML_ASSERT(kgrid_q2xs); - GGML_ASSERT(kneighbors_q2xs); + GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?"); GGML_ASSERT(n%QK_K == 0); const int kMaxQ = 3; diff --git a/ggml-quants.h b/ggml-quants.h index d7fefdb54..7d7cf9178 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -257,3 +257,6 @@ size_t quantize_q4_0 (const float * src, void * dst, int nrows, int n_per_row, size_t quantize_q4_1 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_q5_0 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_q5_1 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); + +void iq2xs_init_impl(int grid_size); +void iq2xs_free_impl(int grid_size); diff --git a/ggml.c b/ggml.c index d7e01b81f..9cf4784ce 100644 --- a/ggml.c +++ b/ggml.c @@ -394,12 +394,6 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y); static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y); -ggml_collect_imatrix_t g_imatrix_collect = NULL; - -void ggml_set_imatrix_collection(ggml_collect_imatrix_t imatrix_collect) { - g_imatrix_collect = imatrix_collect; -} - static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { [GGML_TYPE_I8] = { .type_name = "i8", @@ -1656,6 +1650,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "LEAKY_RELU", "FLASH_ATTN", + "FLASH_ATTN_EXT", "FLASH_FF", "FLASH_ATTN_BACK", "WIN_PART", @@ -1680,7 +1675,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72"); +static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1742,6 +1737,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "leaky_relu(x)", "flash_attn(x)", + "flash_attn_ext(x)", "flash_ff(x)", "flash_attn_back(x)", "win_part(x)", @@ -1766,7 +1762,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72"); +static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5684,6 +5680,46 @@ struct ggml_tensor * ggml_flash_attn( return result; } +// ggml_flash_attn_ext + +struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale) { + GGML_ASSERT(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + if (mask) { + GGML_ASSERT(ggml_is_contiguous(mask)); + GGML_ASSERT(mask->ne[2] == 1); + GGML_ASSERT(mask->ne[3] == 1); + //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); + } + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + is_node = true; + } + + //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne); + + float params[] = { scale }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_FLASH_ATTN_EXT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = mask; + + return result; +} + // ggml_flash_ff struct ggml_tensor * ggml_flash_ff( @@ -9790,10 +9826,6 @@ static void ggml_compute_forward_mul_mat( const int ith = params->ith; const int nth = params->nth; - if (ith == 1 && g_imatrix_collect) { - g_imatrix_collect(src0, src1); - } - const enum ggml_type type = src0->type; const bool src1_cont = ggml_is_contiguous(src1); @@ -10097,10 +10129,6 @@ static void ggml_compute_forward_mul_mat_id( const struct ggml_tensor * src0_cur = dst->src[cur_a + 2]; - if (ith == 1 && g_imatrix_collect) { - g_imatrix_collect(src0_cur, src1); - } - const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); @@ -13226,6 +13254,258 @@ static void ggml_compute_forward_flash_attn( } } +// ggml_compute_forward_flash_attn_ext + +static void ggml_compute_forward_flash_attn_ext_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + const int64_t P = nek1 - N; + const int64_t M = P + N; + + const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + const int64_t rk2 = neq2/nek2; + const int64_t rk3 = neq3/nek3; + + const int64_t rv2 = neq2/nev2; + const int64_t rv3 = neq3/nev3; + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float scale = 1.0f; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { + for (int64_t ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16(neq0, + S + i1, + (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } else { + for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16_unroll(neq0, nbk1, + S + i1, + ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } + + // scale + ggml_vec_scale_f32(nek1, S, scale); + + if (mask) { + const float * mp = (float *)((char *) mask->data + (ir%mask->ne[1])*mask->nb[1]); + ggml_vec_acc_f32(M, S, mp); + } + + // softmax + // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. + // dont forget to set their S values to zero + { + float max = -INFINITY; + ggml_vec_max_f32(M, &max, S); + + ggml_float sum = 0.0; + { +#ifdef GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(S, 1, &max, S, 1, Mup); + vvexpf(S, S, &Mup); + ggml_vec_sum_f32(Mup, &sum, S); +#else + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; + ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + + for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + float * SS = S + i; + + for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { + if (SS[j] == -INFINITY) { + SS[j] = 0.0f; + } else { + ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]); + sump[j] += (ggml_float)val; + SS[j] = val; + } + } + } + + for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } +#endif + } + + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(M, S, sum); + +#ifndef NDEBUG + for (int i = 0; i < M; ++i) { + assert(!isnan(S[i])); + assert(!isinf(S[i])); + } +#endif + } + + ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); + + for (int64_t i = 0; i < M; i++) { + S16[i] = GGML_FP32_TO_FP16(S[i]); + } + + // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16). + if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { + for (int64_t ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // v indices + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; + + ggml_vec_dot_f16(nev0, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + S16); + } + } else { + for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // v indices + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; + + ggml_vec_dot_f16_unroll(nev0, nbv1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + S16); + } + } + } +} + +static void ggml_compute_forward_flash_attn_ext( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + switch (q->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_flash_ff static void ggml_compute_forward_flash_ff_f16( @@ -14731,6 +15011,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm const bool masked = t != 0; ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor); } break; + case GGML_OP_FLASH_ATTN_EXT: + { + ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } break; case GGML_OP_FLASH_FF: { ggml_compute_forward_flash_ff(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor); @@ -15727,6 +16011,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor GGML_ASSERT(false); // TODO: not implemented } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { struct ggml_tensor * flash_grad = NULL; if (src0->grad || src1->grad || tensor->src[2]->grad) { @@ -16452,6 +16737,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { n_tasks = n_threads; } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { n_tasks = n_threads; } break; @@ -16783,6 +17069,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); @@ -18538,6 +18825,28 @@ enum ggml_opt_result ggml_opt_resume_g( //////////////////////////////////////////////////////////////////////////////// +void ggml_quantize_init(enum ggml_type type) { + ggml_critical_section_start(); + + switch (type) { + case GGML_TYPE_IQ2_XXS: iq2xs_init_impl(256); break; + case GGML_TYPE_IQ2_XS: iq2xs_init_impl(512); break; + default: // nothing + break; + } + + ggml_critical_section_end(); +} + +void ggml_quantize_free(void) { + ggml_critical_section_start(); + + iq2xs_free_impl(256); + iq2xs_free_impl(512); + + ggml_critical_section_end(); +} + size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) { assert(k % QK4_0 == 0); const int nb = k / QK4_0; @@ -18665,9 +18974,15 @@ size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * return (n/QK8_0*sizeof(block_q8_0)); } +bool ggml_quantize_requires_imatrix(enum ggml_type type) { + return + type == GGML_TYPE_IQ2_XXS || + type == GGML_TYPE_IQ2_XS; +} + size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int nrows, int n_per_row, int64_t * hist, const float * imatrix) { - (void)imatrix; + ggml_quantize_init(type); // this is noop if already initialized size_t result = 0; int n = nrows * n_per_row; switch (type) { @@ -18780,13 +19095,13 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i } break; case GGML_TYPE_F16: { - int elemsize = sizeof(ggml_fp16_t); + size_t elemsize = sizeof(ggml_fp16_t); ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n); result = n * elemsize; } break; case GGML_TYPE_F32: { - int elemsize = sizeof(float); + size_t elemsize = sizeof(float); result = n * elemsize; memcpy((uint8_t *)dst + start * elemsize, src + start, result); } break; diff --git a/ggml.h b/ggml.h index 837c52e68..d76fe9d5c 100644 --- a/ggml.h +++ b/ggml.h @@ -452,6 +452,7 @@ extern "C" { GGML_OP_LEAKY_RELU, GGML_OP_FLASH_ATTN, + GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_FF, GGML_OP_FLASH_ATTN_BACK, GGML_OP_WIN_PART, @@ -1619,6 +1620,14 @@ extern "C" { struct ggml_tensor * v, bool masked); + GGML_API struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale); + GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, struct ggml_tensor * q, @@ -2065,6 +2074,18 @@ extern "C" { // quantization // + // - ggml_quantize_init can be called multiple times with the same type + // it will only initialize the quantization tables for the first call or after ggml_quantize_free + // automatically called by ggml_quantize_chunk for convenience + // + // - ggml_quantize_free will free any memory allocated by ggml_quantize_init + // call this at the end of the program to avoid memory leaks + // + // note: these are thread-safe + // + GGML_API void ggml_quantize_init(enum ggml_type type); + GGML_API void ggml_quantize_free(void); + // TODO: these would probably get removed in favor of the more general ggml_quantize_chunk GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); @@ -2078,19 +2099,13 @@ extern "C" { GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist); + // some quantization type cannot be used without an importance matrix + GGML_API bool ggml_quantize_requires_imatrix(enum ggml_type type); + + // calls ggml_quantize_init internally (i.e. can allocate memory) GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int nrows, int n_per_row, int64_t * hist, const float * imatrix); - // These are needed for IQ2_XS and IQ2_XXS quantizations - GGML_API void ggml_init_iq2_quantization(enum ggml_type type); - GGML_API void ggml_deinit_iq2_quantization(enum ggml_type type); - - // - // Importance matrix - // - typedef void(*ggml_collect_imatrix_t)(const struct ggml_tensor * src0, const struct ggml_tensor * src1); - GGML_API void ggml_set_imatrix_collection(ggml_collect_imatrix_t imatrix_collect); - // // gguf // diff --git a/llama.cpp b/llama.cpp index 2c5983c67..d4bebe520 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1393,6 +1393,9 @@ struct llama_cparams { bool mul_mat_q; bool offload_kqv; + + ggml_backend_sched_eval_callback cb_eval; + void * cb_eval_user_data; }; struct llama_layer { @@ -4202,38 +4205,6 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(k, "k", il); - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); - - if (model.arch == LLM_ARCH_PHI2) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } - - if (max_alibi_bias > 0.0f) { - // temporary branch until we figure out how to handle ggml_alibi through ggml_add - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); - - if (max_alibi_bias > 0.0f) { - // TODO: n_head or n_head_kv - // TODO: K-shift is likely not working - // TODO: change to ggml_add - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); - } - - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); - - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); - cb(kq, "kq_soft_max_ext", il); - } - // split cached v into n_head heads struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], @@ -4243,8 +4214,53 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); + // TODO: determine if we can use flash attention + const bool supports_flash_attn = true; + + struct ggml_tensor * kqv; + + if (supports_flash_attn) { + //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); + //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); + //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); + //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); + kqv = ggml_flash_attn_ext(ctx, ggml_cast(ctx, q, GGML_TYPE_F16), k, v, kq_mask, kq_scale); + } else { + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); + + if (model.arch == LLM_ARCH_PHI2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } + + if (max_alibi_bias > 0.0f) { + // temporary branch until we figure out how to handle ggml_alibi through ggml_add + kq = ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); + + if (max_alibi_bias > 0.0f) { + // TODO: n_head or n_head_kv + // TODO: K-shift is likely not working + // TODO: change to ggml_add + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); + } + + kq = ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); + + kq = ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); + cb(kq, "kq_soft_max_ext", il); + } + + kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); + } struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); cb(kqv_merged, "kqv_merged", il); @@ -6254,6 +6270,7 @@ static int llama_decode_internal( //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); ggml_backend_sched_reset(lctx.sched); + ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); ggml_cgraph * gf = llama_build_graph(lctx, batch); @@ -8743,8 +8760,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // placeholder for the meta data ::zeros(fout, meta_size); - std::set used_iq2; - for (int i = 0; i < ml.n_tensors; ++i) { struct ggml_tensor * tensor = ml.get_tensor_meta(i); @@ -8797,11 +8812,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } else { const size_t nelements = ggml_nelements(tensor); - if ((new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_XS) && used_iq2.find(new_type) == used_iq2.end()) { - ggml_init_iq2_quantization(new_type); - used_iq2.insert(new_type); - } - const float * imatrix = nullptr; if (imatrix_data) { auto it = imatrix_data->find(tensor->name); @@ -8927,10 +8937,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s fout.close(); - for (auto type : used_iq2) { - ggml_deinit_iq2_quantization(type); - } - gguf_free(ctx_out); LLAMA_LOG_INFO("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); @@ -9276,6 +9282,8 @@ struct llama_context_params llama_context_default_params() { /*.yarn_beta_fast =*/ 32.0f, /*.yarn_beta_slow =*/ 1.0f, /*.yarn_orig_ctx =*/ 0, + /*.cb_eval =*/ nullptr, + /*.cb_eval_user_data =*/ nullptr, /*.type_k =*/ GGML_TYPE_F16, /*.type_v =*/ GGML_TYPE_F16, /*.mul_mat_q =*/ true, @@ -9336,6 +9344,7 @@ void llama_backend_free(void) { #ifdef GGML_USE_MPI ggml_mpi_backend_free(); #endif + ggml_quantize_free(); } int64_t llama_time_us(void) { @@ -9416,6 +9425,9 @@ struct llama_context * llama_new_context_with_model( hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx : hparams.n_ctx_train; + cparams.cb_eval = params.cb_eval; + cparams.cb_eval_user_data = params.cb_eval_user_data; + auto rope_scaling_type = params.rope_scaling_type; if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) { rope_scaling_type = hparams.rope_scaling_type_train; @@ -9491,8 +9503,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, - cparams.n_ctx, cparams.offload_kqv)) { + if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, cparams.n_ctx, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; diff --git a/llama.h b/llama.h index a570b0d69..e268d7a1d 100644 --- a/llama.h +++ b/llama.h @@ -2,6 +2,7 @@ #define LLAMA_H #include "ggml.h" +#include "ggml-backend.h" #ifdef GGML_USE_CUBLAS #include "ggml-cuda.h" #define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES @@ -231,6 +232,9 @@ extern "C" { float yarn_beta_slow; // YaRN high correction dim uint32_t yarn_orig_ctx; // YaRN original context size + ggml_backend_sched_eval_callback cb_eval; + void * cb_eval_user_data; + enum ggml_type type_k; // data type for K cache enum ggml_type type_v; // data type for V cache diff --git a/scripts/get-hellaswag.sh b/scripts/get-hellaswag.sh new file mode 100755 index 000000000..ef8dcceb0 --- /dev/null +++ b/scripts/get-hellaswag.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +wget https://raw.githubusercontent.com/klosax/hellaswag_text_data/main/hellaswag_val_full.txt + +echo "Usage:" +echo "" +echo " ./perplexity --hellaswag --hellaswag-tasks N -f hellaswag_val_full.txt -m modelfile.gguf" +echo "" + +exit 0 diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index be9e408fb..4d52d946b 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -b306d6e996ec0ace77118fa5098822cdc7f9c88f +6c1ce0bd591a430c1d3f6797d905194581c878c1 diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index bc5649989..c89d8b191 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -49,6 +49,7 @@ llama_build_and_test_executable(test-llama-grammar.cpp) llama_build_and_test_executable(test-grad0.cpp) # llama_build_and_test_executable(test-opt.cpp) # SLOW llama_build_and_test_executable(test-backend-ops.cpp) +llama_build_and_test_executable(test-autorelease.cpp) llama_build_and_test_executable(test-rope.cpp) diff --git a/tests/test-autorelease.cpp b/tests/test-autorelease.cpp new file mode 100644 index 000000000..289c6ba6c --- /dev/null +++ b/tests/test-autorelease.cpp @@ -0,0 +1,28 @@ +// ref: https://github.com/ggerganov/llama.cpp/issues/4952#issuecomment-1892864763 + +#include +#include +#include + +#include "llama.h" + +// This creates a new context inside a pthread and then tries to exit cleanly. +int main(int argc, char ** argv) { + if (argc < 2) { + printf("Usage: %s model.gguf\n", argv[0]); + return 0; // intentionally return success + } + + const std::string fname = argv[1]; + + std::thread([&fname]() { + llama_backend_init(false); + auto * model = llama_load_model_from_file(fname.c_str(), llama_model_default_params()); + auto * ctx = llama_new_context_with_model(model, llama_context_default_params()); + llama_free(ctx); + llama_free_model(model); + llama_backend_free(); + }).join(); + + return 0; +} diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 22a7856d4..5693c2197 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -16,39 +16,37 @@ #include static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { + // static RNG initialization (revisit if n_threads stops being constant) + static const size_t n_threads = std::thread::hardware_concurrency(); + static std::vector generators = []() { + std::random_device rd; + std::vector vec; + vec.reserve(n_threads); + //for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(1234 + i); } // fixed seed + for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); } + return vec; + }(); + size_t size = ggml_nelements(tensor); std::vector data(size); -#if 0 - static std::default_random_engine generator(1234); - std::uniform_real_distribution distribution(min, max); - - for (size_t i = 0; i < size; i++) { - data[i] = distribution(generator); - } -#else - auto init_thread = [&](size_t start, size_t end) { - std::random_device rd; - std::default_random_engine generator(rd()); + auto init_thread = [&](size_t ith, size_t start, size_t end) { std::uniform_real_distribution distribution(min, max); - for (size_t i = start; i < end; i++) { - data[i] = distribution(generator); + data[i] = distribution(generators[ith]); } }; - size_t n_threads = std::thread::hardware_concurrency(); std::vector threads; threads.reserve(n_threads); for (size_t i = 0; i < n_threads; i++) { size_t start = i*size/n_threads; size_t end = (i+1)*size/n_threads; - threads.emplace_back(init_thread, start, end); + threads.emplace_back(init_thread, i, start, end); } for (auto & t : threads) { t.join(); } -#endif if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) { ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float)); @@ -56,7 +54,16 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0); std::vector dataq(ggml_row_size(tensor->type, size)); int64_t hist[16]; - ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size/tensor->ne[0], tensor->ne[0], hist, nullptr); + std::vector imatrix(tensor->ne[0], 1.0f); // dummy importance matrix + const float * im = imatrix.data(); + if (!ggml_quantize_requires_imatrix(tensor->type)) { + // when the imatrix is optional, we want to test both quantization with and without imatrix + // use one of the random numbers to decide + if (data[0] > 0.5f*(min + max)) { + im = nullptr; + } + } + ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size/tensor->ne[0], tensor->ne[0], hist, im); ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size()); } else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) { // This is going to create some weird integers though. @@ -1377,6 +1384,32 @@ struct test_leaky_relu : public test_case { } }; +// GGML_OP_FLASH_ATTN_EXT +struct test_flash_attn_ext : public test_case { + const ggml_type typeq; + const int64_t hs; // head size + const int64_t nh; // num heads + const int64_t kv; // kv size + const int64_t nt; // tokens + + std::string vars() override { + return VARS_TO_STR5(typeq, hs, nh, kv, nt); + } + + test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16, + int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nt = 8) + : typeq(typeq), hs(hs), nh(nh), kv(kv), nt(nt) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nt, nh, 1); + ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nt, 1, 1); + ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); + return out; + } +}; + // Mixtral MOE struct test_moe : public test_case { const int n_experts; @@ -1472,7 +1505,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op GGML_TYPE_Q8_0, GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, - GGML_TYPE_Q6_K + GGML_TYPE_Q6_K, + GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, }; // unary ops @@ -1642,6 +1676,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); + test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 96, 8)); + #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024)); @@ -1752,6 +1788,8 @@ int main(int argc, char ** argv) { return 1; } + ggml_quantize_free(); + printf("\033[1;32mOK\033[0m\n"); return 0; }