diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index aafd29850..b3ca984b4 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -764,9 +764,10 @@ static void ggml_vk_gelu(Args&&... args) { static void ggml_vk_soft_max( kp::Sequence& seq, - const std::shared_ptr& in, + const std::shared_ptr& inA, + const std::shared_ptr& inB, const std::shared_ptr& out, - uint32_t inOff, uint32_t outOff, + uint32_t inAOff, uint32_t inBOff, uint32_t outOff, int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03, float scale ) { @@ -774,22 +775,27 @@ static void ggml_vk_soft_max( kp::shader_data::op_softmax_comp_spv_len); struct PushConstants { - uint32_t inOff, outOff; + uint32_t inAOff, inBOff, outOff; int32_t ne00, ne01, ne02; float scale; + int32_t mask; } pushConsts { - safe_divide(inOff, 4), safe_divide(outOff, 4), - ne00, ne01, ne02, scale + safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4), + ne00, ne01, ne02, + scale, + bool(inB) }; + auto & inB_ = inB ? inB : inA; + std::shared_ptr s_algo = nullptr; if (!komputeManager()->hasAlgorithm(__func__)) { // FIXME: The softmax kernel needs to be fixed to use the subgroupsize which can vary by device const uint32_t local_x = 32; - s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {local_x}, {pushConsts}); + s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB_, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {local_x}, {pushConsts}); } else { s_algo = komputeManager()->getAlgorithm(__func__); - s_algo->setTensors({in, out}); + s_algo->setTensors({inA, inB_, out}); s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)}); s_algo->setPushConstants({pushConsts}); s_algo->updateDescriptors(s_kompute_context->pool.get()); @@ -1552,7 +1558,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph case GGML_OP_SOFT_MAX: { const float scale = ((float *) dst->op_params)[0]; - ggml_vk_soft_max(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, scale); + ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale); } break; case GGML_OP_DIAG_MASK_INF: { diff --git a/kompute-shaders/op_softmax.comp b/kompute-shaders/op_softmax.comp index fea371788..7bc9176ca 100644 --- a/kompute-shaders/op_softmax.comp +++ b/kompute-shaders/op_softmax.comp @@ -6,16 +6,19 @@ layout(local_size_x_id = 0) in; -layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; }; -layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; }; +layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; }; +layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; }; +layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; }; layout(push_constant) uniform PushConstants { - uint inOff; + uint inAOff; + uint inBOff; uint outOff; int ne00; int ne01; int ne02; float scale; + int mask; } pcs; void main() { @@ -27,20 +30,21 @@ void main() { const uint i01 = gl_WorkGroupID.x; const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00; - const uint psrc0 = extra_off + pcs.inOff; // Based from in_ + const uint psrc0 = extra_off + pcs.inAOff; // Based from inA + const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB const uint pdst = extra_off + pcs.outOff; // Based from out_ // parallel max float localMax = uintBitsToFloat(0xFF800000); for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { - localMax = max(localMax, in_[psrc0 + i00]*pcs.scale); + localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f)); } float max_ = subgroupMax(localMax); // parallel sum float localSum = 0.0f; for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { - const float exp_psrc0 = exp(in_[psrc0 + i00]*pcs.scale - max_); + const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f) - max_); localSum += exp_psrc0; out_[pdst + i00] = exp_psrc0; }