Implemented ggml_vk_soft_max

This commit is contained in:
niansa 2023-06-28 12:48:41 +02:00
parent e2b721db65
commit de7d1823ed

View file

@ -212,15 +212,28 @@ std::vector<uint8_t> getVecBlockQ4_0QS(T *x, unsigned nb, unsigned qk) {
};
static const std::string program_source_head = R"(
#version 450
static const std::string program_source_head = R"(#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_float16: enable
#extension GL_EXT_shader_explicit_arithmetic_types_int8: enable
#extension GL_EXT_shader_explicit_arithmetic_types_int64: enable
#define QK4_0 32
#define QR4_0 2
#define QK4_1 32
#define GELU_COEF_A 0.044715;
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876;
#ifndef QK_K
#define QK_K 256
#endif
#if QK_K == 256
#define K_SCALE_SIZE 12
#else
#define K_SCALE_SIZE 4
#endif
)";
@ -366,16 +379,6 @@ void ggml_vk_abmath(kp::Sequence& seq,
seq.record<kp::OpAlgoDispatch>(mgr.algorithm<float, PushConstants>({inA, inB, out}, spirv, {size}, {}, {pushConsts}));
}
template <bool with_row = false, typename... Args>
void ggml_vk_add(Args&&... args) {
return ggml_vk_abmath<'+', with_row>(std::forward<Args>(args)...);
}
template <bool with_row = false, typename... Args>
void ggml_vk_mul(Args&&... args) {
return ggml_vk_abmath<'*', with_row>(std::forward<Args>(args)...);
}
static const std::string program_scale =
MULTILINE_QUOTE(
@ -456,8 +459,8 @@ void ggml_vk_silu(Args&&... args) {
static const std::string program_relu =
MULTILINE_QUOTE(
layout(push_constant) uniform PushConstants {
uint inAOff;
uint inOff;
uint outOff;
} pcs;
layout(local_size_x = 1) in;
@ -482,8 +485,8 @@ void ggml_vk_relu(Args&&... args) {
static const std::string program_gelu =
MULTILINE_QUOTE(
layout(push_constant) uniform PushConstants {
uint inAOff;
uint inOff;
uint outOff;
} pcs;
layout(local_size_x = 1) in;
@ -506,6 +509,109 @@ void ggml_vk_gelu(Args&&... args) {
}
static const std::string program_soft_max =
MULTILINE_QUOTE(
layout(push_constant) uniform PushConstants {
uint64_t ne00;
uint64_t ne01;
uint64_t ne02;
uint inOff;
uint outOff;
} pcs;
layout(local_size_x = nth) in;
layout(binding = 0) buffer restrict readonly tensorInA { float in_[]; };
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
shared float buf[nth];
void main() {
const uint64_t i03 = uint64_t(gl_GlobalInvocationID.z);
const uint64_t i02 = uint64_t(gl_GlobalInvocationID.y);
const uint64_t i01 = uint64_t(gl_GlobalInvocationID.x);
const uint extra_off = uint(i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00);
const uint in_off = pcs.inOff + extra_off;
const uint out_off = pcs.outOff + extra_off;
// parallel max
buf[gl_LocalInvocationID.x] = uintBitsToFloat(0xFF800000);
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += nth) {
buf[gl_LocalInvocationID.x] = max(buf[gl_LocalInvocationID.x], in_[in_off + i00]);
}
// reduce
barrier();
memoryBarrierShared();
for (uint i = nth/2; i > 0; i /= 2) {
if (gl_LocalInvocationID.x < i) {
buf[gl_LocalInvocationID.x] = max(buf[gl_LocalInvocationID.x], buf[gl_LocalInvocationID.x + i]);
}
barrier();
memoryBarrierShared();
}
// broadcast (no effect?)
if (gl_LocalInvocationID.x == 0) {
buf[0] = buf[0]; // ???
}
barrier();
memoryBarrierShared();
const float max_ = buf[0];
// parallel sum
buf[gl_LocalInvocationID.x] = 0.0;
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += nth) {
buf[gl_LocalInvocationID.x] += exp(in_[in_off + i00] - max_);
}
// reduce
barrier();
memoryBarrierShared();
for (uint i = nth/2; i > 0; i /= 2) {
if (gl_LocalInvocationID.x < i) {
buf[gl_LocalInvocationID.x] += buf[gl_LocalInvocationID.x + i];
}
barrier();
memoryBarrierShared();
}
// broadcast (no effect?)
if (gl_LocalInvocationID.x == 0) {
buf[0] = buf[0]; // ???
}
barrier();
memoryBarrierShared();
const float sum = buf[0];
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += nth) {
out_[out_off + i00] = exp(in_[in_off + i00] - max_) / sum;
}
}
);
void ggml_vk_soft_max(kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& in, uint32_t inOff,
const std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
int64_t ne00, int64_t ne01, int64_t ne02, uint64_t ne03) {
const static unsigned nth = 32;
const static auto spirv = compileSource(program_source_head+"#define nth "+std::to_string(nth)+"\n"+program_soft_max, __func__);
struct PushConstants {
int64_t ne00, ne01, ne02;
uint32_t inOff, outOff;
} pushConsts {
ne00, ne01, ne02, inOff, outOff
};
seq.record<kp::OpAlgoDispatch>(mgr.algorithm<float, PushConstants>({in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}));
}
void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
printf("%s: evaluating graph\n", __func__);
@ -585,15 +691,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
} break;
case GGML_OP_ADD:
{
ggml_vk_add(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst));
ggml_vk_abmath<'+'>(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst));
} break;
case GGML_OP_MUL:
{
if (ggml_nelements(src1) == ne10) {
// src1 is a row
ggml_vk_mul<true>(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst), ne00);
ggml_vk_abmath<'*', true>(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst), ne00);
} else {
ggml_vk_mul(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst));
ggml_vk_abmath<'*'>(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ggml_nelements(dst));
}
} break;
case GGML_OP_SCALE:
@ -613,6 +719,10 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
{
ggml_vk_gelu(seq, id_src0, offs_src0, id_dst, offs_dst, ggml_nelements(dst));
} break;
case GGML_OP_SOFT_MAX:
{
ggml_vk_soft_max(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ne02, ne03);
} break;
default:
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
//GGML_ASSERT(false);