test-grad0.c : add print_elements to help with debugging

This commit is contained in:
xaedes 2023-04-28 17:46:55 +02:00
parent 339b2adf48
commit 86b44a02e4
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -112,6 +112,26 @@ void set_element(struct ggml_tensor * t, int idx, float value) {
((float *)t->data)[idx] = value;
}
void print_elements(const char* label, const struct ggml_tensor * t) {
if (!t) {
printf("%s: %s = null\n", __func__, label);
return;
}
const int nelements = ggml_nelements(t);
printf("%s: %s = [", __func__, label);
for (int k = 0; k < nelements; ++k) {
if (k > 0) { printf(", "); }
printf("%.5f", get_element(t, k));
}
printf("] shape: [");
for (int k = 0; k < t->n_dims; ++k) {
if (k > 0) { printf(", "); }
printf("%d", t->ne[k]);
}
printf("]\n");
}
bool check_gradient(
const char * op_name,
struct ggml_context * ctx0,