tests : more
This commit is contained in:
parent
abeaf0d90e
commit
c6c1132e5e
4 changed files with 22 additions and 24 deletions
5
ggml.c
5
ggml.c
|
@ -13554,11 +13554,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||
|
||||
const int64_t D = neq0;
|
||||
const int64_t N = neq1;
|
||||
const int64_t P = nek1 - N;
|
||||
|
||||
GGML_ASSERT(ne0 == D);
|
||||
GGML_ASSERT(ne2 == N);
|
||||
GGML_ASSERT(P >= 0);
|
||||
|
||||
GGML_ASSERT(nbq0 == sizeof(float));
|
||||
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
|
||||
|
@ -13569,7 +13567,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||
GGML_ASSERT(nev0 == D);
|
||||
|
||||
GGML_ASSERT(neq1 == N);
|
||||
GGML_ASSERT(nek1 == N + P);
|
||||
GGML_ASSERT(nev0 == D);
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
|
@ -13608,8 +13605,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||
float scale = 1.0f;
|
||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||
|
||||
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
||||
|
||||
// loop over n_batch and n_head
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// q indices
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue