From 067e2947830cd0bfd192f6bc5dde2375cee7b337 Mon Sep 17 00:00:00 2001 From: Pierrick HYMBERT Date: Wed, 10 Apr 2024 03:35:57 +0200 Subject: [PATCH] gguf-debug: Example how to use ggml callback for debugging --- examples/CMakeLists.txt | 1 + examples/ggml-debug/CMakeLists.txt | 5 + examples/ggml-debug/README.md | 106 +++++++++++++++++ examples/ggml-debug/ggml-debug.cpp | 180 +++++++++++++++++++++++++++++ llama.cpp | 2 +- 5 files changed, 293 insertions(+), 1 deletion(-) create mode 100644 examples/ggml-debug/CMakeLists.txt create mode 100644 examples/ggml-debug/README.md create mode 100644 examples/ggml-debug/ggml-debug.cpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 76496bf06..df39b6236 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -22,6 +22,7 @@ else() add_subdirectory(finetune) add_subdirectory(gritlm) add_subdirectory(gguf-split) + add_subdirectory(ggml-debug) add_subdirectory(infill) add_subdirectory(llama-bench) add_subdirectory(llava) diff --git a/examples/ggml-debug/CMakeLists.txt b/examples/ggml-debug/CMakeLists.txt new file mode 100644 index 000000000..36987d0c7 --- /dev/null +++ b/examples/ggml-debug/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET ggml-debug) +add_executable(${TARGET} ggml-debug.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/ggml-debug/README.md b/examples/ggml-debug/README.md new file mode 100644 index 000000000..825facd4d --- /dev/null +++ b/examples/ggml-debug/README.md @@ -0,0 +1,106 @@ +# llama.cpp/examples/ggml-debug + +A simple example which demonstrates how to use callback during the inference. +It simply prints to the console all operations and tensor data. + +Usage: + +```shell +ggml-debug \ + --hf-repo ggml-org/models \ + --hf-file phi-2/ggml-model-q4_0.gguf \ + --model phi-2-q4_0.gguf \ + --prompt hello \ + --seed 42 \ + -ngl 33 +``` + +Will print: + +```shell +llm_load_tensors: offloaded 33/33 layers to GPU +... +llama_new_context_with_model: n_ctx = 512 +... +llama_new_context_with_model: CUDA0 compute buffer size = 105.00 MiB +llama_new_context_with_model: CUDA_Host compute buffer size = 6.01 MiB +llama_new_context_with_model: graph nodes = 1225 +llama_new_context_with_model: graph splits = 2 + +system_info: n_threads = 6 / 12 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | +ggml_debug: inp_embd = GET_ROWS(token_embd.weight{2560, 51200, 1, 1}, inp_tokens{1, 1, 1, 1}}) = {2560, 1, 1, 1} + [ + [ + [ -0.0181, -0.0181, 0.0453, ...], + ], + ] +ggml_debug: norm-0 = NORM(CUDA0#inp_embd#0{2560, 1, 1, 1}, }) = {2560, 1, 1, 1} + [ + [ + [ -0.6989, -0.6989, 1.7686, ...], + ], + ] +ggml_debug: norm_w-0 = MUL(norm-0{2560, 1, 1, 1}, blk.0.attn_norm.weight{2560, 1, 1, 1}}) = {2560, 1, 1, 1} + [ + [ + [ -0.1800, -0.1788, 0.4663, ...], + ], + ] +ggml_debug: attn_norm-0 = ADD(norm_w-0{2560, 1, 1, 1}, blk.0.attn_norm.bias{2560, 1, 1, 1}}) = {2560, 1, 1, 1} + [ + [ + [ -0.1863, -0.1712, 0.4750, ...], + ], + ] +ggml_debug: wqkv-0 = MUL_MAT(blk.0.attn_qkv.weight{2560, 7680, 1, 1}, attn_norm-0{2560, 1, 1, 1}}) = {7680, 1, 1, 1} + [ + [ + [ -1.1238, -2.3523, -1.6938, ...], + ], + ] +ggml_debug: bqkv-0 = ADD(wqkv-0{7680, 1, 1, 1}, blk.0.attn_qkv.bias{7680, 1, 1, 1}}) = {7680, 1, 1, 1} + [ + [ + [ -1.1135, -2.5451, -1.8321, ...], + ], + ] +ggml_debug: bqkv-0 (view) = VIEW(bqkv-0{7680, 1, 1, 1}, }) = {2560, 1, 1, 1} + [ + [ + [ -1.1135, -2.5451, -1.8321, ...], + ], + ] +ggml_debug: Qcur-0 = CONT(bqkv-0 (view){2560, 1, 1, 1}, }) = {2560, 1, 1, 1} + [ + [ + [ -1.1135, -2.5451, -1.8321, ...], + ], + ] +ggml_debug: Qcur-0 (reshaped) = RESHAPE(Qcur-0{2560, 1, 1, 1}, }) = {80, 32, 1, 1} + [ + [ + [ -1.1135, 0.8348, 0.8010, ...], + [ -2.5451, -1.1920, 0.0546, ...], + [ -1.8321, -0.0515, 0.8186, ...], + ... + ], + ] +ggml_debug: Qcur-0 = ROPE(Qcur-0 (reshaped){80, 32, 1, 1}, CUDA0#inp_pos#0{1, 1, 1, 1}}) = {80, 32, 1, 1} + [ + [ + [ -1.1135, 0.8348, 0.8010, ...], + [ -2.5451, -1.1920, 0.0546, ...], + [ -1.8321, -0.0515, 0.8186, ...], + ... + ], + ] +ggml_debug: Qcur-0 = SCALE(Qcur-0{80, 32, 1, 1}, }) = {80, 32, 1, 1} + [ + [ + [ -0.1245, 0.0933, 0.0896, ...], + [ -0.2845, -0.1333, 0.0061, ...], + [ -0.2048, -0.0058, 0.0915, ...], + ... + ], + ] +``` diff --git a/examples/ggml-debug/ggml-debug.cpp b/examples/ggml-debug/ggml-debug.cpp new file mode 100644 index 000000000..168725a7d --- /dev/null +++ b/examples/ggml-debug/ggml-debug.cpp @@ -0,0 +1,180 @@ +#include "common.h" +#include "llama.h" +#include "ggml.h" + +#include +#include +#include +#include +#include + +struct callback_data { + std::mutex m_mutex; + std::vector data; +}; + +static std::string ggml_ne_string(const ggml_tensor * t) { + std::string str; + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + str += std::to_string(t->ne[i]); + if (i + 1 < GGML_MAX_DIMS) { + str += ", "; + } + } + return str; +} + +static void ggml_print_tensor(const float * data, const int64_t * ne) { + int i, j, k; + printf(" [\n"); + for (i = 0; i < ne[2] && i < 3; i++) { + printf(" [\n"); + for (j = 0; j < ne[1] && j < 3; j++) { + printf(" ["); + for (k = 0; k < ne[0] && k < 3; k++) { + printf("%8.4f", data[k * ne[1] * ne[2] + j * ne[2] + i]); + if (k < ne[0] - 1 && k < 2) printf(", "); + } + if (ne[0] > 3) printf(", ..."); + printf("],\n"); + } + if (ne[1] > 3) printf(" ...\n"); + printf(" ],\n"); + } + if (ne[2] > 3) printf(" ...\n"); + printf(" ]\n"); +} + +/** + * GGML operations callback during the graph execution. + * + * @param t current tensor + * @param ask when ask is true, the scheduler wants to know if we are interested in data from this tensor + * if we return true, a follow-up call will be made with ask=false in which we can do the actual collection. + * @param user_data user data to pass at each call back + * @return true to receive data or continue the graph, false otherwise + */ +static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { + auto * cb_data = (callback_data *) user_data; + + const struct ggml_tensor * src0 = t->src[0]; + const struct ggml_tensor * src1 = t->src[1]; + + if (ask) { + return true; // Always retrieve data + } + + std::lock_guard lock(cb_data->m_mutex); + + char src1_str[128] = {0}; + if (src1) { + sprintf(src1_str, "%s{%s}", src1->name, ggml_ne_string(src1).c_str()); + } + + printf("%s: %24s = %10s(%s{%s}, %s}) = {%s} \n", __func__, + t->name, ggml_op_name(t->op), + src0->name, ggml_ne_string(src0).c_str(), + src1 ? src1_str : "", + ggml_ne_string(t).c_str()); + + + // copy the data from the GPU memory if needed + const bool is_host = ggml_backend_buffer_is_host(t->buffer); + + if (!is_host) { + auto n_bytes = ggml_nbytes(t); + cb_data->data.resize(n_bytes / sizeof(float)); + ggml_backend_tensor_get(t, cb_data->data.data(), 0, n_bytes); + } + + const float * data = is_host ? (const float *) t->data : cb_data->data.data(); + ggml_print_tensor(data, t->ne); + + return true; +} + +static bool run(llama_context * ctx, const gpt_params & params) { + const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); + + std::vector tokens = ::llama_tokenize(ctx, params.prompt, add_bos); + + if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return false; + } + + return true; +} + +int main(int argc, char ** argv) { + + callback_data cb_data; + + gpt_params params; + params.n_batch = 512; + if (!gpt_params_parse(argc, argv, params)) { + return 1; + } + params.n_batch = std::min(params.n_batch, params.n_ctx); + + print_build_info(); + + std::mt19937 rng(params.seed); + if (params.random_prompt) { + params.prompt = gpt_random_prompt(rng); + } + + llama_backend_init(); + llama_numa_init(params.numa); + + auto mparams = llama_model_params_from_gpt_params(params); + + llama_model * model = nullptr; + + if (!params.hf_repo.empty() && !params.hf_file.empty()) { + model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams); + } else if (!params.model_url.empty()) { + model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams); + } else { + model = llama_load_model_from_file(params.model.c_str(), mparams); + } + + if (model == NULL) { + fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); + return 1; + } + + auto cparams = llama_context_params_from_gpt_params(params); + + // pass the callback to the backend scheduler + // it will be executed for each node during the graph computation + cparams.cb_eval = ggml_debug; + cparams.cb_eval_user_data = &cb_data; + + llama_context * ctx = llama_new_context_with_model(model, cparams); + if (ctx == NULL) { + fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); + llama_free_model(model); + return 1; + } + + // print system information + { + fprintf(stderr, "\n"); + fprintf(stderr, "%s\n", get_system_info(params).c_str()); + } + + bool OK = run(ctx, params); + if (!OK) { + return 1; + } + + llama_print_timings(ctx); + + llama_free(ctx); + llama_free_model(model); + + llama_backend_free(); + + return 0; +} diff --git a/llama.cpp b/llama.cpp index 217726184..fd664c9f5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -11054,7 +11054,7 @@ struct llm_tokenizer_bpe { add_new_bigram(bigram.left, left_symbol.next); // right side of current symbol } - // add the fnished tokens to the final list keeping correct order for next and prev + // add the finished tokens to the final list keeping correct order for next and prev for (auto & sym : symbols) { if (sym.n > 0) { sym.prev = final_prev_index;