kompute: softmax: implement ALiBi support

Signed-off-by: Sergio Lopez <slp@redhat.com>
This commit is contained in:
Sergio Lopez 2024-11-20 07:28:25 +01:00
parent 913536f2a5
commit d8889598d6
3 changed files with 34 additions and 9 deletions

View file

@ -788,7 +788,8 @@ static void ggml_vk_soft_max(
const std::shared_ptr<kp::Tensor>& out, const std::shared_ptr<kp::Tensor>& out,
uint32_t inAOff, uint32_t inBOff, uint32_t outOff, uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03, int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
float scale float scale, float max_bias, float m0, float m1,
uint32_t n_head_log2
) { ) {
const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv, const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
kp::shader_data::op_softmax_comp_spv_len); kp::shader_data::op_softmax_comp_spv_len);
@ -796,12 +797,14 @@ static void ggml_vk_soft_max(
struct PushConstants { struct PushConstants {
uint32_t inAOff, inBOff, outOff; uint32_t inAOff, inBOff, outOff;
int32_t ne00, ne01, ne02; int32_t ne00, ne01, ne02;
float scale; float scale, max_bias, m0, m1;
uint32_t n_head_log2;
int32_t mask; int32_t mask;
} pushConsts { } pushConsts {
safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4), safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
ne00, ne01, ne02, ne00, ne01, ne02,
scale, scale, max_bias, m0, m1,
n_head_log2,
bool(inB) bool(inB)
}; };
@ -1597,11 +1600,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32); GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
#pragma message("TODO: add ALiBi support") const int64_t nrows_x = ggml_nrows(src0);
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192") const int64_t nrows_y = src0->ne[1];
GGML_ASSERT(max_bias == 0.0f);
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale); const uint32_t n_head = nrows_x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale, max_bias, m0, m1, n_head_log2);
} break; } break;
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
{ {

View file

@ -3,6 +3,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require #extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#extension GL_EXT_shader_explicit_arithmetic_types_int8: require #extension GL_EXT_shader_explicit_arithmetic_types_int8: require
#extension GL_EXT_shader_explicit_arithmetic_types_int16: require #extension GL_EXT_shader_explicit_arithmetic_types_int16: require
#extension GL_EXT_shader_explicit_arithmetic_types_int64: require
#extension GL_EXT_control_flow_attributes: enable #extension GL_EXT_control_flow_attributes: enable
#extension GL_KHR_shader_subgroup_arithmetic : require #extension GL_KHR_shader_subgroup_arithmetic : require
#extension GL_EXT_debug_printf : enable #extension GL_EXT_debug_printf : enable

View file

@ -18,6 +18,10 @@ layout(push_constant) uniform PushConstants {
int ne01; int ne01;
int ne02; int ne02;
float scale; float scale;
float max_bias;
float m0;
float m1;
uint n_head_log2;
int mask; int mask;
} pcs; } pcs;
@ -34,17 +38,29 @@ void main() {
const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB
const uint pdst = extra_off + pcs.outOff; // Based from out_ const uint pdst = extra_off + pcs.outOff; // Based from out_
float slope = 1.0f;
// ALiBi
if (pcs.max_bias > 0.0f) {
int64_t h = i02;
float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1;
int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1;
slope = pow(base, float(exp));
}
// parallel max // parallel max
float localMax = uintBitsToFloat(0xFF800000); float localMax = uintBitsToFloat(0xFF800000);
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f)); localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f));
} }
float max_ = subgroupMax(localMax); float max_ = subgroupMax(localMax);
// parallel sum // parallel sum
float localSum = 0.0f; float localSum = 0.0f;
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f) - max_); const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_);
localSum += exp_psrc0; localSum += exp_psrc0;
out_[pdst + i00] = exp_psrc0; out_[pdst + i00] = exp_psrc0;
} }