ggml : support alibi bias in ggml_soft_max_ext (CPU + Metal)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-02-14 16:03:42 +02:00
parent 6ca762eccf
commit 5055a0c990
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
6 changed files with 156 additions and 94 deletions

View file

@ -728,6 +728,7 @@ static bool ggml_metal_graph_compute(
size_t offs_src0 = 0; size_t offs_src0 = 0;
size_t offs_src1 = 0; size_t offs_src1 = 0;
size_t offs_src2 = 0;
size_t offs_dst = 0; size_t offs_dst = 0;
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx]; id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
@ -746,6 +747,7 @@ static bool ggml_metal_graph_compute(
struct ggml_tensor * src0 = gf->nodes[i]->src[0]; struct ggml_tensor * src0 = gf->nodes[i]->src[0];
struct ggml_tensor * src1 = gf->nodes[i]->src[1]; struct ggml_tensor * src1 = gf->nodes[i]->src[1];
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
struct ggml_tensor * dst = gf->nodes[i]; struct ggml_tensor * dst = gf->nodes[i];
switch (dst->op) { switch (dst->op) {
@ -807,6 +809,7 @@ static bool ggml_metal_graph_compute(
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil; id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil; id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
//GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
@ -1197,11 +1200,16 @@ static bool ggml_metal_graph_compute(
} else { } else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
} }
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; if (id_src2) {
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; } else {
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
[encoder setBytes:&scale length:sizeof(scale) atIndex:6]; }
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
[encoder setBytes:&scale length:sizeof(scale) atIndex:7];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@ -1514,8 +1522,6 @@ static bool ggml_metal_graph_compute(
// max size of the src1ids array in the kernel stack // max size of the src1ids array in the kernel stack
GGML_ASSERT(ne11 <= 512); GGML_ASSERT(ne11 <= 512);
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
const int64_t ne20 = src2 ? src2->ne[0] : 0; const int64_t ne20 = src2 ? src2->ne[0] : 0;
const int64_t ne21 = src2 ? src2->ne[1] : 0; const int64_t ne21 = src2 ? src2->ne[1] : 0;
const int64_t ne22 = src2 ? src2->ne[2] : 0; const int64_t ne22 = src2 ? src2->ne[2] : 0;

View file

@ -351,12 +351,13 @@ kernel void kernel_sum_rows(
kernel void kernel_soft_max( kernel void kernel_soft_max(
device const float * src0, device const float * src0,
device const float * src1, device const float * src1,
device const float * src2,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01, constant int64_t & ne01,
constant int64_t & ne02, constant int64_t & ne02,
constant float & scale, constant float & scale,
threadgroup float * buf [[threadgroup(0)]], threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]], uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]], uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]],
@ -368,13 +369,14 @@ kernel void kernel_soft_max(
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
const float slope = src2 != src0 ? src2[i02] : 0.0f;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
// parallel max // parallel max
float lmax = -INFINITY; float lmax = -INFINITY;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) { for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)); lmax = MAX(lmax, psrc0[i00]*scale + slope*i00 + (pmask ? pmask[i00] : 0.0f));
} }
// find the max value in the block // find the max value in the block
@ -399,7 +401,7 @@ kernel void kernel_soft_max(
// parallel sum // parallel sum
float lsum = 0.0f; float lsum = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) { for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); const float exp_psrc0 = exp((psrc0[i00]*scale + slope*i00 + (pmask ? pmask[i00] : 0.0f)) - max_val);
lsum += exp_psrc0; lsum += exp_psrc0;
pdst[i00] = exp_psrc0; pdst[i00] = exp_psrc0;
} }
@ -437,12 +439,13 @@ kernel void kernel_soft_max(
kernel void kernel_soft_max_4( kernel void kernel_soft_max_4(
device const float * src0, device const float * src0,
device const float * src1, device const float * src1,
device const float * src2,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01, constant int64_t & ne01,
constant int64_t & ne02, constant int64_t & ne02,
constant float & scale, constant float & scale,
threadgroup float * buf [[threadgroup(0)]], threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]], uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]], uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]],
@ -454,13 +457,16 @@ kernel void kernel_soft_max_4(
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
const float slope = src2 != src0 ? src2[i02] : 0.0f;
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
const float4 s0(0.0f, 1.0f, 2.0f, 3.0f);
// parallel max // parallel max
float4 lmax4 = -INFINITY; float4 lmax4 = -INFINITY;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)); lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(4*i00 + s0) + (pmask ? pmask[i00] : 0.0f));
} }
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@ -486,7 +492,7 @@ kernel void kernel_soft_max_4(
// parallel sum // parallel sum
float4 lsum4 = 0.0f; float4 lsum4 = 0.0f;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val); const float4 exp_psrc4 = exp((psrc4[i00]*scale + slope*(4*i00 + s0) + (pmask ? pmask[i00] : 0.0f)) - max_val);
lsum4 += exp_psrc4; lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4; pdst4[i00] = exp_psrc4;
} }

36
ggml.c
View file

@ -5060,16 +5060,22 @@ static struct ggml_tensor * ggml_soft_max_impl(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * mask, struct ggml_tensor * mask,
struct ggml_tensor * slope,
float scale, float scale,
bool inplace) { bool inplace) {
GGML_ASSERT(ggml_is_contiguous(a)); GGML_ASSERT(ggml_is_contiguous(a));
if (mask) { if (mask) {
GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(mask->ne[2] == 1); GGML_ASSERT(ggml_is_matrix(mask));
GGML_ASSERT(mask->ne[3] == 1);
GGML_ASSERT(ggml_can_repeat_rows(mask, a)); GGML_ASSERT(ggml_can_repeat_rows(mask, a));
} }
if (slope) {
GGML_ASSERT(ggml_is_contiguous(slope));
GGML_ASSERT(ggml_is_vector(slope));
GGML_ASSERT(slope->ne[0] == a->ne[2]);
}
bool is_node = false; bool is_node = false;
if (a->grad) { if (a->grad) {
@ -5085,6 +5091,7 @@ static struct ggml_tensor * ggml_soft_max_impl(
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a; result->src[0] = a;
result->src[1] = mask; result->src[1] = mask;
result->src[2] = slope;
return result; return result;
} }
@ -5092,21 +5099,22 @@ static struct ggml_tensor * ggml_soft_max_impl(
struct ggml_tensor * ggml_soft_max( struct ggml_tensor * ggml_soft_max(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a) { struct ggml_tensor * a) {
return ggml_soft_max_impl(ctx, a, NULL, 1.0f, false); return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, false);
} }
struct ggml_tensor * ggml_soft_max_inplace( struct ggml_tensor * ggml_soft_max_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a) { struct ggml_tensor * a) {
return ggml_soft_max_impl(ctx, a, NULL, 1.0f, true); return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, true);
} }
struct ggml_tensor * ggml_soft_max_ext( struct ggml_tensor * ggml_soft_max_ext(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * mask, struct ggml_tensor * mask,
struct ggml_tensor * slope,
float scale) { float scale) {
return ggml_soft_max_impl(ctx, a, mask, scale, false); return ggml_soft_max_impl(ctx, a, mask, slope, scale, false);
} }
// ggml_soft_max_back // ggml_soft_max_back
@ -11459,6 +11467,7 @@ static void ggml_compute_forward_soft_max_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
const struct ggml_tensor * src2,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
assert(ggml_is_contiguous(dst)); assert(ggml_is_contiguous(dst));
assert(ggml_are_same_shape(src0, dst)); assert(ggml_are_same_shape(src0, dst));
@ -11475,6 +11484,8 @@ static void ggml_compute_forward_soft_max_f32(
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
GGML_TENSOR_UNARY_OP_LOCALS
const int64_t ne11 = src1 ? src1->ne[1] : 1; const int64_t ne11 = src1 ? src1->ne[1] : 1;
const int nc = src0->ne[0]; const int nc = src0->ne[0];
@ -11502,6 +11513,16 @@ static void ggml_compute_forward_soft_max_f32(
ggml_vec_acc_f32(nc, wp, mp); ggml_vec_acc_f32(nc, wp, mp);
} }
// alibi bias
if (src2) {
const int h = (i1/ne01)%ne02;
const float slope = ((float *)(src2->data))[h];
for (int i = 0; i < nc; i++) {
wp[i] = wp[i] + slope*i;
}
}
#ifndef NDEBUG #ifndef NDEBUG
for (int i = 0; i < nc; ++i) { for (int i = 0; i < nc; ++i) {
//printf("p[%d] = %f\n", i, p[i]); //printf("p[%d] = %f\n", i, p[i]);
@ -11546,11 +11567,12 @@ static void ggml_compute_forward_soft_max(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
const struct ggml_tensor * src2,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
{ {
ggml_compute_forward_soft_max_f32(params, src0, src1, dst); ggml_compute_forward_soft_max_f32(params, src0, src1, src2, dst);
} break; } break;
default: default:
{ {
@ -15077,7 +15099,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break; } break;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
{ {
ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor); ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
} break; } break;
case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SOFT_MAX_BACK:
{ {

4
ggml.h
View file

@ -1373,12 +1373,14 @@ extern "C" {
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
// fused soft_max(a*scale + mask) // fused soft_max(a*scale + i*slope + mask)
// mask is optional // mask is optional
// slope is optional
GGML_API struct ggml_tensor * ggml_soft_max_ext( GGML_API struct ggml_tensor * ggml_soft_max_ext(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * mask, struct ggml_tensor * mask,
struct ggml_tensor * slope,
float scale); float scale);
GGML_API struct ggml_tensor * ggml_soft_max_back( GGML_API struct ggml_tensor * ggml_soft_max_back(

142
llama.cpp
View file

@ -1923,6 +1923,7 @@ struct llama_context {
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch] struct ggml_tensor * inp_pos; // I32 [n_batch]
struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch] struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch]
struct ggml_tensor * inp_KQ_slope; // F32 [n_head_kv]
struct ggml_tensor * inp_K_shift; // I32 [n_ctx] struct ggml_tensor * inp_K_shift; // I32 [n_ctx]
struct ggml_tensor * inp_sum; // F32 [n_batch, n_batch] struct ggml_tensor * inp_sum; // F32 [n_batch, n_batch]
@ -4782,10 +4783,10 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * wo_b, struct ggml_tensor * wo_b,
struct ggml_tensor * q_cur, struct ggml_tensor * q_cur,
struct ggml_tensor * kq_mask, struct ggml_tensor * kq_mask,
struct ggml_tensor * kq_slope,
int64_t n_ctx, int64_t n_ctx,
int32_t n_tokens, int32_t n_tokens,
int32_t n_kv, int32_t n_kv,
float max_alibi_bias,
float kq_scale, float kq_scale,
const llm_build_cb & cb, const llm_build_cb & cb,
int il) { int il) {
@ -4815,28 +4816,8 @@ static struct ggml_tensor * llm_build_kqv(
ggml_mul_mat_set_prec(kq, GGML_PREC_F32); ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
} }
if (max_alibi_bias > 0.0f) { kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_slope, kq_scale);
// temporary branch until we figure out how to handle ggml_alibi through ggml_add cb(kq, "kq_soft_max_ext", il);
kq = ggml_scale(ctx, kq, kq_scale);
cb(kq, "kq_scaled", il);
if (max_alibi_bias > 0.0f) {
// TODO: n_head or n_head_kv
// TODO: K-shift is likely not working
// TODO: change to ggml_add
kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias);
cb(kq, "kq_scaled_alibi", il);
}
kq = ggml_add(ctx, kq, kq_mask);
cb(kq, "kq_masked", il);
kq = ggml_soft_max(ctx, kq);
cb(kq, "kq_soft_max", il);
} else {
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale);
cb(kq, "kq_soft_max_ext", il);
}
// split cached v into n_head heads // split cached v into n_head heads
struct ggml_tensor * v = struct ggml_tensor * v =
@ -4882,11 +4863,11 @@ static struct ggml_tensor * llm_build_kv(
struct ggml_tensor * v_cur, struct ggml_tensor * v_cur,
struct ggml_tensor * q_cur, struct ggml_tensor * q_cur,
struct ggml_tensor * kq_mask, struct ggml_tensor * kq_mask,
struct ggml_tensor * kq_slope,
int64_t n_ctx, int64_t n_ctx,
int32_t n_tokens, int32_t n_tokens,
int32_t kv_head, int32_t kv_head,
int32_t n_kv, int32_t n_kv,
float max_alibi_bias,
float kq_scale, float kq_scale,
const llm_build_cb & cb, const llm_build_cb & cb,
int il) { int il) {
@ -4900,9 +4881,8 @@ static struct ggml_tensor * llm_build_kv(
llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il); llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il);
struct ggml_tensor * cur; struct ggml_tensor * cur;
cur = llm_build_kqv(ctx, model, hparams, kv, graph, cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b,
wo, wo_b, q_cur, kq_mask, kq_slope, n_ctx, n_tokens, n_kv, kq_scale, cb, il);
q_cur, kq_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, kq_scale, cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
return cur; return cur;
@ -5070,7 +5050,7 @@ struct llm_build_context {
} }
Qcur = ggml_rope_custom( Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale, hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow ext_factor, attn_factor, beta_fast, beta_slow
); );
@ -5085,7 +5065,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -5215,6 +5195,9 @@ struct llm_build_context {
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
cb(KQ_mask, "KQ_mask", -1); cb(KQ_mask, "KQ_mask", -1);
struct ggml_tensor * KQ_slope = ggml_view_1d(ctx0, lctx.inp_KQ_slope, n_head_kv, 0);
cb(KQ_slope, "KQ_slope", -1);
// shift the entire K-cache if needed // shift the entire K-cache if needed
if (do_rope_shift) { if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb); llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
@ -5265,7 +5248,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, KQ_slope, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -5389,7 +5372,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -5488,7 +5471,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -5693,7 +5676,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Q, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -5755,6 +5738,9 @@ struct llm_build_context {
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
cb(KQ_mask, "KQ_mask", -1); cb(KQ_mask, "KQ_mask", -1);
struct ggml_tensor * KQ_slope = ggml_view_1d(ctx0, lctx.inp_KQ_slope, n_head_kv, 0);
cb(KQ_slope, "KQ_slope", -1);
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL; struct ggml_tensor * inpSA = inpL;
@ -5782,7 +5768,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, KQ_slope, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -5882,7 +5868,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} else { } else {
// compute Q and K and RoPE them // compute Q and K and RoPE them
@ -5913,7 +5899,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -5985,6 +5971,9 @@ struct llm_build_context {
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
cb(KQ_mask, "KQ_mask", -1); cb(KQ_mask, "KQ_mask", -1);
struct ggml_tensor * KQ_slope = ggml_view_1d(ctx0, lctx.inp_KQ_slope, n_head_kv, 0);
cb(KQ_slope, "KQ_slope", -1);
inpL = llm_build_norm(ctx0, inpL, hparams, inpL = llm_build_norm(ctx0, inpL, hparams,
model.tok_norm, model.tok_norm,
model.tok_norm_b, model.tok_norm_b,
@ -6018,7 +6007,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, KQ_slope, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -6078,6 +6067,9 @@ struct llm_build_context {
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
cb(KQ_mask, "KQ_mask", -1); cb(KQ_mask, "KQ_mask", -1);
struct ggml_tensor * KQ_slope = ggml_view_1d(ctx0, lctx.inp_KQ_slope, n_head_kv, 0);
cb(KQ_slope, "KQ_slope", -1);
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * attn_norm; struct ggml_tensor * attn_norm;
@ -6111,7 +6103,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, KQ_slope, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -6233,7 +6225,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -6348,7 +6340,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -6469,7 +6461,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -6596,7 +6588,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f, cb, il); Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -6699,7 +6691,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
struct ggml_tensor * sa_out = cur; struct ggml_tensor * sa_out = cur;
@ -6798,7 +6790,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -6907,7 +6899,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -7025,7 +7017,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -7144,7 +7136,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -7276,7 +7268,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo, model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -7507,6 +7499,32 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
} }
} }
// using Alibi bias
if (hparams.f_max_alibi_bias > 0.0f) {
const uint32_t n_head_kv = hparams.n_head_kv;
const float max_bias = hparams.f_max_alibi_bias;
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_slope->buffer));
float * data = (float *) lctx.inp_KQ_slope->data;
// TODO: is this supposed to be ceil instead of floor?
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
for (uint32_t h = 0; h < n_head_kv; ++h) {
if (h < n_head_log2) {
data[h] = powf(m0, h + 1);
} else {
data[h] = powf(m1, 2*(h - n_head_log2) + 1);
}
}
}
{ {
assert(ggml_backend_buffer_is_host(lctx.inp_sum->buffer)); assert(ggml_backend_buffer_is_host(lctx.inp_sum->buffer));
float * data = (float *) lctx.inp_sum->data; float * data = (float *) lctx.inp_sum->data;
@ -11412,25 +11430,27 @@ struct llama_context * llama_new_context_with_model(
// graph inputs // graph inputs
{ {
ggml_init_params init_params = { ggml_init_params init_params = {
/* .mem_size */ ggml_tensor_overhead()*7, /* .mem_size */ ggml_tensor_overhead()*8,
/* .mem_buffer */ nullptr, /* .mem_buffer */ nullptr,
/* .no_alloc */ true, /* .no_alloc */ true,
}; };
ctx->ctx_input = ggml_init(init_params); ctx->ctx_input = ggml_init(init_params);
ctx->inp_tokens = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); ctx->inp_tokens = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch); ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch);
ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch); ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx); ctx->inp_KQ_slope = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_head_kv);
ctx->inp_sum = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch); ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx);
ctx->inp_sum = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
ggml_set_name(ctx->inp_tokens, "inp_tokens"); ggml_set_name(ctx->inp_tokens, "inp_tokens");
ggml_set_name(ctx->inp_embd, "inp_embd"); ggml_set_name(ctx->inp_embd, "inp_embd");
ggml_set_name(ctx->inp_pos, "inp_pos"); ggml_set_name(ctx->inp_pos, "inp_pos");
ggml_set_name(ctx->inp_KQ_mask, "inp_KQ_mask"); ggml_set_name(ctx->inp_KQ_mask, "inp_KQ_mask");
ggml_set_name(ctx->inp_K_shift, "inp_K_shift"); ggml_set_name(ctx->inp_KQ_slope, "inp_KQ_slope");
ggml_set_name(ctx->inp_sum, "inp_sum"); ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
ggml_set_name(ctx->inp_sum, "inp_sum");
ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true)); ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));

