tests : more

This commit is contained in:
Georgi Gerganov 2024-01-29 18:22:28 +02:00
parent abeaf0d90e
commit c6c1132e5e
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 22 additions and 24 deletions

5
ggml.c
View file

@ -13554,11 +13554,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int64_t D = neq0;
const int64_t N = neq1;
const int64_t P = nek1 - N;
GGML_ASSERT(ne0 == D);
GGML_ASSERT(ne2 == N);
GGML_ASSERT(P >= 0);
GGML_ASSERT(nbq0 == sizeof(float));
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
@ -13569,7 +13567,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
GGML_ASSERT(nev0 == D);
GGML_ASSERT(neq1 == N);
GGML_ASSERT(nek1 == N + P);
GGML_ASSERT(nev0 == D);
// dst cannot be transposed or permuted
@ -13608,8 +13605,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
float scale = 1.0f;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
// loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) {
// q indices