diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index d6b99aa1f..7b92a7bac 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -223,8 +223,8 @@ static const std::string program_source_head = R"(#version 450 #define QR4_0 2 #define QK4_1 32 -#define GELU_COEF_A 0.044715; -#define SQRT_2_OVER_PI 0.79788456080286535587989211986876; +#define GELU_COEF_A 0.044715 +#define SQRT_2_OVER_PI 0.79788456080286535587989211986876 #ifndef QK_K #define QK_K 256 @@ -235,6 +235,12 @@ static const std::string program_source_head = R"(#version 450 #else #define K_SCALE_SIZE 4 #endif + +#define BM 128 +#define BN 128 +#define BK 8 +#define TM 8 +#define TN 8 )"; @@ -651,13 +657,56 @@ void ggml_vk_soft_max(kp::Sequence& seq, } -static const std::string program_mul_mat_f16 = R"( -#define BM 128 -#define BN 128 -#define BK 8 -#define TM 8 -#define TN 8 -)" MULTILINE_QUOTE( +static const std::string program_diag_mask_inf = + MULTILINE_QUOTE( +layout(push_constant) uniform PushConstants { + uint64_t ne00; + uint64_t ne01; + uint inAOff; + uint inBOff; + uint outOff; +} pcs; + +layout(local_size_x = 1) in; +layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; }; +layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; }; +layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; }; + +void main() { + const uint64_t i02 = uint64_t(gl_GlobalInvocationID.z); + const uint64_t i01 = uint64_t(gl_GlobalInvocationID.y); + const uint64_t i00 = uint64_t(gl_GlobalInvocationID.x); + + const int n_past = inB[pcs.inBOff]; + + if (i00 > n_past + i01) { + out_[uint(i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00 + i00 + pcs.outOff)] = uintBitsToFloat(0xFF800000); + } else { + out_[uint(i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00 + i00 + pcs.outOff)] = inA[uint(i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00 + i00 + pcs.inAOff)]; + } +} +); + +void ggml_vk_diag_mask_inf(kp::Sequence& seq, + const std::shared_ptr& inA, uint32_t inAOff, + const std::shared_ptr& inB, uint32_t inBOff, + const std::shared_ptr& out, uint32_t outOff, + int64_t ne00, int64_t ne01, int64_t ne02) { + const static auto spirv = glsl_compile_source(program_source_head+program_diag_mask_inf, __func__); + + struct PushConstants { + int64_t ne00, ne01; + uint32_t inAOff, inBOff, outOff; + } pushConsts { + ne00, ne01, inAOff, inBOff, outOff + }; + + seq.record(mgr.algorithm({inA, inB, out}, spirv, {unsigned(ne00), unsigned(ne01), unsigned(ne02)}, {}, {pushConsts})); +} + + +static const std::string program_mul_mat_f16 = + MULTILINE_QUOTE( layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; }; @@ -800,13 +849,8 @@ void ggml_vk_mul_mat_f16(kp::Sequence& seq, } -static const std::string program_mul_mat_f32 = R"( -#define BM 128 -#define BN 128 -#define BK 8 -#define TM 8 -#define TN 8 -)" MULTILINE_QUOTE( +static const std::string program_mul_mat_f32 = + MULTILINE_QUOTE( layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer tensorInA { float inA[]; }; @@ -1041,14 +1085,18 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph { ggml_vk_soft_max(seq, id_src0, offs_src0, id_dst, offs_dst, ne00, ne01, ne02, ne03); } break; + case GGML_OP_DIAG_MASK_INF: + { + ggml_vk_diag_mask_inf(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02); + } break; case GGML_OP_MUL_MAT: { if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { ggml_vk_mul_mat_f32(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02, ne03, ne10, ne11, nb2, nb3); break; - } else if (src0->type == GGML_TYPE_F32 - && src1->type == GGML_TYPE_F16) { + } else if (src0->type == GGML_TYPE_F16 + && src1->type == GGML_TYPE_F32) { ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02, ne03, ne10, ne11, nb10, nb11, nb12, nb13, nb2, nb3); } }