ggml : online attention (CPU)

This commit is contained in:
Georgi Gerganov 2024-01-20 12:26:49 +02:00
parent c3cdfffa88
commit a9681febd6
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
6 changed files with 231 additions and 198 deletions

View file

@ -1390,21 +1390,21 @@ struct test_flash_attn_ext : public test_case {
const int64_t hs; // head size
const int64_t nh; // num heads
const int64_t kv; // kv size
const int64_t nt; // tokens
const int64_t nb; // batch size
std::string vars() override {
return VARS_TO_STR5(typeq, hs, nh, kv, nt);
return VARS_TO_STR5(typeq, hs, nh, kv, nb);
}
test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16,
int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nt = 8)
: typeq(typeq), hs(hs), nh(nh), kv(kv), nt(nt) {}
int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
: typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nt, nh, 1);
ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nb, nh, 1);
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1);
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nt, 1, 1);
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1);
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs));
return out;
}