Fix MUL_MAT_ID matrix vector shader and dispatch code

This commit is contained in:
0cc4m 2024-06-01 19:35:57 +02:00
parent c8f93774c9
commit 2c3d0b42f3
3 changed files with 53283 additions and 55263 deletions

File diff suppressed because it is too large Load diff

View file

@ -228,31 +228,27 @@ typedef std::vector<vk_submission> vk_sequence;
struct vk_mat_mat_push_constants {
uint32_t M; uint32_t N; uint32_t K;
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; uint32_t k_split;
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
uint32_t k_split;
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
};
struct vk_mat_vec_push_constants {
uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
};
struct vk_mat_mat_id_push_constants {
uint32_t M; uint32_t N; uint32_t K;
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; uint32_t k_split;
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
uint32_t expert_stride_a; uint32_t expert_stride_d;
uint32_t n_as; uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint ne11;
};
struct vk_mat_vec_push_constants {
uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
};
struct vk_mat_vec_id_push_constants {
uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
uint expert_stride_a; uint expert_stride_d;
uint n_as; uint nei0; uint nbi1;
uint32_t nei0; uint32_t ne11;
};
struct vk_op_push_constants {
@ -2206,7 +2202,11 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context * ctx, vk_context
const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_dispatch_pipeline(" << pipeline->name << ", (" << wg0 << "," << wg1 << "," << wg2 << "))" << std::endl;
std::cerr << "ggml_vk_dispatch_pipeline(" << pipeline->name << ", {";
for (auto& buffer : buffers) {
std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.size << "), ";
}
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))" << std::endl;
#endif
std::vector<vk::DescriptorBufferInfo> descriptor_buffer_infos;
std::vector<vk::WriteDescriptorSet> write_descriptor_sets;
@ -2787,21 +2787,21 @@ static void ggml_vk_matmul(
ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline,
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d) {
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ")" << std::endl;
std::cerr << "ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")" << std::endl;
#endif
ggml_vk_sync_buffers(subctx);
if (split_k == 1) {
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, k, ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d };
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch });
return;
}
GGML_ASSERT(batch_stride_d == m * n);
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d };
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3 };
// Make sure enough workgroups get assigned for split k to work
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
ggml_vk_sync_buffers(subctx);
@ -2813,19 +2813,17 @@ static void ggml_vk_matmul_id(
ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline,
vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
uint32_t expert_stride_a, uint32_t expert_stride_d,
uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
"m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", " <<
"m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
"batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
"expert_stride_a: " << expert_stride_a << ", expert_stride_d: " << expert_stride_d << ", n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")" << std::endl;
"n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")" << std::endl;
#endif
ggml_vk_sync_buffers(subctx);
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, k, ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d,
expert_stride_a, expert_stride_d, n_as, nei0, nei1, nbi1, ne11 };
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
nei0, nei1, nbi1, ne11 };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
}
@ -3040,7 +3038,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
ctx, subctx, pipeline,
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
{ d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
ne01, ne11, ne10, ne10, ne10, ne01, split_k, ne12*ne13, ne02, ne12, r2, r3, stride_batch_x, stride_batch_y, ne20*ne21
ne01, ne11, ne10,
ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
split_k, ne12*ne13, ne02, ne12, r2, r3
); // NOLINT
}
@ -3192,8 +3192,8 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
// compute
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21),
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23} }, sizeof(vk_mat_vec_push_constants), &pc, { (uint32_t)ne01, (uint32_t)(ne12 * ne13), 1});
@ -3396,9 +3396,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context *
const uint64_t ne22 = dst->ne[2];
const uint64_t ne23 = dst->ne[3];
const uint64_t r2 = ne12 / ne02;
const uint64_t r3 = ne13 / ne03;
const uint64_t n_as = ne02;
GGML_ASSERT(n_as <= 8);
@ -3552,9 +3549,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context *
ctx, subctx, pipeline,
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
{ d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
ne01, ne21, ne10, ne10, ne10, ne01, ne02, ne12, r2, r3,
ne01, ne21, ne10, ne10, ne10, ne01,
stride_batch_x, stride_batch_y, ne20*ne21,
x_ne, ne20,
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11
); // NOLINT
}
@ -3580,23 +3576,18 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
const uint64_t ne12 = src1->ne[2];
const uint64_t ne13 = src1->ne[3];
GGML_ASSERT(ne11 == 1);
const uint64_t nei0 = ids->ne[0];
const uint64_t nei1 = ids->ne[1];
const uint64_t nbi1 = ids->nb[1];
const uint64_t nbi2 = ids->nb[2];
GGML_ASSERT(nei1 == 1);
const uint64_t ne20 = dst->ne[0];
const uint64_t ne21 = dst->ne[1];
const uint64_t ne22 = dst->ne[2];
const uint64_t ne23 = dst->ne[3];
const uint64_t r2 = ne12 / ne02;
const uint64_t r3 = ne13 / ne03;
const uint64_t n_as = ne02;
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
@ -3635,7 +3626,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
const uint64_t x_ne = ne01 * ne00;
const uint64_t y_ne = ne11 * ne10;
const uint64_t d_ne = ne11 * ne01;
const uint64_t d_ne = ne21 * ne20;
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
@ -3723,22 +3714,20 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
// compute
const vk_mat_vec_id_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
0, stride_batch_y, (uint32_t)(ne20*ne21),
(uint32_t)x_ne, (uint32_t)ne20,
(uint32_t)n_as, (uint32_t)nei0, (uint32_t)(nbi1 / ggml_type_size(ids->type)),
(uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21),
(uint32_t)nei0, (uint32_t)ne11,
};
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{ { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23}, { d_ids, ids_buf_offset, ids_sz } },
sizeof(vk_mat_vec_id_push_constants), &pc, { (uint32_t)ne01, (uint32_t)(ne22 * ne23), 1 });
sizeof(vk_mat_vec_id_push_constants), &pc, { (uint32_t)ne01, (uint32_t)nei0, 1 });
}
static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
#ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")" << std::endl;
#endif
if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst);
} else {
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst);
@ -4662,7 +4651,9 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
ggml_vk_ctx_begin(ctx, subctx);
ggml_vk_matmul(
ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n, 0, 0, 0, 0, 1
m, n, k,
k, k, m, k*m, k*n, m*n,
split_k, batch, batch, batch, 1, 1
);
ggml_vk_ctx_end(subctx);
}
@ -5166,7 +5157,9 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
ggml_vk_ctx_begin(ctx, subctx);
ggml_vk_matmul(
ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n, 0, 0, 0, 0, 1
m, n, k,
k, k, m, k*m, k*n, m*n,
split_k, batch, batch, batch, 1, 1
);
ggml_vk_ctx_end(subctx);
}

