ggml : ggml_flash_attn_ext() support ALiBi (CPU)
This commit is contained in:
parent
d0592d495d
commit
166e60bf9b
3 changed files with 25 additions and 11 deletions
27
ggml.c
27
ggml.c
|
@ -6436,7 +6436,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||||
struct ggml_tensor * k,
|
struct ggml_tensor * k,
|
||||||
struct ggml_tensor * v,
|
struct ggml_tensor * v,
|
||||||
struct ggml_tensor * mask,
|
struct ggml_tensor * mask,
|
||||||
float scale) {
|
float scale,
|
||||||
|
float max_bias) {
|
||||||
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
||||||
// TODO: check if vT can be multiplied by (k*qT)
|
// TODO: check if vT can be multiplied by (k*qT)
|
||||||
if (mask) {
|
if (mask) {
|
||||||
|
@ -6458,7 +6459,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||||
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
|
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
|
||||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||||
|
|
||||||
float params[] = { scale };
|
float params[] = { scale, max_bias };
|
||||||
ggml_set_op_params(result, params, sizeof(params));
|
ggml_set_op_params(result, params, sizeof(params));
|
||||||
|
|
||||||
result->op = GGML_OP_FLASH_ATTN_EXT;
|
result->op = GGML_OP_FLASH_ATTN_EXT;
|
||||||
|
@ -6478,7 +6479,7 @@ void ggml_flash_attn_ext_set_prec(
|
||||||
|
|
||||||
const int32_t prec_i32 = (int32_t) prec;
|
const int32_t prec_i32 = (int32_t) prec;
|
||||||
|
|
||||||
ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
|
ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_flash_ff
|
// ggml_flash_ff
|
||||||
|
@ -13308,8 +13309,8 @@ static void ggml_compute_forward_soft_max_f32(
|
||||||
|
|
||||||
// TODO: is this supposed to be ceil instead of floor?
|
// TODO: is this supposed to be ceil instead of floor?
|
||||||
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
||||||
const uint32_t n_head_kv = ne02;
|
const uint32_t n_head = ne02;
|
||||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv));
|
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||||
|
|
||||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
@ -15525,7 +15526,16 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
float scale = 1.0f;
|
float scale = 1.0f;
|
||||||
|
float max_bias = 0.0f;
|
||||||
|
|
||||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||||
|
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||||
|
|
||||||
|
const uint32_t n_head = neq2;
|
||||||
|
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||||
|
|
||||||
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
// loop over n_batch and n_head
|
// loop over n_batch and n_head
|
||||||
for (int ir = ir0; ir < ir1; ++ir) {
|
for (int ir = ir0; ir < ir1; ++ir) {
|
||||||
|
@ -15534,6 +15544,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
||||||
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
||||||
|
|
||||||
|
const int h = iq2; // head
|
||||||
|
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
||||||
|
|
||||||
float S = 0.0f;
|
float S = 0.0f;
|
||||||
float M = -INFINITY;
|
float M = -INFINITY;
|
||||||
|
|
||||||
|
@ -15557,7 +15570,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
// loop over n_kv and n_head_kv
|
// loop over n_kv and n_head_kv
|
||||||
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
||||||
for (int64_t ic = 0; ic < nek1; ++ic) {
|
for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||||
const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
|
const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
|
||||||
if (mv == -INFINITY) {
|
if (mv == -INFINITY) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -15628,7 +15641,7 @@ static void ggml_compute_forward_flash_attn_ext(
|
||||||
const struct ggml_tensor * v,
|
const struct ggml_tensor * v,
|
||||||
const struct ggml_tensor * mask,
|
const struct ggml_tensor * mask,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
switch (dst->op_params[1]) {
|
switch (dst->op_params[2]) {
|
||||||
case GGML_PREC_DEFAULT:
|
case GGML_PREC_DEFAULT:
|
||||||
case GGML_PREC_F32:
|
case GGML_PREC_F32:
|
||||||
{
|
{
|
||||||
|
|
3
ggml.h
3
ggml.h
|
@ -1731,7 +1731,8 @@ extern "C" {
|
||||||
struct ggml_tensor * k,
|
struct ggml_tensor * k,
|
||||||
struct ggml_tensor * v,
|
struct ggml_tensor * v,
|
||||||
struct ggml_tensor * mask,
|
struct ggml_tensor * mask,
|
||||||
float scale);
|
float scale,
|
||||||
|
float max_bias);
|
||||||
|
|
||||||
GGML_API void ggml_flash_attn_ext_set_prec(
|
GGML_API void ggml_flash_attn_ext_set_prec(
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
|
|
@ -6537,7 +6537,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
0);
|
0);
|
||||||
cb(v, "v", il);
|
cb(v, "v", il);
|
||||||
|
|
||||||
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale);
|
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||||
|
|
||||||
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) {
|
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) {
|
||||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue