Simple mul_mat_f16 for speed and removal of unused mul_mat_f32
This commit is contained in:
parent
f0e1429d7f
commit
2fc8249ba3
1 changed files with 38 additions and 131 deletions
169
ggml-vulkan.cpp
169
ggml-vulkan.cpp
|
@ -976,6 +976,42 @@ void main() {
|
|||
}
|
||||
);
|
||||
|
||||
static const std::string program_fast_mul_mat_f16 =
|
||||
MULTILINE_QUOTE(
|
||||
layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
|
||||
layout (binding = 1) readonly buffer tensorInB { float16_t inB[]; };
|
||||
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int inAStride;
|
||||
int inBStride;
|
||||
int outStride;
|
||||
uint inAOff;
|
||||
uint inBOff;
|
||||
uint outOff;
|
||||
} pcs;
|
||||
|
||||
void main() {
|
||||
int row = int(gl_GlobalInvocationID.x);
|
||||
int col = int(gl_GlobalInvocationID.y);
|
||||
|
||||
if (row < pcs.M && col < pcs.N) {
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int i = 0; i < pcs.K; i++) {
|
||||
sum += float(inA[row * pcs.inAStride + i]) * float(inB[col * pcs.inBStride + i]);
|
||||
}
|
||||
|
||||
out_[col * pcs.outStride + row] = sum;
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
void ggml_vk_mul_mat_f16(kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff,
|
||||
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
|
||||
|
@ -984,7 +1020,7 @@ void ggml_vk_mul_mat_f16(kp::Sequence& seq,
|
|||
int64_t ne10, int64_t ne11,
|
||||
int nb10, int nb11, int nb12, int nb13,
|
||||
int nb2, int nb3) {
|
||||
const static auto spirv = glsl_compile_source(program_source_head+program_mul_mat_f16, __func__);
|
||||
const static auto spirv = glsl_compile_source(program_source_head+program_fast_mul_mat_f16, __func__);
|
||||
|
||||
const bool inB_cont_rows = nb10 == sizeof(float);
|
||||
const bool inB_cont_cols = (size_t)nb11 == ne11 * sizeof(float);
|
||||
|
@ -1025,131 +1061,6 @@ void ggml_vk_mul_mat_f16(kp::Sequence& seq,
|
|||
}
|
||||
|
||||
|
||||
static const std::string program_mul_mat_f32 =
|
||||
MULTILINE_QUOTE(
|
||||
layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer tensorInA { float inA[]; };
|
||||
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
|
||||
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int inAStride;
|
||||
int inBStride;
|
||||
int outStride;
|
||||
uint inAOff;
|
||||
uint inBOff;
|
||||
uint outOff;
|
||||
} pcs;
|
||||
|
||||
shared float bufA[BM * (BK+1)];
|
||||
shared float bufB[BN * (BK+1)];
|
||||
|
||||
void main() {
|
||||
const int ir = int(gl_WorkGroupID.x);
|
||||
const int ic = int(gl_WorkGroupID.y);
|
||||
|
||||
const int rstride = BM / TM;
|
||||
|
||||
const int lr = int(gl_WorkGroupID.x % rstride);
|
||||
const int lc = int(gl_WorkGroupID.x / rstride);
|
||||
|
||||
const int loadr = int(gl_WorkGroupID.x % BK);
|
||||
const int loadc = int(gl_WorkGroupID.x / BK);
|
||||
|
||||
const int loadstride = int(gl_WorkGroupSize.x);
|
||||
|
||||
int posA = ir * BM * pcs.inAStride;
|
||||
int posB = ic * BN * pcs.inBStride;
|
||||
|
||||
float sums[TM * TN];
|
||||
float cacheA[TM];
|
||||
float cacheB[TN];
|
||||
|
||||
[[unroll]] for (int i = 0; i < TM*TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
[[unroll]] for (int block = 0; block < pcs.K; block += BK) {
|
||||
[[unroll]] for (int l = 0; l < BM * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
bufA[(loadc + lc) * (BK+1) + loadr + lr] = inA[posA + (loadc + lc) * pcs.inAStride + loadr + lr + pcs.inAOff];
|
||||
}
|
||||
[[unroll]] for (int l = 0; l < BN * BK; l += loadstride) {
|
||||
const int lr = l % BK;
|
||||
const int lc = l / BK;
|
||||
bufB[(loadc + lc) * (BK+1) + loadr + lr] = inB[posB + (loadc + lc) * pcs.inBStride + loadr + lr + pcs.inBOff];
|
||||
}
|
||||
|
||||
barrier();
|
||||
memoryBarrierShared();
|
||||
|
||||
posA += BK;
|
||||
posB += BK;
|
||||
|
||||
[[unroll]] for (int i = 0; i < BK; i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (int j = 0; j < BM; j++) {
|
||||
cacheA[j] = bufA[(lr + j*rstride) * (BK+1) + i];
|
||||
}
|
||||
[[unroll]] for (int j = 0; j < TN; j++) {
|
||||
cacheB[j] = bufB[(lc * TN + j) * (BK+1) + i];
|
||||
}
|
||||
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
sums[cc * TM + cr] += cacheA[cr] * cacheB[cc];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
const int dr = ir * BM + lr;
|
||||
const int dc = ic * BN + lc * TN;
|
||||
|
||||
[[unroll]] for (int cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (int cr = 0; cr < TM; cr++) {
|
||||
out_[(dc + cc) * pcs.outStride + dr + cr*rstride + pcs.outOff] = sums[cc * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
void ggml_vk_mul_mat_f32(kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff,
|
||||
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
|
||||
const std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
|
||||
int64_t ne00, int64_t ne01, int64_t ne02, uint64_t ne03,
|
||||
int64_t ne10, int64_t ne11,
|
||||
int nb2, int nb3) {
|
||||
const static auto spirv = glsl_compile_source(program_source_head+program_mul_mat_f32, __func__);
|
||||
|
||||
struct PushConstants {
|
||||
int32_t M, N, K, inAStride, inBStride, outStride;
|
||||
uint32_t inAOff, inBOff, outOff;
|
||||
} pushConsts {
|
||||
(int)ne01, (int)ne11, (int)ne10, (int)ne00, (int)ne10, (int)ne01,
|
||||
inAOff, inBOff, outOff
|
||||
};
|
||||
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
auto off = i02*nb2 + i03*nb3;
|
||||
pushConsts.inAOff = inAOff + off;
|
||||
pushConsts.inBOff = inBOff + off;
|
||||
pushConsts.outOff = outOff + off;
|
||||
seq.record<kp::OpAlgoDispatch>(mgr.algorithm<float, PushConstants>({inA, inB, out}, spirv, {uint32_t(ne01/128), uint32_t(ne11/128)}, {}, {pushConsts}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
|
||||
printf("%s: evaluating graph\n", __func__);
|
||||
|
||||
|
@ -1266,11 +1177,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
|||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
if (src0->type == GGML_TYPE_F32
|
||||
&& src1->type == GGML_TYPE_F32) {
|
||||
ggml_vk_mul_mat_f32(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02, ne03, ne10, ne11, nb2, nb3);
|
||||
break;
|
||||
} else if (src0->type == GGML_TYPE_F16
|
||||
if (src0->type == GGML_TYPE_F16
|
||||
&& src1->type == GGML_TYPE_F32) {
|
||||
ggml_vk_mul_mat_f16(seq, id_src0, offs_src0, id_src1, offs_src1, id_dst, offs_dst, ne00, ne01, ne02, ne03, ne10, ne11, nb10, nb11, nb12, nb13, nb2, nb3);
|
||||
break;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue