diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index b8aae17fd..aafd29850 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -762,21 +762,24 @@ static void ggml_vk_gelu(Args&&... args) { ggml_vk_xxlu(spirv, "gelu", std::forward(args)...); } -static void ggml_vk_soft_max(kp::Sequence& seq, - const std::shared_ptr& in, - const std::shared_ptr& out, - uint32_t inOff, uint32_t outOff, - int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03) { - +static void ggml_vk_soft_max( + kp::Sequence& seq, + const std::shared_ptr& in, + const std::shared_ptr& out, + uint32_t inOff, uint32_t outOff, + int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03, + float scale +) { const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv, kp::shader_data::op_softmax_comp_spv_len); struct PushConstants { uint32_t inOff, outOff; int32_t ne00, ne01, ne02; + float scale; } pushConsts { safe_divide(inOff, 4), safe_divide(outOff, 4), - ne00, ne01, ne02 + ne00, ne01, ne02, scale }; std::shared_ptr s_algo = nullptr; @@ -1548,7 +1551,8 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph } break; case GGML_OP_SOFT_MAX: { - ggml_vk_soft_max(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03); + const float scale = ((float *) dst->op_params)[0]; + ggml_vk_soft_max(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, scale); } break; case GGML_OP_DIAG_MASK_INF: { diff --git a/kompute-shaders/op_softmax.comp b/kompute-shaders/op_softmax.comp index 89de1b701..fea371788 100644 --- a/kompute-shaders/op_softmax.comp +++ b/kompute-shaders/op_softmax.comp @@ -15,6 +15,7 @@ layout(push_constant) uniform PushConstants { int ne00; int ne01; int ne02; + float scale; } pcs; void main() { @@ -32,14 +33,14 @@ void main() { // parallel max float localMax = uintBitsToFloat(0xFF800000); for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { - localMax = max(localMax, in_[psrc0 + i00]); + localMax = max(localMax, in_[psrc0 + i00]*pcs.scale); } float max_ = subgroupMax(localMax); // parallel sum float localSum = 0.0f; for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { - const float exp_psrc0 = exp(in_[psrc0 + i00] - max_); + const float exp_psrc0 = exp(in_[psrc0 + i00]*pcs.scale - max_); localSum += exp_psrc0; out_[pdst + i00] = exp_psrc0; }