successfully test get_rows backward

This commit is contained in:
xaedes 2023-04-28 20:32:00 +02:00
parent 96e773bbde
commit f0302fa71b
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -104,8 +104,63 @@ struct ggml_tensor * get_random_tensor(
return result;
}
struct ggml_tensor * get_random_tensor_int(
struct ggml_context * ctx0,
int ndims,
int64_t ne[],
int32_t imin,
int32_t imax) {
struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_I32, ndims, ne);
switch (ndims) {
case 1:
for (int i0 = 0; i0 < ne[0]; i0++) {
((int32_t *)result->data)[i0] = irand(imax - imin) + imin;
}
break;
case 2:
for (int i1 = 0; i1 < ne[1]; i1++) {
for (int i0 = 0; i0 < ne[0]; i0++) {
((int32_t *)result->data)[i1*ne[0] + i0] = irand(imax - imin) + imin;
}
}
break;
case 3:
for (int i2 = 0; i2 < ne[2]; i2++) {
for (int i1 = 0; i1 < ne[1]; i1++) {
for (int i0 = 0; i0 < ne[0]; i0++) {
((int32_t *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = irand(imax - imin) + imin;
}
}
}
break;
case 4:
for (int i3 = 0; i3 < ne[3]; i3++) {
for (int i2 = 0; i2 < ne[2]; i2++) {
for (int i1 = 0; i1 < ne[1]; i1++) {
for (int i0 = 0; i0 < ne[0]; i0++) {
((int32_t *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = irand(imax - imin) + imin;
}
}
}
}
break;
default:
assert(false);
};
return result;
}
float get_element(const struct ggml_tensor * t, int idx) {
return ((float *)t->data)[idx];
if (t->type == GGML_TYPE_F32) {
return ((float *)t->data)[idx];
} else if (t->type == GGML_TYPE_I32) {
return ((int32_t *)t->data)[idx];
} else {
assert(false);
return INFINITY;
}
}
void set_element(struct ggml_tensor * t, int idx, float value) {
@ -371,7 +426,7 @@ int main(int argc, const char ** argv) {
struct ggml_tensor * f = ggml_sum(ctx0, ggml_div(ctx0, x[0], x[1]));
check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, 1e-1f);
check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-1f, 1e-1f);
}
}
@ -654,6 +709,22 @@ int main(int argc, const char ** argv) {
}
}
// get_rows
{
int64_t ne2[4] = {ne[0], ne[1], 1, 1};
int64_t ne3[4] = {1+irand(ne[1]), 1, 1, 1};
const int nargs = 1;
const int ndims = 2;
x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
x[1] = get_random_tensor_int(ctx0, 1, ne3, 0, ne2[1]);
ggml_set_param(ctx0, x[0]);
struct ggml_tensor * f = ggml_sum(ctx0, ggml_get_rows(ctx0, x[0], x[1]));
check_gradient("get_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
}
// diag_mask_inf
{
const int nargs = 1;