kompute : support scale parameter of softmax

This commit is contained in:
Jared Van Bortel 2024-01-24 16:16:58 -05:00
parent 1450966071
commit 308f279622
2 changed files with 15 additions and 10 deletions

View file

@ -762,21 +762,24 @@ static void ggml_vk_gelu(Args&&... args) {
ggml_vk_xxlu(spirv, "gelu", std::forward<Args>(args)...);
}
static void ggml_vk_soft_max(kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& in,
const std::shared_ptr<kp::Tensor>& out,
uint32_t inOff, uint32_t outOff,
int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03) {
static void ggml_vk_soft_max(
kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& in,
const std::shared_ptr<kp::Tensor>& out,
uint32_t inOff, uint32_t outOff,
int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
float scale
) {
const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
kp::shader_data::op_softmax_comp_spv_len);
struct PushConstants {
uint32_t inOff, outOff;
int32_t ne00, ne01, ne02;
float scale;
} pushConsts {
safe_divide(inOff, 4), safe_divide(outOff, 4),
ne00, ne01, ne02
ne00, ne01, ne02, scale
};
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
@ -1548,7 +1551,8 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
} break;
case GGML_OP_SOFT_MAX:
{
ggml_vk_soft_max(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03);
const float scale = ((float *) dst->op_params)[0];
ggml_vk_soft_max(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, scale);
} break;
case GGML_OP_DIAG_MASK_INF:
{

View file

@ -15,6 +15,7 @@ layout(push_constant) uniform PushConstants {
int ne00;
int ne01;
int ne02;
float scale;
} pcs;
void main() {
@ -32,14 +33,14 @@ void main() {
// parallel max
float localMax = uintBitsToFloat(0xFF800000);
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
localMax = max(localMax, in_[psrc0 + i00]);
localMax = max(localMax, in_[psrc0 + i00]*pcs.scale);
}
float max_ = subgroupMax(localMax);
// parallel sum
float localSum = 0.0f;
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
const float exp_psrc0 = exp(in_[psrc0 + i00] - max_);
const float exp_psrc0 = exp(in_[psrc0 + i00]*pcs.scale - max_);
localSum += exp_psrc0;
out_[pdst + i00] = exp_psrc0;
}