POC: combined scale + diagonal mask infinity + soft max op

This commit is contained in:
Iwan Kawrakow 2023-09-11 13:03:36 +02:00
parent f31b6f4e2d
commit 76a0c903e9
5 changed files with 148 additions and 24 deletions

View file

@ -66,6 +66,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(soft_max_4);
GGML_METAL_DECL_KERNEL(diag_mask_inf);
GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
GGML_METAL_DECL_KERNEL(scale_diag_inf_soft_max);
GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
@ -224,6 +225,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(soft_max_4);
GGML_METAL_ADD_KERNEL(diag_mask_inf);
GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
GGML_METAL_ADD_KERNEL(scale_diag_inf_soft_max);
GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
@ -294,6 +296,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(soft_max);
GGML_METAL_DEL_KERNEL(soft_max_4);
GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
GGML_METAL_DEL_KERNEL(diag_mask_inf);
GGML_METAL_DEL_KERNEL(scale_diag_inf_soft_max);
GGML_METAL_DEL_KERNEL(get_rows_f16);
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
@ -817,6 +821,23 @@ void ggml_metal_graph_compute(
GGML_ASSERT(false);
}
} break;
case GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX:
{
const float scale = ((float *)(dst->op_params))[0];
const int n_past = ((int32_t *)(dst->op_params))[1];
const int nth = 32;
[encoder setComputePipelineState:ctx->pipeline_scale_diag_inf_soft_max];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&scale length:sizeof(float) atIndex:5];
[encoder setBytes:&n_past length:sizeof(int) atIndex:6];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_SOFT_MAX:
{
const int nth = 32;

View file

@ -182,6 +182,48 @@ kernel void kernel_soft_max_4(
}
}
kernel void kernel_scale_diag_inf_soft_max(
device const float * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant float & scale,
constant int & n_past,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
// parallel max
float lmax = psrc0[tpitg[0]];
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
lmax = MAX(lmax, psrc0[i00]);
}
const float max = simd_max(lmax) * scale;
// parallel sum
float lsum = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
const float exp_psrc0 = i00 > n_past + i01 ? 0.f : exp(scale*psrc0[i00] - max);
lsum += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not
// whish to compute it twice.
pdst[i00] = exp_psrc0;
}
const float sum = simd_sum(lsum);
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
pdst[i00] /= sum;
}
}
kernel void kernel_diag_mask_inf(
device const float * src0,
device float * dst,

45
ggml.c
View file

@ -4001,7 +4001,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};
static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
static_assert(GGML_OP_COUNT == 69, "GGML_OP_COUNT != 69");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -4083,7 +4083,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};
static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
static_assert(GGML_OP_COUNT == 69, "GGML_OP_COUNT != 69");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -6952,6 +6952,32 @@ struct ggml_tensor * ggml_soft_max_back_inplace(
return ggml_soft_max_back_impl(ctx, a, b, true);
}
struct ggml_tensor * ggml_scale_diag_mask_inf_softmax_inplace(
struct ggml_context * ctx,
float scale,
int n_past,
struct ggml_tensor * a) {
//bool is_node = false;
//if (a->grad) {
// is_node = true;
//}
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
int32_t params[2];
memcpy(&params[0], &scale, sizeof(scale));
params[1] = n_past;
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX;
//result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->grad = NULL;
result->src[0] = a;
return result;
}
// ggml_rope
static struct ggml_tensor * ggml_rope_impl(
@ -15993,6 +16019,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
// nop
} break;
case GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX:
{
fprintf(stderr, "GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX not implemented\n");
GGML_ASSERT(false);
} break;
case GGML_OP_COUNT:
{
GGML_ASSERT(false);
@ -16861,6 +16892,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
// nop
} break;
case GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX:
{
fprintf(stderr, "GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX not implemented\n");
GGML_ASSERT(false);
} break;
case GGML_OP_COUNT:
{
GGML_ASSERT(false);
@ -17698,6 +17734,11 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
{
n_tasks = 1;
} break;
case GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX:
{
fprintf(stderr, "GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX not implemented\n");
GGML_ASSERT(false);
} break;
case GGML_OP_COUNT:
{
GGML_ASSERT(false);

8
ggml.h
View file

@ -415,6 +415,8 @@ extern "C" {
GGML_OP_CROSS_ENTROPY_LOSS,
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX,
GGML_OP_COUNT,
};
@ -1209,6 +1211,12 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_scale_diag_mask_inf_softmax_inplace(
struct ggml_context * ctx,
float scale,
int n_past,
struct ggml_tensor * a);
// rotary position embedding
// if mode & 1 == 1, skip n_past elements
// if mode & 2 == 1, GPT-NeoX style

View file

@ -2316,6 +2316,8 @@ static struct ggml_cgraph * llm_build_llama(
}
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
const float kq_scale = 1.0f/sqrtf(float(n_embd)/n_head);
for (int il = 0; il < n_layer; ++il) {
ggml_format_name(inpL, "layer_inp_%d", il);
@ -2405,22 +2407,26 @@ static struct ggml_cgraph * llm_build_llama(
offload_func_kq(KQ);
ggml_set_name(KQ, "KQ");
// KQ_scaled = KQ / sqrt(n_embd_head)
// KQ_scaled shape [n_past + N, N, n_head, 1]
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
offload_func_kq(KQ_scaled);
ggml_set_name(KQ_scaled, "KQ_scaled");
// KQ_masked = mask_past(KQ_scaled)
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
offload_func_kq(KQ_masked);
ggml_set_name(KQ_masked, "KQ_masked");
// KQ = soft_max(KQ_masked)
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
struct ggml_tensor * KQ_soft_max = ggml_scale_diag_mask_inf_softmax_inplace(ctx0, kq_scale, n_past, KQ);
offload_func_v(KQ_soft_max);
ggml_set_name(KQ_soft_max, "KQ_soft_max");
//// KQ_scaled = KQ / sqrt(n_embd_head)
//// KQ_scaled shape [n_past + N, N, n_head, 1]
//struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
//offload_func_kq(KQ_scaled);
//ggml_set_name(KQ_scaled, "KQ_scaled");
//// KQ_masked = mask_past(KQ_scaled)
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
//offload_func_kq(KQ_masked);
//ggml_set_name(KQ_masked, "KQ_masked");
//// KQ = soft_max(KQ_masked)
//struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
//offload_func_v(KQ_soft_max);
//ggml_set_name(KQ_soft_max, "KQ_soft_max");
// split cached V into n_head heads
struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v,
@ -2647,6 +2653,8 @@ static struct ggml_cgraph * llm_build_falcon(
}
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
const float kq_scale = 1.0f/sqrtf(float(n_embd)/n_head);
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * attn_norm;
@ -2764,18 +2772,22 @@ static struct ggml_cgraph * llm_build_falcon(
offload_func_kq(KQ);
ggml_set_name(KQ, "KQ");
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
offload_func_kq(KQ_scaled);
ggml_set_name(KQ_scaled, "KQ_scaled");
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
offload_func_kq(KQ_masked);
ggml_set_name(KQ_masked, "KQ_masked");
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
struct ggml_tensor * KQ_soft_max = ggml_scale_diag_mask_inf_softmax_inplace(ctx0, kq_scale, n_past, KQ);
offload_func_v(KQ_soft_max);
ggml_set_name(KQ_soft_max, "KQ_soft_max");
//struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
//offload_func_kq(KQ_scaled);
//ggml_set_name(KQ_scaled, "KQ_scaled");
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
//offload_func_kq(KQ_masked);
//ggml_set_name(KQ_masked, "KQ_masked");
//struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
//offload_func_v(KQ_soft_max);
//ggml_set_name(KQ_soft_max, "KQ_soft_max");
struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v,
n_past + N, n_embd_head, n_head_kv,