View file

@ -1085,24 +1085,28 @@ struct test_diag_mask_inf : public test_case {
struct test_soft_max : public test_case { struct test_soft_max : public test_case {
const ggml_type type; const ggml_type type;
const std::array<int64_t, 4> ne; const std::array<int64_t, 4> ne;
const float scale;
const bool mask; const bool mask;
const bool slope;
const float scale;
std::string vars() override { std::string vars() override {
return VARS_TO_STR4(type, ne, scale, mask); return VARS_TO_STR5(type, ne, mask, slope, scale);
} }
test_soft_max(ggml_type type = GGML_TYPE_F32, test_soft_max(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 10, 10}, std::array<int64_t, 4> ne = {10, 10, 10, 10},
float scale = 1.0f, bool mask = false,
bool mask = false) bool slope = false,
: type(type), ne(ne), scale(scale), mask(mask) {} float scale = 1.0f)
: type(type), ne(ne), mask(mask), slope(slope), scale(scale) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * b = nullptr; ggml_tensor * b = nullptr;
if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); } ggml_tensor * c = nullptr;
ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, scale); if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); }
if (slope) { c = ggml_new_tensor_1d(ctx, type, ne[2]); }
ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, c, scale);
return out; return out;
} }
}; };
@ -1488,7 +1492,7 @@ struct test_moe : public test_case {
ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens); ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur); ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur);
ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, 1.0f/sqrtf(n_embd)); ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, nullptr, 1.0f/sqrtf(n_embd));
// select experts // select experts
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok); ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok);
@ -1636,7 +1640,7 @@ public:
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); kq = ggml_soft_max_ext(ctx, kq, kq_mask, nullptr, kq_scale);
// split cached v into n_head heads // split cached v into n_head heads
struct ggml_tensor * v = struct ggml_tensor * v =
@ -2091,14 +2095,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (int n = 0; n < 10; ++n) { for (int n = 0; n < 10; ++n) {
int64_t ne0 = dist_ne0(rng); int64_t ne0 = dist_ne0(rng);
int64_t ne1 = dist_ne1(rng); int64_t ne1 = dist_ne1(rng);
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1})); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, n/3 == 0, 0.1f));
} }
exponent <<= 1; exponent <<= 1;
} }
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, 0.1f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, false, 0.1f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, 0.1f, true)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, false, 0.1f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, true, 0.1f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, 0.1f));
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B