Fix matmul k-split bug

This commit is contained in:
0cc4m 2023-11-05 12:24:09 +01:00
parent 00bea85cf2
commit bd7fa3f9e4
2 changed files with 128 additions and 97 deletions

View file

@ -695,24 +695,24 @@ static void ggml_vk_load_shaders() {
auto warptile_s = { 32, 32, 32, 8, 32, 32, 2, 2, 2 };
if (vk_device.fp16) {
vk_pipeline_matmul_f32_l = ggml_vk_create_pipeline("matmul_f32_l", matmul_f32_l_len, matmul_f32_l_data, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
vk_pipeline_matmul_f32_m = ggml_vk_create_pipeline("matmul_f32_m", matmul_f32_m_len, matmul_f32_m_data, "main", 3, 7 * sizeof(int), {64, 64, 1}, warptile_m, 64);
vk_pipeline_matmul_f32_s = ggml_vk_create_pipeline("matmul_f32_s", matmul_f32_s_len, matmul_f32_s_data, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
vk_pipeline_matmul_f32_l = ggml_vk_create_pipeline("matmul_f32_l", matmul_f32_l_len, matmul_f32_l_data, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 1);
vk_pipeline_matmul_f32_m = ggml_vk_create_pipeline("matmul_f32_m", matmul_f32_m_len, matmul_f32_m_data, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 1);
vk_pipeline_matmul_f32_s = ggml_vk_create_pipeline("matmul_f32_s", matmul_f32_s_len, matmul_f32_s_data, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 1);
vk_pipeline_matmul_f32_aligned_l = ggml_vk_create_pipeline("matmul_f32_aligned_l", matmul_f32_aligned_l_len, matmul_f32_aligned_l_data, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
vk_pipeline_matmul_f32_aligned_m = ggml_vk_create_pipeline("matmul_f32_aligned_m", matmul_f32_aligned_m_len, matmul_f32_aligned_m_data, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
vk_pipeline_matmul_f32_aligned_s = ggml_vk_create_pipeline("matmul_f32_aligned_s", matmul_f32_aligned_s_len, matmul_f32_aligned_s_data, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
vk_pipeline_matmul_f16_l = ggml_vk_create_pipeline("matmul_f16_l", matmul_f16_l_len, matmul_f16_l_data, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
vk_pipeline_matmul_f16_m = ggml_vk_create_pipeline("matmul_f16_m", matmul_f16_m_len, matmul_f16_m_data, "main", 3, 7 * sizeof(int), {64, 64, 1}, warptile_m, 64);
vk_pipeline_matmul_f16_s = ggml_vk_create_pipeline("matmul_f16_s", matmul_f16_s_len, matmul_f16_s_data, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
vk_pipeline_matmul_f16_l = ggml_vk_create_pipeline("matmul_f16_l", matmul_f16_l_len, matmul_f16_l_data, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 1);
vk_pipeline_matmul_f16_m = ggml_vk_create_pipeline("matmul_f16_m", matmul_f16_m_len, matmul_f16_m_data, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 1);
vk_pipeline_matmul_f16_s = ggml_vk_create_pipeline("matmul_f16_s", matmul_f16_s_len, matmul_f16_s_data, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 1);
vk_pipeline_matmul_f16_aligned_l = ggml_vk_create_pipeline("matmul_f16_aligned_l", matmul_f16_aligned_l_len, matmul_f16_aligned_l_data, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
vk_pipeline_matmul_f16_aligned_m = ggml_vk_create_pipeline("matmul_f16_aligned_m", matmul_f16_aligned_m_len, matmul_f16_aligned_m_data, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
vk_pipeline_matmul_f16_aligned_s = ggml_vk_create_pipeline("matmul_f16_aligned_s", matmul_f16_aligned_s_len, matmul_f16_aligned_s_data, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
vk_pipeline_matmul_f16_f32_l = ggml_vk_create_pipeline("matmul_f16_f32_l", matmul_f16_f32_l_len, matmul_f16_f32_l_data, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
vk_pipeline_matmul_f16_f32_m = ggml_vk_create_pipeline("matmul_f16_f32_m", matmul_f16_f32_m_len, matmul_f16_f32_m_data, "main", 3, 7 * sizeof(int), {64, 64, 1}, warptile_m, 64);
vk_pipeline_matmul_f16_f32_s = ggml_vk_create_pipeline("matmul_f16_f32_s", matmul_f16_f32_s_len, matmul_f16_f32_s_data, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
vk_pipeline_matmul_f16_f32_l = ggml_vk_create_pipeline("matmul_f16_f32_l", matmul_f16_f32_l_len, matmul_f16_f32_l_data, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 1);
vk_pipeline_matmul_f16_f32_m = ggml_vk_create_pipeline("matmul_f16_f32_m", matmul_f16_f32_m_len, matmul_f16_f32_m_data, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 1);
vk_pipeline_matmul_f16_f32_s = ggml_vk_create_pipeline("matmul_f16_f32_s", matmul_f16_f32_s_len, matmul_f16_f32_s_data, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 1);
vk_pipeline_matmul_f16_f32_aligned_l = ggml_vk_create_pipeline("matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_l_len, matmul_f16_f32_aligned_l_data, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
vk_pipeline_matmul_f16_f32_aligned_m = ggml_vk_create_pipeline("matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_m_len, matmul_f16_f32_aligned_m_data, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
vk_pipeline_matmul_f16_f32_aligned_s = ggml_vk_create_pipeline("matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_s_len, matmul_f16_f32_aligned_s_data, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
@ -766,24 +766,24 @@ static void ggml_vk_load_shaders() {
vk_pipeline_scale_f32 = ggml_vk_create_pipeline("scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
} else {
vk_pipeline_matmul_f32_l = ggml_vk_create_pipeline("matmul_f32_l", matmul_f32_l_fp32_len, matmul_f32_l_fp32_data, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
vk_pipeline_matmul_f32_m = ggml_vk_create_pipeline("matmul_f32_m", matmul_f32_m_fp32_len, matmul_f32_m_fp32_data, "main", 3, 7 * sizeof(int), {64, 64, 1}, warptile_m, 64);
vk_pipeline_matmul_f32_s = ggml_vk_create_pipeline("matmul_f32_s", matmul_f32_s_fp32_len, matmul_f32_s_fp32_data, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
vk_pipeline_matmul_f32_l = ggml_vk_create_pipeline("matmul_f32_l", matmul_f32_l_fp32_len, matmul_f32_l_fp32_data, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 1);
vk_pipeline_matmul_f32_m = ggml_vk_create_pipeline("matmul_f32_m", matmul_f32_m_fp32_len, matmul_f32_m_fp32_data, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 1);
vk_pipeline_matmul_f32_s = ggml_vk_create_pipeline("matmul_f32_s", matmul_f32_s_fp32_len, matmul_f32_s_fp32_data, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 1);
vk_pipeline_matmul_f32_aligned_l = ggml_vk_create_pipeline("matmul_f32_aligned_l", matmul_f32_aligned_l_fp32_len, matmul_f32_aligned_l_fp32_data, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
vk_pipeline_matmul_f32_aligned_m = ggml_vk_create_pipeline("matmul_f32_aligned_m", matmul_f32_aligned_m_fp32_len, matmul_f32_aligned_m_fp32_data, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
vk_pipeline_matmul_f32_aligned_s = ggml_vk_create_pipeline("matmul_f32_aligned_s", matmul_f32_aligned_s_fp32_len, matmul_f32_aligned_s_fp32_data, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
vk_pipeline_matmul_f16_l = ggml_vk_create_pipeline("matmul_f16_l", matmul_f16_l_fp32_len, matmul_f16_l_fp32_data, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
vk_pipeline_matmul_f16_m = ggml_vk_create_pipeline("matmul_f16_m", matmul_f16_m_fp32_len, matmul_f16_m_fp32_data, "main", 3, 7 * sizeof(int), {64, 64, 1}, warptile_m, 64);
vk_pipeline_matmul_f16_s = ggml_vk_create_pipeline("matmul_f16_s", matmul_f16_s_fp32_len, matmul_f16_s_fp32_data, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
vk_pipeline_matmul_f16_l = ggml_vk_create_pipeline("matmul_f16_l", matmul_f16_l_fp32_len, matmul_f16_l_fp32_data, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 1);
vk_pipeline_matmul_f16_m = ggml_vk_create_pipeline("matmul_f16_m", matmul_f16_m_fp32_len, matmul_f16_m_fp32_data, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 1);
vk_pipeline_matmul_f16_s = ggml_vk_create_pipeline("matmul_f16_s", matmul_f16_s_fp32_len, matmul_f16_s_fp32_data, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 1);
vk_pipeline_matmul_f16_aligned_l = ggml_vk_create_pipeline("matmul_f16_aligned_l", matmul_f16_aligned_l_fp32_len, matmul_f16_aligned_l_fp32_data, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
vk_pipeline_matmul_f16_aligned_m = ggml_vk_create_pipeline("matmul_f16_aligned_m", matmul_f16_aligned_m_fp32_len, matmul_f16_aligned_m_fp32_data, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
vk_pipeline_matmul_f16_aligned_s = ggml_vk_create_pipeline("matmul_f16_aligned_s", matmul_f16_aligned_s_fp32_len, matmul_f16_aligned_s_fp32_data, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
vk_pipeline_matmul_f16_f32_l = ggml_vk_create_pipeline("matmul_f16_f32_l", matmul_f16_f32_l_fp32_len, matmul_f16_f32_l_fp32_data, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
vk_pipeline_matmul_f16_f32_m = ggml_vk_create_pipeline("matmul_f16_f32_m", matmul_f16_f32_m_fp32_len, matmul_f16_f32_m_fp32_data, "main", 3, 7 * sizeof(int), {64, 64, 1}, warptile_m, 64);
vk_pipeline_matmul_f16_f32_s = ggml_vk_create_pipeline("matmul_f16_f32_s", matmul_f16_f32_s_fp32_len, matmul_f16_f32_s_fp32_data, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
vk_pipeline_matmul_f16_f32_l = ggml_vk_create_pipeline("matmul_f16_f32_l", matmul_f16_f32_l_fp32_len, matmul_f16_f32_l_fp32_data, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 1);
vk_pipeline_matmul_f16_f32_m = ggml_vk_create_pipeline("matmul_f16_f32_m", matmul_f16_f32_m_fp32_len, matmul_f16_f32_m_fp32_data, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 1);
vk_pipeline_matmul_f16_f32_s = ggml_vk_create_pipeline("matmul_f16_f32_s", matmul_f16_f32_s_fp32_len, matmul_f16_f32_s_fp32_data, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 1);
vk_pipeline_matmul_f16_f32_aligned_l = ggml_vk_create_pipeline("matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_l_fp32_len, matmul_f16_f32_aligned_l_fp32_data, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
vk_pipeline_matmul_f16_f32_aligned_m = ggml_vk_create_pipeline("matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_m_fp32_len, matmul_f16_f32_aligned_m_fp32_data, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
vk_pipeline_matmul_f16_f32_aligned_s = ggml_vk_create_pipeline("matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_s_fp32_len, matmul_f16_f32_aligned_s_fp32_data, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
@ -1013,6 +1013,7 @@ std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl;
ggml_vk_test_transfer(1024 * 1024 * m);
}
const std::vector<size_t> vals {
100, 46, 558,
1024, 2, 4096,
512, 1, 256,
128, 110, 622,
@ -1032,30 +1033,15 @@ std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl;
4096, 512, 11008,
32000, 512, 4096,
};
const size_t num_it = 100;
for (size_t i = 0; i < vals.size(); i += 3) {
ggml_vk_test_matmul_f32(vals[i], vals[i + 1], vals[i + 2], 1000, 1, 0);
ggml_vk_test_matmul_f32(vals[i], vals[i + 1], vals[i + 2], 1000, 4, 0);
ggml_vk_test_matmul_f32(vals[i], vals[i + 1], vals[i + 2], 1000, 1, 1);
ggml_vk_test_matmul_f32(vals[i], vals[i + 1], vals[i + 2], 1000, 4, 1);
ggml_vk_test_matmul_f32(vals[i], vals[i + 1], vals[i + 2], 1000, 1, 2);
ggml_vk_test_matmul_f32(vals[i], vals[i + 1], vals[i + 2], 1000, 4, 2);
ggml_vk_test_matmul_f16(vals[i], vals[i + 1], vals[i + 2], num_it, 1, 0);
ggml_vk_test_matmul_f16(vals[i], vals[i + 1], vals[i + 2], num_it, 4, 0);
ggml_vk_test_matmul_f16(vals[i], vals[i + 1], vals[i + 2], num_it, 1, 1);
ggml_vk_test_matmul_f16(vals[i], vals[i + 1], vals[i + 2], num_it, 4, 1);
ggml_vk_test_matmul_f16(vals[i], vals[i + 1], vals[i + 2], num_it, 1, 2);
ggml_vk_test_matmul_f16(vals[i], vals[i + 1], vals[i + 2], num_it, 4, 2);
std::cerr << std::endl;
ggml_vk_test_matmul_f16(vals[i], vals[i + 1], vals[i + 2], 1000, 1, 0);
ggml_vk_test_matmul_f16(vals[i], vals[i + 1], vals[i + 2], 1000, 4, 0);
ggml_vk_test_matmul_f16(vals[i], vals[i + 1], vals[i + 2], 1000, 1, 1);
ggml_vk_test_matmul_f16(vals[i], vals[i + 1], vals[i + 2], 1000, 4, 1);
ggml_vk_test_matmul_f16(vals[i], vals[i + 1], vals[i + 2], 1000, 1, 2);
ggml_vk_test_matmul_f16(vals[i], vals[i + 1], vals[i + 2], 1000, 4, 2);
std::cerr << std::endl;
ggml_vk_test_matmul_f16_f32(vals[i], vals[i + 1], vals[i + 2], 1000, 1, 0);
ggml_vk_test_matmul_f16_f32(vals[i], vals[i + 1], vals[i + 2], 1000, 4, 0);
ggml_vk_test_matmul_f16_f32(vals[i], vals[i + 1], vals[i + 2], 1000, 1, 1);
ggml_vk_test_matmul_f16_f32(vals[i], vals[i + 1], vals[i + 2], 1000, 4, 1);
ggml_vk_test_matmul_f16_f32(vals[i], vals[i + 1], vals[i + 2], 1000, 1, 2);
ggml_vk_test_matmul_f16_f32(vals[i], vals[i + 1], vals[i + 2], 1000, 4, 2);
std::cerr << std::endl << std::endl;
}
#endif
}
@ -1583,11 +1569,11 @@ static vk_sequence ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const st
}
}
static int ggml_vk_guess_split_k(int m, int n, int k) {
static int ggml_vk_guess_split_k(int m, int n, int k, bool aligned) {
#ifdef VK_DEBUG
std::cerr << "ggml_vk_guess_split_k()";
std::cerr << "ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")";
#endif
if (k > 128 && (m < 128 || n < 128)) {
if (aligned && k > 128 && (m < 128 || n < 128)) {
#ifdef VK_DEBUG
std::cerr << " = 4" << std::endl;
#endif
@ -1602,15 +1588,15 @@ static int ggml_vk_guess_split_k(int m, int n, int k) {
static uint32_t ggml_vk_guess_matmul_pipeline_align(int m, int n) {
#ifdef VK_DEBUG
std::cerr << "ggml_vk_guess_matmul_pipeline_padding()" << std::endl;
std::cerr << "ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")" << std::endl;
#endif
if (m <= 32 || n <= 32) {
return vk_pipeline_matmul_f32_s.align;
return vk_pipeline_matmul_f32_aligned_s.align;
}
if (m <= 64 || n <= 64) {
return vk_pipeline_matmul_f32_m.align;
return vk_pipeline_matmul_f32_aligned_m.align;
}
return vk_pipeline_matmul_f32_l.align;
return vk_pipeline_matmul_f32_aligned_l.align;
}
static vk_pipeline* ggml_vk_guess_matmul_pipeline(bool bit16_x, bool bit16_y, int m, int n, bool aligned) {
@ -1690,7 +1676,7 @@ static vk_sequence ggml_vk_matmul(vk_pipeline& pipeline, vk_subbuffer&& a, vk_su
}
// Synchronize the two submissions
const std::vector<int> pc1 = { m, n, k, stride_a, stride_b, stride_d, CEIL_DIV(stride_a, split_k) };
const std::vector<int> pc1 = { m, n, k, stride_a, stride_b, stride_d, CEIL_DIV(k, split_k) };
ggml_vk_dispatch_pipeline(s, pipeline, { a, b, d }, pc1.size() * sizeof(int), pc1.data(), { (uint32_t)m * split_k, (uint32_t)n, 1 });
ggml_vk_sync_buffers(s.buffer, { d }, q, vk::AccessFlagBits::eMemoryWrite, vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite, true);
const std::vector<int> pc2 = { m, n, split_k };
@ -1726,13 +1712,14 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
const int split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
const int kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ne01, ne11));
const bool aligned = ne10 == kpad;
const int split_k = ggml_vk_guess_split_k(ne01, ne11, ne10, aligned);
const bool load_x = src0->backend == GGML_BACKEND_GPU;
vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(false, false, ne01, ne11, ne10 == kpad);
vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(false, false, ne01, ne11, aligned);
const uint32_t x_sz = ggml_vk_align_size(sizeof(float) * x_ne, vk_device.properties.limits.minStorageBufferOffsetAlignment);
const uint32_t y_sz = ggml_vk_align_size(sizeof(float) * y_ne, vk_device.properties.limits.minStorageBufferOffsetAlignment);
@ -1848,11 +1835,12 @@ static void ggml_vk_mul_mat_q_f16(const ggml_tensor * src0, const ggml_tensor *
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
const int split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
const int kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ne01, ne11));
const bool aligned = ne10 == kpad;
vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(true, !f16_f32_kernel, ne01, ne11, ne10 == kpad);
const int split_k = ggml_vk_guess_split_k(ne01, ne11, ne10, aligned);
vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(true, !f16_f32_kernel, ne01, ne11, aligned);
const uint32_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), vk_device.properties.limits.minStorageBufferOffsetAlignment);
const uint32_t qy_sz = ggml_vk_align_size(ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type), vk_device.properties.limits.minStorageBufferOffsetAlignment);
@ -2518,7 +2506,14 @@ void ggml_vk_preallocate_buffers_graph(ggml_tensor * node){
const bool f16_f32_kernel = use_src1 && src1->type == GGML_TYPE_F32;
const bool qy_needs_dequant = use_src1 && src1->type != GGML_TYPE_F16 && !f16_f32_kernel;
const int split_k = node->op == GGML_OP_MUL_MAT ? ggml_vk_guess_split_k(ne01, ne11, ne10) : 1;
int split_k;
if (node->op == GGML_OP_MUL_MAT) {
const int kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ne01, ne11));
const bool aligned = ne10 == kpad;
split_k = ggml_vk_guess_split_k(ne01, ne11, ne10, aligned);
} else {
split_k = 1;
}
const uint32_t x_ne = ne00 * ne01;
const uint32_t y_ne = ne10 * ne11;
const uint32_t d_ne = ne20 * ne21;
@ -2981,6 +2976,22 @@ void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k, size_t num_it, int sp
vk_pipeline * p;
std::string shname;
if (shader_size == 0) {
p = &vk_pipeline_matmul_f32_aligned_s;
shname = "F32_ALIGNED_S";
} else if (shader_size == 1) {
p = &vk_pipeline_matmul_f32_aligned_m;
shname = "F32_ALIGNED_M";
} else if (shader_size == 2) {
p = &vk_pipeline_matmul_f32_aligned_l;
shname = "F32_ALIGNED_L";
} else {
GGML_ASSERT(0);
}
const size_t kpad = ggml_vk_align_size(k, p->align);
if (k != kpad) {
if (shader_size == 0) {
p = &vk_pipeline_matmul_f32_s;
shname = "F32_S";
@ -2990,11 +3001,8 @@ void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k, size_t num_it, int sp
} else if (shader_size == 2) {
p = &vk_pipeline_matmul_f32_l;
shname = "F32_L";
} else {
GGML_ASSERT(0);
}
const size_t kpad = ggml_vk_align_size(k, p->align);
}
ggml_vk_pipeline_allocate_descriptor_sets(*p, num_it);
if (split_k > 1) {
@ -3004,8 +3012,8 @@ void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k, size_t num_it, int sp
vk_buffer d_X;
vk_buffer d_Y;
vk_buffer d_D;
ggml_vk_pool_malloc(sizeof(float) * kpad * m, &d_X, {});
ggml_vk_pool_malloc(sizeof(float) * kpad * n, &d_Y, {});
ggml_vk_pool_malloc(sizeof(float) * k * m, &d_X, {});
ggml_vk_pool_malloc(sizeof(float) * k * n, &d_Y, {});
ggml_vk_pool_malloc(sizeof(float) * d_ne * split_k, &d_D, {});
float* x = (float *) malloc(sizeof(float) * x_ne);
@ -3030,7 +3038,7 @@ void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k, size_t num_it, int sp
auto begin = std::chrono::high_resolution_clock::now();
for (size_t i = 0; i < num_it; i++) {
seq.push_back(ggml_vk_matmul(*p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), m, n, k, kpad, kpad, m, split_k, vk_device.compute_queue, {}, {}));
seq.push_back(ggml_vk_matmul(*p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), m, n, k, k, k, m, split_k, vk_device.compute_queue, {}, {}));
}
ggml_vk_submit(vk_device.compute_queue, seq, VK_NULL_HANDLE);
@ -3093,6 +3101,22 @@ void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k, size_t num_it, int sp
vk_pipeline * p;
std::string shname;
if (shader_size == 0) {
p = &vk_pipeline_matmul_f16_aligned_s;
shname = "F16_ALIGNED_S";
} else if (shader_size == 1) {
p = &vk_pipeline_matmul_f16_aligned_m;
shname = "F16_ALIGNED_M";
} else if (shader_size == 2) {
p = &vk_pipeline_matmul_f16_aligned_l;
shname = "F16_ALIGNED_L";
} else {
GGML_ASSERT(0);
}
const size_t kpad = ggml_vk_align_size(k, p->align);
if (k != kpad) {
if (shader_size == 0) {
p = &vk_pipeline_matmul_f16_s;
shname = "F16_S";
@ -3102,11 +3126,8 @@ void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k, size_t num_it, int sp
} else if (shader_size == 2) {
p = &vk_pipeline_matmul_f16_l;
shname = "F16_L";
} else {
GGML_ASSERT(0);
}
const size_t kpad = ggml_vk_align_size(k, p->align);
}
ggml_vk_pipeline_allocate_descriptor_sets(*p, num_it);
if (split_k > 1) {
@ -3116,8 +3137,8 @@ void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k, size_t num_it, int sp
vk_buffer d_X;
vk_buffer d_Y;
vk_buffer d_D;
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * kpad * m, &d_X, {});
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * kpad * n, &d_Y, {});
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * k * m, &d_X, {});
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * k * n, &d_Y, {});
ggml_vk_pool_malloc(sizeof(float) * d_ne * split_k, &d_D, {});
ggml_fp16_t* x = (ggml_fp16_t *) malloc(sizeof(ggml_fp16_t) * x_ne);
@ -3142,7 +3163,7 @@ void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k, size_t num_it, int sp
auto begin = std::chrono::high_resolution_clock::now();
for (size_t i = 0; i < num_it; i++) {
seq.push_back(ggml_vk_matmul(*p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), m, n, k, kpad, kpad, m, split_k, vk_device.compute_queue, {}, {}));
seq.push_back(ggml_vk_matmul(*p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), m, n, k, k, k, m, split_k, vk_device.compute_queue, {}, {}));
}
ggml_vk_submit(vk_device.compute_queue, seq, VK_NULL_HANDLE);
@ -3212,6 +3233,22 @@ void ggml_vk_test_matmul_f16_f32(size_t m, size_t n, size_t k, size_t num_it, in
vk_pipeline * p;
std::string shname;
if (shader_size == 0) {
p = &vk_pipeline_matmul_f16_f32_aligned_s;
shname = "F16_F32_ALIGNED_S";
} else if (shader_size == 1) {
p = &vk_pipeline_matmul_f16_f32_aligned_m;
shname = "F16_F32_ALIGNED_M";
} else if (shader_size == 2) {
p = &vk_pipeline_matmul_f16_f32_aligned_l;
shname = "F16_F32_ALIGNED_L";
} else {
GGML_ASSERT(0);
}
const size_t kpad = ggml_vk_align_size(k, p->align);
if (k != kpad) {
if (shader_size == 0) {
p = &vk_pipeline_matmul_f16_f32_s;
shname = "F16_F32_S";
@ -3221,11 +3258,8 @@ void ggml_vk_test_matmul_f16_f32(size_t m, size_t n, size_t k, size_t num_it, in
} else if (shader_size == 2) {
p = &vk_pipeline_matmul_f16_f32_l;
shname = "F16_F32_L";
} else {
GGML_ASSERT(0);
}
const size_t kpad = ggml_vk_align_size(k, p->align);
}
ggml_vk_pipeline_allocate_descriptor_sets(*p, num_it);
if (split_k > 1) {
@ -3235,8 +3269,8 @@ void ggml_vk_test_matmul_f16_f32(size_t m, size_t n, size_t k, size_t num_it, in
vk_buffer d_X;
vk_buffer d_Y;
vk_buffer d_D;
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * kpad * m, &d_X, {});
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * kpad * n, &d_Y, {});
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * k * m, &d_X, {});
ggml_vk_pool_malloc(sizeof(ggml_fp16_t) * k * n, &d_Y, {});
ggml_vk_pool_malloc(sizeof(float) * d_ne * split_k, &d_D, {});
ggml_fp16_t* x = (ggml_fp16_t *) malloc(sizeof(ggml_fp16_t) * x_ne);
@ -3261,7 +3295,7 @@ void ggml_vk_test_matmul_f16_f32(size_t m, size_t n, size_t k, size_t num_it, in
auto begin = std::chrono::high_resolution_clock::now();
for (size_t i = 0; i < num_it; i++) {
seq.push_back(ggml_vk_matmul(*p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), m, n, k, kpad, kpad, m, split_k, vk_device.compute_queue, {}, {}));
seq.push_back(ggml_vk_matmul(*p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), m, n, k, k, k, m, split_k, vk_device.compute_queue, {}, {}));
}
ggml_vk_submit(vk_device.compute_queue, seq, VK_NULL_HANDLE);

View file

@ -304,17 +304,17 @@ void main() {
const int loadstride = int(gl_WorkGroupSize.x * LOAD_VEC) / BK;
const int start_k = ik * p.k_split;
const int end_k = (ik + 1) * p.k_split;
const int end_k = min(p.K, (ik + 1) * p.k_split);
int pos_a = ir * BM * p.stride_a / LOAD_VEC + start_k / LOAD_VEC;
int pos_b = ic * BN * p.stride_b / LOAD_VEC + start_k / LOAD_VEC;
D_TYPE sums[WMITER * TM * WNITER * TN];
FLOAT_TYPE sums[WMITER * TM * WNITER * TN];
FLOAT_TYPE cache_a[WMITER * TM];
FLOAT_TYPE cache_b[WNITER * TN];
[[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = 0.0f;
sums[i] = FLOAT_TYPE(0.0f);
}
[[unroll]] for (int block = start_k; block < end_k; block += BK) {
@ -374,7 +374,7 @@ void main() {
pos_a += BK / LOAD_VEC;
pos_b += BK / LOAD_VEC;
for (int i = 0; i < min(BK, p.K - block); i++) {
for (int i = 0; i < BK; i++) {
// Load from shared into cache
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (int j = 0; j < TM; j++) {
@ -391,7 +391,7 @@ void main() {
[[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (int cc = 0; cc < TN; cc++) {
[[unroll]] for (int cr = 0; cr < TM; cr++) {
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += D_TYPE(cache_a[wsir * TM + cr]) * D_TYPE(cache_b[wsic * TN + cc]);
sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += FLOAT_TYPE(cache_a[wsir * TM + cr]) * FLOAT_TYPE(cache_b[wsic * TN + cc]);
}
}
}
@ -414,7 +414,7 @@ void main() {
[[unroll]] for (int cc = 0; cc < TN; cc++) {
[[unroll]] for (int cr = 0; cr < TM; cr++) {
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr];
data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
}
}
}
@ -1590,12 +1590,9 @@ async def main():
tasks.append(string_to_spv("matmul_f32_aligned_m", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv("matmul_f32_aligned_s", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
stream.clear();
stream.extend((mulmat_head, shader_float_type, mulmat_body));
tasks.append(string_to_spv("matmul_f16_l", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv("matmul_f16_m", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv("matmul_f16_s", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv("matmul_f16_aligned_l", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv("matmul_f16_aligned_m", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv("matmul_f16_aligned_s", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16))