ggml : alternative ALiBi without extra tensor

We compute the slopes in the kernel

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-02-14 17:37:48 +02:00
parent 5261fb2dbe
commit 97d6a0cc06
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
6 changed files with 103 additions and 123 deletions

View file

@ -1191,7 +1191,8 @@ static bool ggml_metal_graph_compute(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
} }
const float scale = ((float *) dst->op_params)[0]; const float scale = ((float *) dst->op_params)[0];
const float max_bias = ((float *) dst->op_params)[1];
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1200,16 +1201,12 @@ static bool ggml_metal_graph_compute(
} else { } else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
} }
if (id_src2) { [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
} else { [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
} [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3]; [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
[encoder setBytes:&scale length:sizeof(scale) atIndex:7];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];

View file

@ -351,12 +351,12 @@ kernel void kernel_sum_rows(
kernel void kernel_soft_max( kernel void kernel_soft_max(
device const float * src0, device const float * src0,
device const float * src1, device const float * src1,
device const float * src2,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01, constant int64_t & ne01,
constant int64_t & ne02, constant int64_t & ne02,
constant float & scale, constant float & scale,
constant float & max_bias,
threadgroup float * buf [[threadgroup(0)]], threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]], uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]], uint tpitg[[thread_position_in_threadgroup]],
@ -369,9 +369,22 @@ kernel void kernel_soft_max(
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
const float slope = src2 != src0 ? src2[i02] : 0.0f;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
float slope = 0.0f;
if (max_bias > 0.0f) {
const uint32_t n_head_kv = ne02;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv));
const float m0 = pow(2.0f, -(max_bias ) / n_head_log2);
const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2);
const int64_t h = i02;
slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1);
}
// parallel max // parallel max
float lmax = -INFINITY; float lmax = -INFINITY;
@ -439,12 +452,12 @@ kernel void kernel_soft_max(
kernel void kernel_soft_max_4( kernel void kernel_soft_max_4(
device const float * src0, device const float * src0,
device const float * src1, device const float * src1,
device const float * src2,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01, constant int64_t & ne01,
constant int64_t & ne02, constant int64_t & ne02,
constant float & scale, constant float & scale,
constant float & max_bias,
threadgroup float * buf [[threadgroup(0)]], threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]], uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]], uint tpitg[[thread_position_in_threadgroup]],
@ -457,11 +470,24 @@ kernel void kernel_soft_max_4(
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
const float slope = src2 != src0 ? src2[i02] : 0.0f;
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
const float4 s0(0.0f, 1.0f, 2.0f, 3.0f); const float4 s0(0.0f, 1.0f, 2.0f, 3.0f);
float slope = 0.0f;
if (max_bias > 0.0f) {
const uint32_t n_head_kv = ne02;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv));
const float m0 = pow(2.0f, -(max_bias ) / n_head_log2);
const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2);
const int64_t h = i02;
slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1);
}
// parallel max // parallel max
float4 lmax4 = -INFINITY; float4 lmax4 = -INFINITY;

48
ggml.c
View file