View file

@ -257,26 +257,22 @@ layout (push_constant) uniform parameter
uint stride_a;
uint stride_b;
uint stride_d;
uint k_split;
uint ne02;
uint ne12;
uint broadcast2;
uint broadcast3;
uint batch_stride_a;
uint batch_stride_b;
uint batch_stride_d;
#ifdef MUL_MAT_ID
uint expert_stride_a;
uint expert_stride_d;
uint n_as;
uint nei0;
uint nei1;
uint nbi1;
uint ne11;
#else
uint k_split;
uint ne02;
uint ne12;
uint broadcast2;
uint broadcast3;
#endif
} p;
@ -354,12 +350,17 @@ void main() {
if (ic * BN >= _ne1) return;
#endif
#ifdef MUL_MAT_ID
const uint start_k = 0;
const uint end_k = p.K;
#else
const uint start_k = ik * p.k_split;
const uint end_k = min(p.K, (ik + 1) * p.k_split);
#endif
uint pos_a = (
#ifdef MUL_MAT_ID
expert_idx * p.expert_stride_a +
expert_idx * p.batch_stride_a +
#else
batch_idx_a * p.batch_stride_a +
#endif
@ -1193,29 +1194,26 @@ layout (push_constant) uniform parameter
uint stride_b;
uint stride_d;
uint ne02;
uint ne12;
uint broadcast2;
uint broadcast3;
uint batch_stride_a;
uint batch_stride_b;
uint batch_stride_d;
#ifdef MUL_MAT_ID
uint expert_stride_a;
uint expert_stride_d;
uint n_as;
uint nei0;
uint nbi1;
uint ne11;
#else
uint ne02;
uint ne12;
uint broadcast2;
uint broadcast3;
#endif
} p;
void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
const uint batch_idx = gl_GlobalInvocationID.y;
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
const uint expert_idx = gl_GlobalInvocationID.y;
#else
const uint batch_idx = gl_GlobalInvocationID.y;
#endif
#ifndef MUL_MAT_ID
@ -1227,21 +1225,27 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
const uint batch_idx_a = i03 * p.ne02 + i02;
#else
const uint expert_id = data_ids[batch_idx * p.nbi1 + expert_idx];
const uint expert_id = data_ids[expert_idx];
#endif
a_offset =
#ifdef MUL_MAT_ID
expert_id * p.expert_stride_a;
expert_id * p.batch_stride_a;
#else
batch_idx_a * p.batch_stride_a;
#endif
b_offset = batch_idx * p.batch_stride_b;
b_offset =
#ifdef MUL_MAT_ID
(expert_idx % p.ne11) * p.stride_b;
#else
batch_idx * p.batch_stride_b;
#endif
d_offset =
#ifdef MUL_MAT_ID
expert_idx * p.expert_stride_d +
#endif
expert_idx * p.stride_d;
#else
batch_idx * p.batch_stride_d;
#endif
}
"""