POC: combined scale + diagonal mask infinity + soft max op
This commit is contained in:
parent
f31b6f4e2d
commit
76a0c903e9
5 changed files with 148 additions and 24 deletions
21
ggml-metal.m
21
ggml-metal.m
|
@ -66,6 +66,7 @@ struct ggml_metal_context {
|
|||
GGML_METAL_DECL_KERNEL(soft_max_4);
|
||||
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
||||
GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
|
||||
GGML_METAL_DECL_KERNEL(scale_diag_inf_soft_max);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
||||
|
@ -224,6 +225,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(soft_max_4);
|
||||
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
||||
GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
|
||||
GGML_METAL_ADD_KERNEL(scale_diag_inf_soft_max);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
||||
|
@ -294,6 +296,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||
GGML_METAL_DEL_KERNEL(soft_max);
|
||||
GGML_METAL_DEL_KERNEL(soft_max_4);
|
||||
GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
|
||||
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
||||
GGML_METAL_DEL_KERNEL(scale_diag_inf_soft_max);
|
||||
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
||||
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
||||
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
||||
|
@ -817,6 +821,23 @@ void ggml_metal_graph_compute(
|
|||
GGML_ASSERT(false);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX:
|
||||
{
|
||||
const float scale = ((float *)(dst->op_params))[0];
|
||||
const int n_past = ((int32_t *)(dst->op_params))[1];
|
||||
const int nth = 32;
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_scale_diag_inf_soft_max];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||
[encoder setBytes:&scale length:sizeof(float) atIndex:5];
|
||||
[encoder setBytes:&n_past length:sizeof(int) atIndex:6];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
const int nth = 32;
|
||||
|
|
|
@ -182,6 +182,48 @@ kernel void kernel_soft_max_4(
|
|||
}
|
||||
}
|
||||
|
||||
kernel void kernel_scale_diag_inf_soft_max(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant float & scale,
|
||||
constant int & n_past,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig[2];
|
||||
const int64_t i02 = tgpig[1];
|
||||
const int64_t i01 = tgpig[0];
|
||||
|
||||
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
// parallel max
|
||||
float lmax = psrc0[tpitg[0]];
|
||||
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||
lmax = MAX(lmax, psrc0[i00]);
|
||||
}
|
||||
const float max = simd_max(lmax) * scale;
|
||||
|
||||
// parallel sum
|
||||
float lsum = 0.0f;
|
||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||
const float exp_psrc0 = i00 > n_past + i01 ? 0.f : exp(scale*psrc0[i00] - max);
|
||||
lsum += exp_psrc0;
|
||||
// Remember the result of exp here. exp is expensive, so we really do not
|
||||
// whish to compute it twice.
|
||||
pdst[i00] = exp_psrc0;
|
||||
}
|
||||
|
||||
const float sum = simd_sum(lsum);
|
||||
|
||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||
pdst[i00] /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_diag_mask_inf(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
|
|
45
ggml.c
45
ggml.c
|
@ -4001,7 +4001,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"CROSS_ENTROPY_LOSS_BACK",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
|
||||
static_assert(GGML_OP_COUNT == 69, "GGML_OP_COUNT != 69");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
|
@ -4083,7 +4083,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"cross_entropy_loss_back(x,y)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
|
||||
static_assert(GGML_OP_COUNT == 69, "GGML_OP_COUNT != 69");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
|
@ -6952,6 +6952,32 @@ struct ggml_tensor * ggml_soft_max_back_inplace(
|
|||
return ggml_soft_max_back_impl(ctx, a, b, true);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_scale_diag_mask_inf_softmax_inplace(
|
||||
struct ggml_context * ctx,
|
||||
float scale,
|
||||
int n_past,
|
||||
struct ggml_tensor * a) {
|
||||
//bool is_node = false;
|
||||
|
||||
//if (a->grad) {
|
||||
// is_node = true;
|
||||
//}
|
||||
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
||||
|
||||
int32_t params[2];
|
||||
memcpy(¶ms[0], &scale, sizeof(scale));
|
||||
params[1] = n_past;
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX;
|
||||
//result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->grad = NULL;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_rope
|
||||
|
||||
static struct ggml_tensor * ggml_rope_impl(
|
||||
|
@ -15993,6 +16019,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
// nop
|
||||
} break;
|
||||
case GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX:
|
||||
{
|
||||
fprintf(stderr, "GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX not implemented\n");
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
case GGML_OP_COUNT:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
|
@ -16861,6 +16892,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
{
|
||||
// nop
|
||||
} break;
|
||||
case GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX:
|
||||
{
|
||||
fprintf(stderr, "GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX not implemented\n");
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
case GGML_OP_COUNT:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
|
@ -17698,6 +17734,11 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
|||
{
|
||||
n_tasks = 1;
|
||||
} break;
|
||||
case GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX:
|
||||
{
|
||||
fprintf(stderr, "GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX not implemented\n");
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
case GGML_OP_COUNT:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
|
|
8
ggml.h
8
ggml.h
|
@ -415,6 +415,8 @@ extern "C" {
|
|||
GGML_OP_CROSS_ENTROPY_LOSS,
|
||||
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
|
||||
|
||||
GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX,
|
||||
|
||||
GGML_OP_COUNT,
|
||||
};
|
||||
|
||||
|
@ -1209,6 +1211,12 @@ extern "C" {
|
|||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_scale_diag_mask_inf_softmax_inplace(
|
||||
struct ggml_context * ctx,
|
||||
float scale,
|
||||
int n_past,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// rotary position embedding
|
||||
// if mode & 1 == 1, skip n_past elements
|
||||
// if mode & 2 == 1, GPT-NeoX style
|
||||
|
|
56
llama.cpp
56
llama.cpp
|
@ -2316,6 +2316,8 @@ static struct ggml_cgraph * llm_build_llama(
|
|||
}
|
||||
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
|
||||
|
||||
const float kq_scale = 1.0f/sqrtf(float(n_embd)/n_head);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_format_name(inpL, "layer_inp_%d", il);
|
||||
|
||||
|
@ -2405,22 +2407,26 @@ static struct ggml_cgraph * llm_build_llama(
|
|||
offload_func_kq(KQ);
|
||||
ggml_set_name(KQ, "KQ");
|
||||
|
||||
// KQ_scaled = KQ / sqrt(n_embd_head)
|
||||
// KQ_scaled shape [n_past + N, N, n_head, 1]
|
||||
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
|
||||
offload_func_kq(KQ_scaled);
|
||||
ggml_set_name(KQ_scaled, "KQ_scaled");
|
||||
|
||||
// KQ_masked = mask_past(KQ_scaled)
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
|
||||
offload_func_kq(KQ_masked);
|
||||
ggml_set_name(KQ_masked, "KQ_masked");
|
||||
|
||||
// KQ = soft_max(KQ_masked)
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
|
||||
struct ggml_tensor * KQ_soft_max = ggml_scale_diag_mask_inf_softmax_inplace(ctx0, kq_scale, n_past, KQ);
|
||||
offload_func_v(KQ_soft_max);
|
||||
ggml_set_name(KQ_soft_max, "KQ_soft_max");
|
||||
|
||||
//// KQ_scaled = KQ / sqrt(n_embd_head)
|
||||
//// KQ_scaled shape [n_past + N, N, n_head, 1]
|
||||
//struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
|
||||
//offload_func_kq(KQ_scaled);
|
||||
//ggml_set_name(KQ_scaled, "KQ_scaled");
|
||||
|
||||
//// KQ_masked = mask_past(KQ_scaled)
|
||||
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
|
||||
//offload_func_kq(KQ_masked);
|
||||
//ggml_set_name(KQ_masked, "KQ_masked");
|
||||
|
||||
//// KQ = soft_max(KQ_masked)
|
||||
//struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
|
||||
//offload_func_v(KQ_soft_max);
|
||||
//ggml_set_name(KQ_soft_max, "KQ_soft_max");
|
||||
|
||||
// split cached V into n_head heads
|
||||
struct ggml_tensor * V =
|
||||
ggml_view_3d(ctx0, kv_self.v,
|
||||
|
@ -2647,6 +2653,8 @@ static struct ggml_cgraph * llm_build_falcon(
|
|||
}
|
||||
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
|
||||
|
||||
const float kq_scale = 1.0f/sqrtf(float(n_embd)/n_head);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * attn_norm;
|
||||
|
||||
|
@ -2764,18 +2772,22 @@ static struct ggml_cgraph * llm_build_falcon(
|
|||
offload_func_kq(KQ);
|
||||
ggml_set_name(KQ, "KQ");
|
||||
|
||||
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
|
||||
offload_func_kq(KQ_scaled);
|
||||
ggml_set_name(KQ_scaled, "KQ_scaled");
|
||||
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
|
||||
offload_func_kq(KQ_masked);
|
||||
ggml_set_name(KQ_masked, "KQ_masked");
|
||||
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
|
||||
struct ggml_tensor * KQ_soft_max = ggml_scale_diag_mask_inf_softmax_inplace(ctx0, kq_scale, n_past, KQ);
|
||||
offload_func_v(KQ_soft_max);
|
||||
ggml_set_name(KQ_soft_max, "KQ_soft_max");
|
||||
|
||||
//struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
|
||||
//offload_func_kq(KQ_scaled);
|
||||
//ggml_set_name(KQ_scaled, "KQ_scaled");
|
||||
|
||||
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
|
||||
//offload_func_kq(KQ_masked);
|
||||
//ggml_set_name(KQ_masked, "KQ_masked");
|
||||
|
||||
//struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
|
||||
//offload_func_v(KQ_soft_max);
|
||||
//ggml_set_name(KQ_soft_max, "KQ_soft_max");
|
||||
|
||||
struct ggml_tensor * V =
|
||||
ggml_view_3d(ctx0, kv_self.v,
|
||||
n_past + N, n_embd_head, n_head_kv,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue