diff --git a/ggml-metal.m b/ggml-metal.m index 4f3f14e24..9b5c20273 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -66,6 +66,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(soft_max_4); GGML_METAL_DECL_KERNEL(diag_mask_inf); GGML_METAL_DECL_KERNEL(diag_mask_inf_8); + GGML_METAL_DECL_KERNEL(scale_diag_inf_soft_max); GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_1); @@ -224,6 +225,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(soft_max_4); GGML_METAL_ADD_KERNEL(diag_mask_inf); GGML_METAL_ADD_KERNEL(diag_mask_inf_8); + GGML_METAL_ADD_KERNEL(scale_diag_inf_soft_max); GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_1); @@ -294,6 +296,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(soft_max); GGML_METAL_DEL_KERNEL(soft_max_4); GGML_METAL_DEL_KERNEL(diag_mask_inf_8); + GGML_METAL_DEL_KERNEL(diag_mask_inf); + GGML_METAL_DEL_KERNEL(scale_diag_inf_soft_max); GGML_METAL_DEL_KERNEL(get_rows_f16); GGML_METAL_DEL_KERNEL(get_rows_q4_0); GGML_METAL_DEL_KERNEL(get_rows_q4_1); @@ -817,6 +821,23 @@ void ggml_metal_graph_compute( GGML_ASSERT(false); } } break; + case GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX: + { + const float scale = ((float *)(dst->op_params))[0]; + const int n_past = ((int32_t *)(dst->op_params))[1]; + const int nth = 32; + + [encoder setComputePipelineState:ctx->pipeline_scale_diag_inf_soft_max]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&scale length:sizeof(float) atIndex:5]; + [encoder setBytes:&n_past length:sizeof(int) atIndex:6]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_SOFT_MAX: { const int nth = 32; diff --git a/ggml-metal.metal b/ggml-metal.metal index f45b1490f..c586181a9 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -182,6 +182,48 @@ kernel void kernel_soft_max_4( } } +kernel void kernel_scale_diag_inf_soft_max( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + constant int & n_past, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + // parallel max + float lmax = psrc0[tpitg[0]]; + for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) { + lmax = MAX(lmax, psrc0[i00]); + } + const float max = simd_max(lmax) * scale; + + // parallel sum + float lsum = 0.0f; + for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { + const float exp_psrc0 = i00 > n_past + i01 ? 0.f : exp(scale*psrc0[i00] - max); + lsum += exp_psrc0; + // Remember the result of exp here. exp is expensive, so we really do not + // whish to compute it twice. + pdst[i00] = exp_psrc0; + } + + const float sum = simd_sum(lsum); + + for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { + pdst[i00] /= sum; + } +} + kernel void kernel_diag_mask_inf( device const float * src0, device float * dst, diff --git a/ggml.c b/ggml.c index 3f72379c3..2539edb81 100644 --- a/ggml.c +++ b/ggml.c @@ -4001,7 +4001,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68"); +static_assert(GGML_OP_COUNT == 69, "GGML_OP_COUNT != 69"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -4083,7 +4083,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68"); +static_assert(GGML_OP_COUNT == 69, "GGML_OP_COUNT != 69"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6952,6 +6952,32 @@ struct ggml_tensor * ggml_soft_max_back_inplace( return ggml_soft_max_back_impl(ctx, a, b, true); } +struct ggml_tensor * ggml_scale_diag_mask_inf_softmax_inplace( + struct ggml_context * ctx, + float scale, + int n_past, + struct ggml_tensor * a) { + //bool is_node = false; + + //if (a->grad) { + // is_node = true; + //} + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + int32_t params[2]; + memcpy(¶ms[0], &scale, sizeof(scale)); + params[1] = n_past; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX; + //result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->grad = NULL; + result->src[0] = a; + + return result; +} + // ggml_rope static struct ggml_tensor * ggml_rope_impl( @@ -15993,6 +16019,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { // nop } break; + case GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX: + { + fprintf(stderr, "GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX not implemented\n"); + GGML_ASSERT(false); + } break; case GGML_OP_COUNT: { GGML_ASSERT(false); @@ -16861,6 +16892,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // nop } break; + case GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX: + { + fprintf(stderr, "GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX not implemented\n"); + GGML_ASSERT(false); + } break; case GGML_OP_COUNT: { GGML_ASSERT(false); @@ -17698,6 +17734,11 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { { n_tasks = 1; } break; + case GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX: + { + fprintf(stderr, "GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX not implemented\n"); + GGML_ASSERT(false); + } break; case GGML_OP_COUNT: { GGML_ASSERT(false); diff --git a/ggml.h b/ggml.h index c936823d6..4f60c133a 100644 --- a/ggml.h +++ b/ggml.h @@ -415,6 +415,8 @@ extern "C" { GGML_OP_CROSS_ENTROPY_LOSS, GGML_OP_CROSS_ENTROPY_LOSS_BACK, + GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX, + GGML_OP_COUNT, }; @@ -1209,6 +1211,12 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_scale_diag_mask_inf_softmax_inplace( + struct ggml_context * ctx, + float scale, + int n_past, + struct ggml_tensor * a); + // rotary position embedding // if mode & 1 == 1, skip n_past elements // if mode & 2 == 1, GPT-NeoX style diff --git a/llama.cpp b/llama.cpp index 2a2a0c9c6..edaf8a461 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2316,6 +2316,8 @@ static struct ggml_cgraph * llm_build_llama( } ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + const float kq_scale = 1.0f/sqrtf(float(n_embd)/n_head); + for (int il = 0; il < n_layer; ++il) { ggml_format_name(inpL, "layer_inp_%d", il); @@ -2405,22 +2407,26 @@ static struct ggml_cgraph * llm_build_llama( offload_func_kq(KQ); ggml_set_name(KQ, "KQ"); - // KQ_scaled = KQ / sqrt(n_embd_head) - // KQ_scaled shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); - offload_func_kq(KQ_scaled); - ggml_set_name(KQ_scaled, "KQ_scaled"); - - // KQ_masked = mask_past(KQ_scaled) - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); - offload_func_kq(KQ_masked); - ggml_set_name(KQ_masked, "KQ_masked"); - - // KQ = soft_max(KQ_masked) - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_scale_diag_mask_inf_softmax_inplace(ctx0, kq_scale, n_past, KQ); offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); + //// KQ_scaled = KQ / sqrt(n_embd_head) + //// KQ_scaled shape [n_past + N, N, n_head, 1] + //struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + //offload_func_kq(KQ_scaled); + //ggml_set_name(KQ_scaled, "KQ_scaled"); + + //// KQ_masked = mask_past(KQ_scaled) + //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + //offload_func_kq(KQ_masked); + //ggml_set_name(KQ_masked, "KQ_masked"); + + //// KQ = soft_max(KQ_masked) + //struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + //offload_func_v(KQ_soft_max); + //ggml_set_name(KQ_soft_max, "KQ_soft_max"); + // split cached V into n_head heads struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, @@ -2647,6 +2653,8 @@ static struct ggml_cgraph * llm_build_falcon( } ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + const float kq_scale = 1.0f/sqrtf(float(n_embd)/n_head); + for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * attn_norm; @@ -2764,18 +2772,22 @@ static struct ggml_cgraph * llm_build_falcon( offload_func_kq(KQ); ggml_set_name(KQ, "KQ"); - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); - offload_func_kq(KQ_scaled); - ggml_set_name(KQ_scaled, "KQ_scaled"); - - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); - offload_func_kq(KQ_masked); - ggml_set_name(KQ_masked, "KQ_masked"); - - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_scale_diag_mask_inf_softmax_inplace(ctx0, kq_scale, n_past, KQ); offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); + //struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + //offload_func_kq(KQ_scaled); + //ggml_set_name(KQ_scaled, "KQ_scaled"); + + //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + //offload_func_kq(KQ_masked); + //ggml_set_name(KQ_masked, "KQ_masked"); + + //struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + //offload_func_v(KQ_soft_max); + //ggml_set_name(KQ_soft_max, "KQ_soft_max"); + struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, n_past + N, n_embd_head, n_head_kv,