@ -5060,8 +5060,8 @@ static struct ggml_tensor * ggml_soft_max_impl(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * mask, struct ggml_tensor * mask,
struct ggml_tensor * slope,
float scale, float scale,
float max_bias,
bool inplace) { bool inplace) {
GGML_ASSERT(ggml_is_contiguous(a)); GGML_ASSERT(ggml_is_contiguous(a));
if (mask) { if (mask) {
@ -5070,12 +5070,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
GGML_ASSERT(ggml_can_repeat_rows(mask, a)); GGML_ASSERT(ggml_can_repeat_rows(mask, a));
} }
if (slope) {
GGML_ASSERT(ggml_is_contiguous(slope));
GGML_ASSERT(ggml_is_vector(slope));
GGML_ASSERT(slope->ne[0] == a->ne[2]);
}
bool is_node = false; bool is_node = false;
if (a->grad) { if (a->grad) {
@ -5084,14 +5078,13 @@ static struct ggml_tensor * ggml_soft_max_impl(
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
float params[] = { scale }; float params[] = { scale, max_bias };
ggml_set_op_params(result, params, sizeof(params)); ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_SOFT_MAX; result->op = GGML_OP_SOFT_MAX;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a; result->src[0] = a;
result->src[1] = mask; result->src[1] = mask;
result->src[2] = slope;
return result; return result;
} }
@ -5099,22 +5092,22 @@ static struct ggml_tensor * ggml_soft_max_impl(
struct ggml_tensor * ggml_soft_max( struct ggml_tensor * ggml_soft_max(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a) { struct ggml_tensor * a) {
return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, false); return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false);
} }
struct ggml_tensor * ggml_soft_max_inplace( struct ggml_tensor * ggml_soft_max_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a) { struct ggml_tensor * a) {
return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, true); return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true);
} }
struct ggml_tensor * ggml_soft_max_ext( struct ggml_tensor * ggml_soft_max_ext(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * mask, struct ggml_tensor * mask,
struct ggml_tensor * slope, float scale,
float scale) { float max_bias) {
return ggml_soft_max_impl(ctx, a, mask, slope, scale, false); return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
} }
// ggml_soft_max_back // ggml_soft_max_back
@ -11467,7 +11460,6 @@ static void ggml_compute_forward_soft_max_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
const struct ggml_tensor * src2,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
assert(ggml_is_contiguous(dst)); assert(ggml_is_contiguous(dst));
assert(ggml_are_same_shape(src0, dst)); assert(ggml_are_same_shape(src0, dst));
@ -11476,8 +11468,11 @@ static void ggml_compute_forward_soft_max_f32(
return; return;
} }
float scale = 1.0f; float scale = 1.0f;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); float max_bias = 0.0f;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
// TODO: handle transposed/permuted matrices // TODO: handle transposed/permuted matrices
@ -11488,6 +11483,14 @@ static void ggml_compute_forward_soft_max_f32(
const int64_t ne11 = src1 ? src1->ne[1] : 1; const int64_t ne11 = src1 ? src1->ne[1] : 1;
// TODO: is this supposed to be ceil instead of floor?
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
const uint32_t n_head_kv = ne02;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
const int nc = src0->ne[0]; const int nc = src0->ne[0];
const int nr = ggml_nrows(src0); const int nr = ggml_nrows(src0);
@ -11514,9 +11517,9 @@ static void ggml_compute_forward_soft_max_f32(
} }
// alibi bias // alibi bias
if (src2) { if (max_bias > 0.0f) {
const int h = (i1/ne01)%ne02; const uint32_t h = (i1/ne01)%ne02; // head
const float slope = ((float *)(src2->data))[h]; const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
for (int i = 0; i < nc; i++) { for (int i = 0; i < nc; i++) {
wp[i] = wp[i] + slope*i; wp[i] = wp[i] + slope*i;
@ -11567,12 +11570,11 @@ static void ggml_compute_forward_soft_max(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
const struct ggml_tensor * src2,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
{ {
ggml_compute_forward_soft_max_f32(params, src0, src1, src2, dst); ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
} break; } break;
default: default:
{ {
@ -15099,7 +15101,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break; } break;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
{ {
ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
} break; } break;
case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SOFT_MAX_BACK:
{ {

8
ggml.h
View file

@ -1373,15 +1373,15 @@ extern "C" {
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
// fused soft_max(a*scale + i*slope + mask) // fused soft_max(a*scale + mask + ALiBi bias)
// mask is optional // mask is optional
// slope is optional // max_bias = 0.0f for no ALiBi
GGML_API struct ggml_tensor * ggml_soft_max_ext( GGML_API struct ggml_tensor * ggml_soft_max_ext(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * mask, struct ggml_tensor * mask,
struct ggml_tensor * slope, float scale,
float scale); float max_bias);
GGML_API struct ggml_tensor * ggml_soft_max_back( GGML_API struct ggml_tensor * ggml_soft_max_back(
struct ggml_context * ctx, struct ggml_context * ctx,

View file

@ -1923,7 +1923,6 @@ struct llama_context {
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch] struct ggml_tensor * inp_pos; // I32 [n_batch]
struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch] struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch]
struct ggml_tensor * inp_KQ_slope; // F32 [n_head_kv]
struct ggml_tensor * inp_K_shift; // I32 [n_ctx] struct ggml_tensor * inp_K_shift; // I32 [n_ctx]
struct ggml_tensor * inp_sum; // F32 [n_batch, n_batch] struct ggml_tensor * inp_sum; // F32 [n_batch, n_batch]
@ -4783,7 +4782,6 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * wo_b, struct ggml_tensor * wo_b,
struct ggml_tensor * q_cur, struct ggml_tensor * q_cur,
struct ggml_tensor * kq_mask, struct ggml_tensor * kq_mask,
struct ggml_tensor * kq_slope,
int64_t n_ctx, int64_t n_ctx,
int32_t n_tokens, int32_t n_tokens,
int32_t n_kv, int32_t n_kv,
@ -4816,7 +4814,7 @@ static struct ggml_tensor * llm_build_kqv(
ggml_mul_mat_set_prec(kq, GGML_PREC_F32); ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
} }
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_slope, kq_scale); kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il); cb(kq, "kq_soft_max_ext", il);
// split cached v into n_head heads // split cached v into n_head heads
@ -4863,7 +4861,6 @@ static struct ggml_tensor * llm_build_kv(
struct ggml_tensor * v_cur, struct ggml_tensor * v_cur,
struct ggml_tensor * q_cur, struct ggml_tensor * q_cur,
struct ggml_tensor * kq_mask, struct ggml_tensor * kq_mask,
struct ggml_tensor * kq_slope,
int64_t n_ctx, int64_t n_ctx,
int32_t n_tokens, int32_t n_tokens,
int32_t kv_head, int32_t kv_head,
@ -4882,7 +4879,7 @@ static struct ggml_tensor * llm_build_kv(
struct ggml_tensor * cur; struct ggml_tensor * cur;
cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b, cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b,
q_cur, kq_mask, kq_slope, n_ctx, n_tokens, n_kv, kq_scale, cb, il); q_cur, kq_mask, n_ctx, n_tokens, n_kv, kq_scale, cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
return cur; return cur;
@ -5065,7 +5062,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -5195,9 +5192,6 @@ struct llm_build_context {
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
cb(KQ_mask, "KQ_mask", -1); cb(KQ_mask, "KQ_mask", -1);
struct ggml_tensor * KQ_slope = ggml_view_1d(ctx0, lctx.inp_KQ_slope, n_head_kv, 0);
cb(KQ_slope, "KQ_slope", -1);
// shift the entire K-cache if needed // shift the entire K-cache if needed
if (do_rope_shift) { if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb); llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
@ -5248,7 +5242,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, KQ_slope, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -5372,7 +5366,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -5471,7 +5465,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -5676,7 +5670,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Q, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -5738,9 +5732,6 @@ struct llm_build_context {
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
cb(KQ_mask, "KQ_mask", -1); cb(KQ_mask, "KQ_mask", -1);
struct ggml_tensor * KQ_slope = ggml_view_1d(ctx0, lctx.inp_KQ_slope, n_head_kv, 0);
cb(KQ_slope, "KQ_slope", -1);
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL; struct ggml_tensor * inpSA = inpL;
@ -5768,7 +5759,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, KQ_slope, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -5868,7 +5859,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} else { } else {
// compute Q and K and RoPE them // compute Q and K and RoPE them
@ -5899,7 +5890,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -5971,9 +5962,6 @@ struct llm_build_context {
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
cb(KQ_mask, "KQ_mask", -1); cb(KQ_mask, "KQ_mask", -1);
struct ggml_tensor * KQ_slope = ggml_view_1d(ctx0, lctx.inp_KQ_slope, n_head_kv, 0);
cb(KQ_slope, "KQ_slope", -1);
inpL = llm_build_norm(ctx0, inpL, hparams, inpL = llm_build_norm(ctx0, inpL, hparams,
model.tok_norm, model.tok_norm,
model.tok_norm_b, model.tok_norm_b,
@ -6007,7 +5995,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, KQ_slope, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -6067,9 +6055,6 @@ struct llm_build_context {
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
cb(KQ_mask, "KQ_mask", -1); cb(KQ_mask, "KQ_mask", -1);
struct ggml_tensor * KQ_slope = ggml_view_1d(ctx0, lctx.inp_KQ_slope, n_head_kv, 0);
cb(KQ_slope, "KQ_slope", -1);
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * attn_norm; struct ggml_tensor * attn_norm;
@ -6103,7 +6088,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, KQ_slope, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -6225,7 +6210,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -6340,7 +6325,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -6461,7 +6446,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -6588,7 +6573,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -6691,7 +6676,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
struct ggml_tensor * sa_out = cur; struct ggml_tensor * sa_out = cur;
@ -6790,7 +6775,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -6899,7 +6884,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -7017,7 +7002,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -7136,7 +7121,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -7268,7 +7253,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -7499,32 +7484,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
} }
} }
// using Alibi bias
if (hparams.f_max_alibi_bias > 0.0f) {
const uint32_t n_head_kv = hparams.n_head_kv;
const float max_bias = hparams.f_max_alibi_bias;
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_slope->buffer));
float * data = (float *) lctx.inp_KQ_slope->data;
// TODO: is this supposed to be ceil instead of floor?
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
for (uint32_t h = 0; h < n_head_kv; ++h) {
if (h < n_head_log2) {
data[h] = powf(m0, h + 1);
} else {
data[h] = powf(m1, 2*(h - n_head_log2) + 1);
}
}
}
{ {
assert(ggml_backend_buffer_is_host(lctx.inp_sum->buffer)); assert(ggml_backend_buffer_is_host(lctx.inp_sum->buffer));
float * data = (float *) lctx.inp_sum->data; float * data = (float *) lctx.inp_sum->data;
@ -11440,7 +11399,6 @@ struct llama_context * llama_new_context_with_model(
ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch); ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch);
ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch); ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
ctx->inp_KQ_slope = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_head_kv);
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx); ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx);
ctx->inp_sum = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch); ctx->inp_sum = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
@ -11448,7 +11406,6 @@ struct llama_context * llama_new_context_with_model(
ggml_set_name(ctx->inp_embd, "inp_embd"); ggml_set_name(ctx->inp_embd, "inp_embd");
ggml_set_name(ctx->inp_pos, "inp_pos"); ggml_set_name(ctx->inp_pos, "inp_pos");
ggml_set_name(ctx->inp_KQ_mask, "inp_KQ_mask"); ggml_set_name(ctx->inp_KQ_mask, "inp_KQ_mask");
ggml_set_name(ctx->inp_KQ_slope, "inp_KQ_slope");
ggml_set_name(ctx->inp_K_shift, "inp_K_shift"); ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
ggml_set_name(ctx->inp_sum, "inp_sum"); ggml_set_name(ctx->inp_sum, "inp_sum");

View file

@ -1086,27 +1086,25 @@ struct test_soft_max : public test_case {
const ggml_type type; const ggml_type type;
const std::array<int64_t, 4> ne; const std::array<int64_t, 4> ne;
const bool mask; const bool mask;
const bool slope;
const float scale; const float scale;
const float max_bias;
std::string vars() override { std::string vars() override {
return VARS_TO_STR5(type, ne, mask, slope, scale); return VARS_TO_STR5(type, ne, mask, scale, max_bias);
} }
test_soft_max(ggml_type type = GGML_TYPE_F32, test_soft_max(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 10, 10}, std::array<int64_t, 4> ne = {10, 10, 10, 10},
bool mask = false, bool mask = false,
bool slope = false, float scale = 1.0f,
float scale = 1.0f) float max_bias = 0.0f)
: type(type), ne(ne), mask(mask), slope(slope), scale(scale) {} : type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * b = nullptr; ggml_tensor * b = nullptr;
ggml_tensor * c = nullptr; if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); }
if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); } ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, scale, max_bias);
if (slope) { c = ggml_new_tensor_1d(ctx, type, ne[2]); }
ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, c, scale);
return out; return out;
} }
}; };
@ -1492,7 +1490,7 @@ struct test_moe : public test_case {
ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens); ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur); ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur);
ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, nullptr, 1.0f/sqrtf(n_embd)); ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, 1.0f/sqrtf(n_embd), 0.0f);
// select experts // select experts
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok); ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok);
@ -1640,7 +1638,7 @@ public:
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
kq = ggml_soft_max_ext(ctx, kq, kq_mask, nullptr, kq_scale); kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, 0.0f);
// split cached v into n_head heads // split cached v into n_head heads
struct ggml_tensor * v = struct ggml_tensor * v =
@ -2095,16 +2093,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (int n = 0; n < 10; ++n) { for (int n = 0; n < 10; ++n) {
int64_t ne0 = dist_ne0(rng); int64_t ne0 = dist_ne0(rng);
int64_t ne1 = dist_ne1(rng); int64_t ne1 = dist_ne1(rng);
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, n/3 == 0 && ne0 < 1000, 0.1f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, 4.0f));
} }
exponent <<= 1; exponent <<= 1;
} }
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, false, 0.1f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, false, 0.1f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, true, 0.1f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 8.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, 0.1f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B