tests : remove unnecessary funcs
This commit is contained in:
parent
971c689178
commit
68c9fca9c2
1 changed files with 6 additions and 42 deletions
|
@ -208,42 +208,6 @@ struct ggml_tensor * get_random_tensor_i32(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
float get_element(const struct ggml_tensor * t, int idx) {
|
|
||||||
switch (t->type) {
|
|
||||||
case GGML_TYPE_F32:
|
|
||||||
return ((float *)t->data)[idx];
|
|
||||||
case GGML_TYPE_I32:
|
|
||||||
return ((int32_t *)t->data)[idx];
|
|
||||||
case GGML_TYPE_F16:
|
|
||||||
return ggml_fp16_to_fp32(((ggml_fp16_t *)t->data)[idx]);
|
|
||||||
case GGML_TYPE_I16:
|
|
||||||
return ((int16_t *)t->data)[idx];
|
|
||||||
default:
|
|
||||||
assert(false);
|
|
||||||
}
|
|
||||||
return INFINITY;
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_element(struct ggml_tensor * t, int idx, float value) {
|
|
||||||
switch (t->type) {
|
|
||||||
case GGML_TYPE_F32:
|
|
||||||
((float *)t->data)[idx] = value;
|
|
||||||
break;
|
|
||||||
case GGML_TYPE_I32:
|
|
||||||
((int32_t *)t->data)[idx] = value;
|
|
||||||
break;
|
|
||||||
case GGML_TYPE_F16:
|
|
||||||
((ggml_fp16_t*)t->data)[idx] = ggml_fp32_to_fp16(value);
|
|
||||||
break;
|
|
||||||
case GGML_TYPE_I16:
|
|
||||||
((int16_t *)t->data)[idx] = value;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
assert(false);
|
|
||||||
}
|
|
||||||
;
|
|
||||||
}
|
|
||||||
|
|
||||||
void print_elements(const char* label, const struct ggml_tensor * t) {
|
void print_elements(const char* label, const struct ggml_tensor * t) {
|
||||||
if (!t) {
|
if (!t) {
|
||||||
printf("%s: %s = null\n", __func__, label);
|
printf("%s: %s = null\n", __func__, label);
|
||||||
|
@ -253,7 +217,7 @@ void print_elements(const char* label, const struct ggml_tensor * t) {
|
||||||
printf("%s: %s = [", __func__, label);
|
printf("%s: %s = [", __func__, label);
|
||||||
for (int k = 0; k < nelements; ++k) {
|
for (int k = 0; k < nelements; ++k) {
|
||||||
if (k > 0) { printf(", "); }
|
if (k > 0) { printf(", "); }
|
||||||
printf("%.5f", get_element(t, k));
|
printf("%.5f", ggml_get_f32_1d(t, k));
|
||||||
}
|
}
|
||||||
printf("] shape: [");
|
printf("] shape: [");
|
||||||
for (int k = 0; k < t->n_dims; ++k) {
|
for (int k = 0; k < t->n_dims; ++k) {
|
||||||
|
@ -304,23 +268,23 @@ bool check_gradient(
|
||||||
const int nelements = ggml_nelements(x[i]);
|
const int nelements = ggml_nelements(x[i]);
|
||||||
for (int k = 0; k < nelements; ++k) {
|
for (int k = 0; k < nelements; ++k) {
|
||||||
// compute gradient using finite differences
|
// compute gradient using finite differences
|
||||||
const float x0 = get_element(x[i], k);
|
const float x0 = ggml_get_f32_1d(x[i], k);
|
||||||
const float xm = x0 - eps;
|
const float xm = x0 - eps;
|
||||||
const float xp = x0 + eps;
|
const float xp = x0 + eps;
|
||||||
set_element(x[i], k, xp);
|
ggml_set_f32_1d(x[i], k, xp);
|
||||||
|
|
||||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||||
|
|
||||||
const float f0 = ggml_get_f32_1d(f, 0);
|
const float f0 = ggml_get_f32_1d(f, 0);
|
||||||
|
|
||||||
set_element(x[i], k, xm);
|
ggml_set_f32_1d(x[i], k, xm);
|
||||||
|
|
||||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||||
|
|
||||||
const float f1 = ggml_get_f32_1d(f, 0);
|
const float f1 = ggml_get_f32_1d(f, 0);
|
||||||
const float g0 = (f0 - f1)/(2.0f*eps);
|
const float g0 = (f0 - f1)/(2.0f*eps);
|
||||||
|
|
||||||
set_element(x[i], k, x0);
|
ggml_set_f32_1d(x[i], k, x0);
|
||||||
|
|
||||||
// compute gradient using backward graph
|
// compute gradient using backward graph
|
||||||
ggml_graph_reset (&gf);
|
ggml_graph_reset (&gf);
|
||||||
|
@ -328,7 +292,7 @@ bool check_gradient(
|
||||||
|
|
||||||
ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
|
ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
|
||||||
|
|
||||||
const float g1 = get_element(x[i]->grad, k);
|
const float g1 = ggml_get_f32_1d(x[i]->grad, k);
|
||||||
|
|
||||||
const float error_abs = fabsf(g0 - g1);
|
const float error_abs = fabsf(g0 - g1);
|
||||||
const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabsf(g0) : 0;
|
const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabsf(g0) : 0;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue