backend : add eval callback

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-01-14 16:48:16 +02:00
parent 4483396751
commit 65648b341f
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 90 additions and 4 deletions

View file

@ -6,11 +6,36 @@
#include <string>
#include <vector>
// a function that can be called for every computed node during graph evaluation
// the user can choose to whether to observe the data of the node depending on the tensor parameters
static bool observe_compute(int node_index, struct ggml_tensor * t, void * user_data) {
GGML_UNUSED(user_data);
// check if name contains soft_max
if (strstr(t->name, "soft_max") != 0) {
printf("%s: node_index = %5d, t->name = %32s, t->op = %12s, [%5d, %5d, %5d, %5d]\n",
__func__, node_index, t->name, ggml_op_name(t->op), (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
std::vector<float> t_data(ggml_nelements(t));
ggml_backend_tensor_get(t, t_data.data(), 0, ggml_nbytes(t));
// print first row
for (int i = 0; i < t->ne[0]; i++) {
printf("%8.4f ", t_data[i]);
}
printf("\n");
}
return true;
}
int main(int argc, char ** argv) {
gpt_params params;
bool observe = false;
if (argc == 1 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH [PROMPT]\n" , argv[0]);
printf("usage: %s MODEL_PATH [PROMPT] [OBSERV]\n" , argv[0]);
return 1 ;
}
@ -22,6 +47,10 @@ int main(int argc, char ** argv) {
params.prompt = argv[2];
}
if (argc >= 4) {
observe = atoi(argv[3]);
}
if (params.prompt.empty()) {
params.prompt = "Hello my name is";
}
@ -37,7 +66,7 @@ int main(int argc, char ** argv) {
llama_model_params model_params = llama_model_default_params();
// model_params.n_gpu_layers = 99; // offload all layers to the GPU
model_params.n_gpu_layers = 99; // offload all layers to the GPU
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
@ -55,6 +84,9 @@ int main(int argc, char ** argv) {
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
ctx_params.cb_eval = observe ? observe_compute : NULL;
ctx_params.cb_eval_user_data = NULL;
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
if (ctx == NULL) {