successfully test diag_mask_inf and diag_mask_zero backward
This commit is contained in:
parent
d42531fa56
commit
19f51592b5
1 changed files with 30 additions and 1 deletions
|
@ -627,7 +627,6 @@ int main(int argc, const char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// transpose
|
// transpose
|
||||||
{
|
{
|
||||||
int64_t ne2[4];
|
int64_t ne2[4];
|
||||||
|
@ -655,6 +654,36 @@ int main(int argc, const char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// diag_mask_inf
|
||||||
|
{
|
||||||
|
const int nargs = 1;
|
||||||
|
const int ndims = 2;
|
||||||
|
|
||||||
|
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
|
int n_past = irand(ne[0]);
|
||||||
|
|
||||||
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_diag_mask_inf(ctx0, x[0], n_past));
|
||||||
|
|
||||||
|
check_gradient("diag_mask_inf", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||||
|
}
|
||||||
|
|
||||||
|
// diag_mask_zero
|
||||||
|
{
|
||||||
|
const int nargs = 1;
|
||||||
|
const int ndims = 2;
|
||||||
|
|
||||||
|
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
|
int n_past = irand(ne[0]);
|
||||||
|
|
||||||
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_diag_mask_zero(ctx0, x[0], n_past));
|
||||||
|
|
||||||
|
check_gradient("diag_mask_zero", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||||
|
}
|
||||||
|
|
||||||
// softmax
|
// softmax
|
||||||
{
|
{
|
||||||
const int nargs = 1;
|
const int nargs = 1;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue