successfully test silu backward
This commit is contained in:
parent
6fb08b4554
commit
671e5922e2
1 changed files with 24 additions and 0 deletions
|
@ -7,6 +7,9 @@
|
|||
|
||||
#define MAX_NARGS 2
|
||||
|
||||
|
||||
#define GGML_SILU_FP16
|
||||
|
||||
//
|
||||
// logging
|
||||
//
|
||||
|
@ -409,6 +412,27 @@ int main(int argc, const char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
// silu
|
||||
{
|
||||
const int nargs = 1;
|
||||
|
||||
for (int ndims = 1; ndims <= 2; ++ndims) {
|
||||
for (int i = 0; i < nargs; ++i) {
|
||||
x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||
ggml_set_param(ctx0, x[i]);
|
||||
}
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_silu(ctx0, x[0]));
|
||||
|
||||
#ifdef GGML_SILU_FP16
|
||||
// due to GGML_SILU_FP16 the finite difference method will be slightly wrong -> increase error bounds.
|
||||
check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 0.5, INFINITY);
|
||||
#else
|
||||
check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
ggml_free(ctx0);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue