successfully test get_rows backward
This commit is contained in:
parent
96e773bbde
commit
f0302fa71b
1 changed files with 73 additions and 2 deletions
|
@ -104,8 +104,63 @@ struct ggml_tensor * get_random_tensor(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * get_random_tensor_int(
|
||||||
|
struct ggml_context * ctx0,
|
||||||
|
int ndims,
|
||||||
|
int64_t ne[],
|
||||||
|
int32_t imin,
|
||||||
|
int32_t imax) {
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_I32, ndims, ne);
|
||||||
|
|
||||||
|
switch (ndims) {
|
||||||
|
case 1:
|
||||||
|
for (int i0 = 0; i0 < ne[0]; i0++) {
|
||||||
|
((int32_t *)result->data)[i0] = irand(imax - imin) + imin;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
for (int i1 = 0; i1 < ne[1]; i1++) {
|
||||||
|
for (int i0 = 0; i0 < ne[0]; i0++) {
|
||||||
|
((int32_t *)result->data)[i1*ne[0] + i0] = irand(imax - imin) + imin;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
for (int i2 = 0; i2 < ne[2]; i2++) {
|
||||||
|
for (int i1 = 0; i1 < ne[1]; i1++) {
|
||||||
|
for (int i0 = 0; i0 < ne[0]; i0++) {
|
||||||
|
((int32_t *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = irand(imax - imin) + imin;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
for (int i3 = 0; i3 < ne[3]; i3++) {
|
||||||
|
for (int i2 = 0; i2 < ne[2]; i2++) {
|
||||||
|
for (int i1 = 0; i1 < ne[1]; i1++) {
|
||||||
|
for (int i0 = 0; i0 < ne[0]; i0++) {
|
||||||
|
((int32_t *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = irand(imax - imin) + imin;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
assert(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
float get_element(const struct ggml_tensor * t, int idx) {
|
float get_element(const struct ggml_tensor * t, int idx) {
|
||||||
|
if (t->type == GGML_TYPE_F32) {
|
||||||
return ((float *)t->data)[idx];
|
return ((float *)t->data)[idx];
|
||||||
|
} else if (t->type == GGML_TYPE_I32) {
|
||||||
|
return ((int32_t *)t->data)[idx];
|
||||||
|
} else {
|
||||||
|
assert(false);
|
||||||
|
return INFINITY;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_element(struct ggml_tensor * t, int idx, float value) {
|
void set_element(struct ggml_tensor * t, int idx, float value) {
|
||||||
|
@ -371,7 +426,7 @@ int main(int argc, const char ** argv) {
|
||||||
|
|
||||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_div(ctx0, x[0], x[1]));
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_div(ctx0, x[0], x[1]));
|
||||||
|
|
||||||
check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, 1e-1f);
|
check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-1f, 1e-1f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -654,6 +709,22 @@ int main(int argc, const char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get_rows
|
||||||
|
{
|
||||||
|
int64_t ne2[4] = {ne[0], ne[1], 1, 1};
|
||||||
|
int64_t ne3[4] = {1+irand(ne[1]), 1, 1, 1};
|
||||||
|
const int nargs = 1;
|
||||||
|
const int ndims = 2;
|
||||||
|
x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
|
||||||
|
x[1] = get_random_tensor_int(ctx0, 1, ne3, 0, ne2[1]);
|
||||||
|
|
||||||
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_get_rows(ctx0, x[0], x[1]));
|
||||||
|
|
||||||
|
check_gradient("get_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||||
|
}
|
||||||
|
|
||||||
// diag_mask_inf
|
// diag_mask_inf
|
||||||
{
|
{
|
||||||
const int nargs = 1;
|
const int nargs = 1;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue