parent
e84b71c2c6
commit
d48c88cbd5
5 changed files with 47 additions and 769 deletions
|
@ -1515,90 +1515,50 @@ int main(int argc, const char ** argv) {
|
|||
}
|
||||
|
||||
// flash_attn f32
|
||||
{
|
||||
srand(seed);
|
||||
const int nargs = 3;
|
||||
// TODO: adapt to ggml_flash_attn_ext() changes
|
||||
//{
|
||||
// srand(seed);
|
||||
// const int nargs = 3;
|
||||
|
||||
int64_t ne2[4];
|
||||
// int64_t ne2[4];
|
||||
|
||||
get_random_dims(ne2, 4);
|
||||
int64_t D = ne2[0];
|
||||
int64_t N = ne2[1];
|
||||
int64_t M = ne2[2] + N;
|
||||
int64_t B = ne2[3];
|
||||
// get_random_dims(ne2, 4);
|
||||
// int64_t D = ne2[0];
|
||||
// int64_t N = ne2[1];
|
||||
// int64_t M = ne2[2] + N;
|
||||
// int64_t B = ne2[3];
|
||||
|
||||
for (int masked = 0; masked <= 1; ++masked) {
|
||||
for (int ndims = 2; ndims <= 4; ++ndims) {
|
||||
int max_nrep = (ndims >= 3) ? 2 : 1;
|
||||
for (int nrep = 1; nrep < max_nrep; ++nrep) {
|
||||
int64_t neq[4] = { D, N, B*nrep, ne[3] };
|
||||
int64_t nek[4] = { D, M, B, ne[3] };
|
||||
int64_t nev[4] = { M, D, B, ne[3] };
|
||||
if (ndims == 2) {
|
||||
neq[2] = 1; neq[3] = 1;
|
||||
nek[2] = 1; nek[3] = 1;
|
||||
nev[2] = 1; nev[3] = 1;
|
||||
} else if (ndims == 3) {
|
||||
neq[3] = 1;
|
||||
nek[3] = 1;
|
||||
nev[3] = 1;
|
||||
}
|
||||
x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
|
||||
x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
|
||||
x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
|
||||
ggml_set_param(ctx0, x[0]);
|
||||
ggml_set_param(ctx0, x[1]);
|
||||
ggml_set_param(ctx0, x[2]);
|
||||
// for (int masked = 0; masked <= 1; ++masked) {
|
||||
// for (int ndims = 2; ndims <= 4; ++ndims) {
|
||||
// int max_nrep = (ndims >= 3) ? 2 : 1;
|
||||
// for (int nrep = 1; nrep < max_nrep; ++nrep) {
|
||||
// int64_t neq[4] = { D, N, B*nrep, ne[3] };
|
||||
// int64_t nek[4] = { D, M, B, ne[3] };
|
||||
// int64_t nev[4] = { M, D, B, ne[3] };
|
||||
// if (ndims == 2) {
|
||||
// neq[2] = 1; neq[3] = 1;
|
||||
// nek[2] = 1; nek[3] = 1;
|
||||
// nev[2] = 1; nev[3] = 1;
|
||||
// } else if (ndims == 3) {
|
||||
// neq[3] = 1;
|
||||
// nek[3] = 1;
|
||||
// nev[3] = 1;
|
||||
// }
|
||||
// x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
|
||||
// x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
|
||||
// x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
|
||||
// ggml_set_param(ctx0, x[0]);
|
||||
// ggml_set_param(ctx0, x[1]);
|
||||
// ggml_set_param(ctx0, x[2]);
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
|
||||
// struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
|
||||
|
||||
check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
// flash_attn f16, not yet fully implemented
|
||||
if(0)
|
||||
{
|
||||
srand(seed);
|
||||
const int nargs = 3;
|
||||
|
||||
int64_t ne2[4];
|
||||
|
||||
get_random_dims(ne2, 4);
|
||||
int64_t D = ne2[0];
|
||||
int64_t N = ne2[1];
|
||||
int64_t M = ne2[2] + N;
|
||||
int64_t B = ne2[3];
|
||||
|
||||
for (int masked = 0; masked <= 1; ++masked) {
|
||||
for (int ndims = 2; ndims <= 4; ++ndims) {
|
||||
int64_t neq[4] = { D, N, B, ne[3] };
|
||||
int64_t nek[4] = { D, M, B, ne[3] };
|
||||
int64_t nev[4] = { M, D, B, ne[3] };
|
||||
if (ndims == 2) {
|
||||
neq[2] = 1; neq[3] = 1;
|
||||
nek[2] = 1; nek[3] = 1;
|
||||
nev[2] = 1; nev[3] = 1;
|
||||
} else if (ndims == 3) {
|
||||
neq[3] = 1;
|
||||
nek[3] = 1;
|
||||
nev[3] = 1;
|
||||
}
|
||||
x[0] = get_random_tensor_f16(ctx0, ndims, neq, -0.1250f, 0.1250f);
|
||||
x[1] = get_random_tensor_f16(ctx0, ndims, nek, -0.1250f, 0.1250f);
|
||||
x[2] = get_random_tensor_f16(ctx0, ndims, nev, -0.1250f, 0.1250f);
|
||||
ggml_set_param(ctx0, x[0]);
|
||||
ggml_set_param(ctx0, x[1]);
|
||||
ggml_set_param(ctx0, x[2]);
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
|
||||
|
||||
check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
|
||||
}
|
||||
}
|
||||
}
|
||||
ggml_free(ctx0);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue