From e96944551a84995b482f465b757146ec0f8bc728 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Thu, 16 Nov 2023 22:11:57 +0100 Subject: [PATCH] Add RMS Norm shader, rework op_f32 shader setup, fix matmul bug --- ggml-vulkan.cpp | 598 +++++++++++++++++++++--------- ggml_vk_generate_shaders.py | 707 ++++++++++++++++++------------------ 2 files changed, 794 insertions(+), 511 deletions(-) diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 56954cad4..88159452f 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -127,13 +127,7 @@ struct vk_device { struct vk_op_push_constants { int M; int N; - int stride_x; - int stride_y; - int stride_d; - int x_offset; - int y_offset; - int d_offset; - float scale; + float param; }; // Allow pre-recording command buffers @@ -153,6 +147,7 @@ struct ggml_vk_tensor_extra_gpu { std::vector out_seqs; size_t tensor_size; + vk_buffer * gpu_buffer; }; struct ggml_vk_garbage_collector { @@ -179,6 +174,7 @@ vk_pipeline vk_pipeline_get_rows_f32[VK_NUM_TYPES]; vk_pipeline vk_pipeline_mul_f32; vk_pipeline vk_pipeline_add_f32, vk_pipeline_add_f16_f32_f16; vk_pipeline vk_pipeline_scale_f32; +vk_pipeline vk_pipeline_rms_norm_f32; static size_t vk_semaphore_idx; static ggml_vk_garbage_collector vk_gc; @@ -602,7 +598,7 @@ static vk_buffer ggml_vk_create_buffer(size_t size, vk::MemoryPropertyFlags req_ return buf; } -static inline vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) { +static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) { return { buf, 0, (uint32_t) buf.size }; } @@ -665,7 +661,7 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf) { } } -static inline bool ggml_vk_build_shader(ggml_type type) { +static bool ggml_vk_build_shader(ggml_type type) { switch(type) { case GGML_TYPE_F16: case GGML_TYPE_Q4_0: @@ -717,6 +713,8 @@ static void ggml_vk_load_shaders() { vk_pipeline_matmul_f16_f32_aligned_m = ggml_vk_create_pipeline("matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_m_len, matmul_f16_f32_aligned_m_data, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); vk_pipeline_matmul_f16_f32_aligned_s = ggml_vk_create_pipeline("matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_s_len, matmul_f16_f32_aligned_s_data, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); + vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline("split_k_reduce", split_k_reduce_fp32_len, split_k_reduce_fp32_data, "main", 1, 3 * sizeof(int), {32, 32, 1}, {}, 1); + // Build dequant shaders vk_pipeline_dequant[GGML_TYPE_F32] = ggml_vk_create_pipeline("f32_to_f16", f32_to_f16_len, f32_to_f16_data, "main", 2, 4 * sizeof(int), {64, 1, 1}, {}, 1); @@ -756,15 +754,29 @@ static void ggml_vk_load_shaders() { vk_pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_K] = ggml_vk_create_pipeline("mul_mat_vec_q5_K_f32", mul_mat_vec_q5_K_f32_len, mul_mat_vec_q5_K_f32_data, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); vk_pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K] = ggml_vk_create_pipeline("mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); + // get_rows + vk_pipeline_get_rows[GGML_TYPE_F16] = ggml_vk_create_pipeline("get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_get_rows[GGML_TYPE_Q4_0] = ggml_vk_create_pipeline("get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_get_rows[GGML_TYPE_Q4_1] = ggml_vk_create_pipeline("get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_get_rows[GGML_TYPE_Q5_0] = ggml_vk_create_pipeline("get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_get_rows[GGML_TYPE_Q5_1] = ggml_vk_create_pipeline("get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_get_rows[GGML_TYPE_Q8_0] = ggml_vk_create_pipeline("get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + + vk_pipeline_get_rows_f32[GGML_TYPE_F16] = ggml_vk_create_pipeline("get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_get_rows_f32[GGML_TYPE_Q4_0] = ggml_vk_create_pipeline("get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_get_rows_f32[GGML_TYPE_Q4_1] = ggml_vk_create_pipeline("get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_get_rows_f32[GGML_TYPE_Q5_0] = ggml_vk_create_pipeline("get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_get_rows_f32[GGML_TYPE_Q5_1] = ggml_vk_create_pipeline("get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_get_rows_f32[GGML_TYPE_Q8_0] = ggml_vk_create_pipeline("get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + // add - vk_pipeline_add_f32 = ggml_vk_create_pipeline("add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); - vk_pipeline_add_f16_f32_f16 = ggml_vk_create_pipeline("add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); + vk_pipeline_add_f32 = ggml_vk_create_pipeline("add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_add_f16_f32_f16 = ggml_vk_create_pipeline("add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); // Static shaders - vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline("split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 1, 3 * sizeof(int), {32, 32, 1}, {}, 1); - vk_pipeline_mul_f32 = ggml_vk_create_pipeline("mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); + vk_pipeline_mul_f32 = ggml_vk_create_pipeline("mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - vk_pipeline_scale_f32 = ggml_vk_create_pipeline("scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); + vk_pipeline_scale_f32 = ggml_vk_create_pipeline("scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); } else { vk_pipeline_matmul_f32_l = ggml_vk_create_pipeline("matmul_f32_l", matmul_f32_l_fp32_len, matmul_f32_l_fp32_data, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 1); vk_pipeline_matmul_f32_m = ggml_vk_create_pipeline("matmul_f32_m", matmul_f32_m_fp32_len, matmul_f32_m_fp32_data, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 1); @@ -788,6 +800,8 @@ static void ggml_vk_load_shaders() { vk_pipeline_matmul_f16_f32_aligned_m = ggml_vk_create_pipeline("matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_m_fp32_len, matmul_f16_f32_aligned_m_fp32_data, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); vk_pipeline_matmul_f16_f32_aligned_s = ggml_vk_create_pipeline("matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_s_fp32_len, matmul_f16_f32_aligned_s_fp32_data, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); + vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline("split_k_reduce", split_k_reduce_fp32_len, split_k_reduce_fp32_data, "main", 1, 3 * sizeof(int), {32, 32, 1}, {}, 1); + // Build dequant shaders vk_pipeline_dequant[GGML_TYPE_F32] = ggml_vk_create_pipeline("f32_to_f16", f32_to_f16_fp32_len, f32_to_f16_fp32_data, "main", 2, 4 * sizeof(int), {64, 1, 1}, {}, 1); @@ -828,30 +842,31 @@ static void ggml_vk_load_shaders() { vk_pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K] = ggml_vk_create_pipeline("mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_fp32_len, mul_mat_vec_q6_K_f32_fp32_data, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); // get_rows - vk_pipeline_get_rows[GGML_TYPE_F16] = ggml_vk_create_pipeline("get_rows_f16", get_rows_f16_fp32_len, get_rows_f16_fp32_data, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); - vk_pipeline_get_rows[GGML_TYPE_Q4_0] = ggml_vk_create_pipeline("get_rows_q4_0", get_rows_q4_0_fp32_len, get_rows_q4_0_fp32_data, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); - vk_pipeline_get_rows[GGML_TYPE_Q4_1] = ggml_vk_create_pipeline("get_rows_q4_1", get_rows_q4_1_fp32_len, get_rows_q4_1_fp32_data, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); - vk_pipeline_get_rows[GGML_TYPE_Q5_0] = ggml_vk_create_pipeline("get_rows_q5_0", get_rows_q5_0_fp32_len, get_rows_q5_0_fp32_data, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); - vk_pipeline_get_rows[GGML_TYPE_Q5_1] = ggml_vk_create_pipeline("get_rows_q5_1", get_rows_q5_1_fp32_len, get_rows_q5_1_fp32_data, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); - vk_pipeline_get_rows[GGML_TYPE_Q8_0] = ggml_vk_create_pipeline("get_rows_q8_0", get_rows_q8_0_fp32_len, get_rows_q8_0_fp32_data, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); + vk_pipeline_get_rows[GGML_TYPE_F16] = ggml_vk_create_pipeline("get_rows_f16", get_rows_f16_fp32_len, get_rows_f16_fp32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_get_rows[GGML_TYPE_Q4_0] = ggml_vk_create_pipeline("get_rows_q4_0", get_rows_q4_0_fp32_len, get_rows_q4_0_fp32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_get_rows[GGML_TYPE_Q4_1] = ggml_vk_create_pipeline("get_rows_q4_1", get_rows_q4_1_fp32_len, get_rows_q4_1_fp32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_get_rows[GGML_TYPE_Q5_0] = ggml_vk_create_pipeline("get_rows_q5_0", get_rows_q5_0_fp32_len, get_rows_q5_0_fp32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_get_rows[GGML_TYPE_Q5_1] = ggml_vk_create_pipeline("get_rows_q5_1", get_rows_q5_1_fp32_len, get_rows_q5_1_fp32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_get_rows[GGML_TYPE_Q8_0] = ggml_vk_create_pipeline("get_rows_q8_0", get_rows_q8_0_fp32_len, get_rows_q8_0_fp32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - vk_pipeline_get_rows_f32[GGML_TYPE_F16] = ggml_vk_create_pipeline("get_rows_f16_f32", get_rows_f16_f32_fp32_len, get_rows_f16_f32_fp32_data, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); - vk_pipeline_get_rows_f32[GGML_TYPE_Q4_0] = ggml_vk_create_pipeline("get_rows_q4_0_f32", get_rows_q4_0_f32_fp32_len, get_rows_q4_0_f32_fp32_data, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); - vk_pipeline_get_rows_f32[GGML_TYPE_Q4_1] = ggml_vk_create_pipeline("get_rows_q4_1_f32", get_rows_q4_1_f32_fp32_len, get_rows_q4_1_f32_fp32_data, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); - vk_pipeline_get_rows_f32[GGML_TYPE_Q5_0] = ggml_vk_create_pipeline("get_rows_q5_0_f32", get_rows_q5_0_f32_fp32_len, get_rows_q5_0_f32_fp32_data, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); - vk_pipeline_get_rows_f32[GGML_TYPE_Q5_1] = ggml_vk_create_pipeline("get_rows_q5_1_f32", get_rows_q5_1_f32_fp32_len, get_rows_q5_1_f32_fp32_data, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); - vk_pipeline_get_rows_f32[GGML_TYPE_Q8_0] = ggml_vk_create_pipeline("get_rows_q8_0_f32", get_rows_q8_0_f32_fp32_len, get_rows_q8_0_f32_fp32_data, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); + vk_pipeline_get_rows_f32[GGML_TYPE_F16] = ggml_vk_create_pipeline("get_rows_f16_f32", get_rows_f16_f32_fp32_len, get_rows_f16_f32_fp32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_get_rows_f32[GGML_TYPE_Q4_0] = ggml_vk_create_pipeline("get_rows_q4_0_f32", get_rows_q4_0_f32_fp32_len, get_rows_q4_0_f32_fp32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_get_rows_f32[GGML_TYPE_Q4_1] = ggml_vk_create_pipeline("get_rows_q4_1_f32", get_rows_q4_1_f32_fp32_len, get_rows_q4_1_f32_fp32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_get_rows_f32[GGML_TYPE_Q5_0] = ggml_vk_create_pipeline("get_rows_q5_0_f32", get_rows_q5_0_f32_fp32_len, get_rows_q5_0_f32_fp32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_get_rows_f32[GGML_TYPE_Q5_1] = ggml_vk_create_pipeline("get_rows_q5_1_f32", get_rows_q5_1_f32_fp32_len, get_rows_q5_1_f32_fp32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_get_rows_f32[GGML_TYPE_Q8_0] = ggml_vk_create_pipeline("get_rows_q8_0_f32", get_rows_q8_0_f32_fp32_len, get_rows_q8_0_f32_fp32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); // add - vk_pipeline_add_f32 = ggml_vk_create_pipeline("add_f32", add_f32_fp32_len, add_f32_fp32_data, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); - vk_pipeline_add_f16_f32_f16 = ggml_vk_create_pipeline("add_f16_f32_f16", add_f16_f32_f16_fp32_len, add_f16_f32_f16_fp32_data, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); + vk_pipeline_add_f32 = ggml_vk_create_pipeline("add_f32", add_f32_fp32_len, add_f32_fp32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + vk_pipeline_add_f16_f32_f16 = ggml_vk_create_pipeline("add_f16_f32_f16", add_f16_f32_f16_fp32_len, add_f16_f32_f16_fp32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); // Static shaders - vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline("split_k_reduce", split_k_reduce_fp32_len, split_k_reduce_fp32_data, "main", 1, 3 * sizeof(int), {32, 32, 1}, {}, 1); - vk_pipeline_mul_f32 = ggml_vk_create_pipeline("mul_f32", mul_f32_fp32_len, mul_f32_fp32_data, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); + vk_pipeline_mul_f32 = ggml_vk_create_pipeline("mul_f32", mul_f32_fp32_len, mul_f32_fp32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - vk_pipeline_scale_f32 = ggml_vk_create_pipeline("scale_f32", scale_f32_fp32_len, scale_f32_fp32_data, "main", 2, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); + vk_pipeline_scale_f32 = ggml_vk_create_pipeline("scale_f32", scale_f32_fp32_len, scale_f32_fp32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); } + + vk_pipeline_rms_norm_f32 = ggml_vk_create_pipeline("rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); } void ggml_vk_test_transfer(size_t ne); @@ -1013,8 +1028,9 @@ std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl; ggml_vk_test_transfer(1024 * 1024 * m); } const std::vector vals { + 4096, 2, 4096, + 623, 111, 128, 100, 46, 558, - 1024, 2, 4096, 512, 1, 256, 128, 110, 622, 511, 511, 127, @@ -1042,11 +1058,27 @@ std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl; ggml_vk_test_matmul_f16(vals[i], vals[i + 1], vals[i + 2], num_it, 1, 2); ggml_vk_test_matmul_f16(vals[i], vals[i + 1], vals[i + 2], num_it, 4, 2); std::cerr << std::endl; + + ggml_vk_test_matmul_f32(vals[i], vals[i + 1], vals[i + 2], num_it, 1, 0); + ggml_vk_test_matmul_f32(vals[i], vals[i + 1], vals[i + 2], num_it, 4, 0); + ggml_vk_test_matmul_f32(vals[i], vals[i + 1], vals[i + 2], num_it, 1, 1); + ggml_vk_test_matmul_f32(vals[i], vals[i + 1], vals[i + 2], num_it, 4, 1); + ggml_vk_test_matmul_f32(vals[i], vals[i + 1], vals[i + 2], num_it, 1, 2); + ggml_vk_test_matmul_f32(vals[i], vals[i + 1], vals[i + 2], num_it, 4, 2); + std::cerr << std::endl; + + ggml_vk_test_matmul_f16_f32(vals[i], vals[i + 1], vals[i + 2], num_it, 1, 0); + ggml_vk_test_matmul_f16_f32(vals[i], vals[i + 1], vals[i + 2], num_it, 4, 0); + ggml_vk_test_matmul_f16_f32(vals[i], vals[i + 1], vals[i + 2], num_it, 1, 1); + ggml_vk_test_matmul_f16_f32(vals[i], vals[i + 1], vals[i + 2], num_it, 4, 1); + ggml_vk_test_matmul_f16_f32(vals[i], vals[i + 1], vals[i + 2], num_it, 1, 2); + ggml_vk_test_matmul_f16_f32(vals[i], vals[i + 1], vals[i + 2], num_it, 4, 2); + std::cerr << std::endl; } #endif } -static inline vk_pipeline* ggml_vk_get_to_fp16(ggml_type type) { +static vk_pipeline* ggml_vk_get_to_fp16(ggml_type type) { #ifdef VK_DEBUG std::cerr << "ggml_vk_get_to_fp16()" << std::endl; #endif @@ -1070,7 +1102,7 @@ static inline vk_pipeline* ggml_vk_get_to_fp16(ggml_type type) { return &vk_pipeline_dequant[type]; } -static inline vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_type type, bool f16_y) { +static vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_type type, bool f16_y) { #ifdef VK_DEBUG std::cerr << "ggml_vk_get_dequantize_mul_mat_vec()" << std::endl; #endif @@ -1401,7 +1433,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer* dst, size_t offset, const void * } } -static inline size_t ggml_vk_align_size(size_t width, size_t align) { +static size_t ggml_vk_align_size(size_t width, size_t align) { return CEIL_DIV(width, align) * align; } @@ -1536,7 +1568,7 @@ static void ggml_vk_buffer_read(vk_buffer* src, size_t offset, void * dst, size_ } } -static vk_sequence ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, vk_queue& q, std::vector wait_semaphores, std::vector signal_semaphores, vk_submission* s = nullptr, std::vector* pre_staging = nullptr) { +static vk_sequence ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, uint64_t i1, vk_queue& q, std::vector wait_semaphores, std::vector signal_semaphores, vk_submission* s = nullptr, std::vector* pre_staging = nullptr) { #ifdef VK_DEBUG std::cerr << "ggml_vk_h2d_tensor_2d()" << std::endl; #endif @@ -1553,10 +1585,10 @@ static vk_sequence ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const st const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3); if (nb0 == ts && nb1 == row_length) { - return ggml_vk_buffer_write_async(dst, offset, x, ne1*nb1, q, std::move(wait_semaphores), std::move(signal_semaphores), s, pre_staging); + return ggml_vk_buffer_write_async(dst, offset, x, i1*nb1, q, std::move(wait_semaphores), std::move(signal_semaphores), s, pre_staging); } if (nb0 == ts) { - return ggml_vk_buffer_write_2d_async(dst, offset, x, nb1, row_length, ne1, q, std::move(wait_semaphores), std::move(signal_semaphores), s, pre_staging); + return ggml_vk_buffer_write_2d_async(dst, offset, x, nb1, row_length, i1, q, std::move(wait_semaphores), std::move(signal_semaphores), s, pre_staging); } GGML_ASSERT(false); // TODO: also needs handling of staging buffers @@ -1571,7 +1603,7 @@ static vk_sequence ggml_vk_h2d_tensor_2d(vk_buffer* dst, size_t offset, const st static int ggml_vk_guess_split_k(int m, int n, int k, bool aligned) { #ifdef VK_DEBUG - std::cerr << "ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")"; + std::cerr << "ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ", " << aligned << ")"; #endif if (aligned && k > 128 && (m < 128 || n < 128)) { #ifdef VK_DEBUG @@ -1688,9 +1720,9 @@ static vk_sequence ggml_vk_matmul(vk_pipeline& pipeline, vk_subbuffer&& a, vk_su static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { #ifdef VK_DEBUG - std::cerr << "ggml_vk_mul_mat_f32((type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3]; - std::cerr << "), (type=" << src1->type << ", backend=" << src0->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3]; - std::cerr << "), (type=" << dst->type << ", backend=" << src0->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << "),)" << std::endl; + std::cerr << "ggml_vk_mul_mat_f32((" << src0 << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl; #endif const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; @@ -1760,7 +1792,7 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr vk_semaphore * sem = ggml_vk_create_timeline_semaphore(); x_semaphores.push_back({ sem->s, sem->value + 1 }); // Wait for previous matmul to be done before writing to the input buffers again - extra->in0_seqs.push_back(ggml_vk_h2d_tensor_2d(d_X, x_offset, src0, i03, i02, vk_device.transfer_queues[0], {}, { { sem->s, sem->value + 1 } }, nullptr, &extra->memcpys)); + extra->in0_seqs.push_back(ggml_vk_h2d_tensor_2d(d_X, x_offset, src0, i03, i02, ne01, vk_device.transfer_queues[0], {}, { { sem->s, sem->value + 1 } }, nullptr, &extra->memcpys)); sem->value += 1; } } @@ -1769,7 +1801,7 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr for (int64_t i13 = 0; i13 < ne13; i13++) { const int64_t i03 = i13 / r3; for (int64_t i12 = 0; i12 < ne12; i12++) { - int64_t i02 = i12 / r2; + const int64_t i02 = i12 / r2; const uint32_t x_offset = load_x ? x_sz * (i03 * ne02 + i02) : 0; const uint32_t y_offset = y_sz * (i13 * ne12 + i12); @@ -1782,7 +1814,7 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr semaphores.push_back(x_semaphores[i03 * ne02 + i02]); } - extra->in1_seqs.push_back(ggml_vk_h2d_tensor_2d(d_Y, y_offset, src1, i13, i12, vk_device.transfer_queues[1], {}, { { sem->s, sem->value + 1 } }, nullptr, &extra->memcpys)); + extra->in1_seqs.push_back(ggml_vk_h2d_tensor_2d(d_Y, y_offset, src1, i13, i12, ne11, vk_device.transfer_queues[1], {}, { { sem->s, sem->value + 1 } }, nullptr, &extra->memcpys)); // compute extra->comp_seqs.push_back(ggml_vk_matmul(*pipeline, { *d_X, x_offset, x_sz }, { *d_Y, y_offset, y_sz }, { *d_D, d_offset, d_sz }, ne01, ne11, ne10, ne10, ne10, ne01, split_k, vk_device.compute_queue, std::move(semaphores), { { sem->s, sem->value + 2 } })); @@ -1800,9 +1832,9 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr static void ggml_vk_mul_mat_q_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { #ifdef VK_DEBUG - std::cerr << "ggml_vk_mul_mat_q_f16((type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3]; - std::cerr << "), (type=" << src1->type << ", backend=" << src0->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3]; - std::cerr << "), (type=" << dst->type << ", backend=" << src0->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << "),)" << std::endl; + std::cerr << "ggml_vk_mul_mat_q_f16((" << src0 << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl; #endif const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; @@ -1848,6 +1880,13 @@ static void ggml_vk_mul_mat_q_f16(const ggml_tensor * src0, const ggml_tensor * const uint32_t y_sz = ggml_vk_align_size(f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne, vk_device.properties.limits.minStorageBufferOffsetAlignment); const uint32_t d_sz = ggml_vk_align_size(sizeof(float) * d_ne * split_k, vk_device.properties.limits.minStorageBufferOffsetAlignment); + if (dst->backend == GGML_BACKEND_GPU) { + if (d_sz != nb2) { + std::cerr << "ERROR: incompatible tensor alignment d_sz=" << d_sz << " nb2=" << nb2 << std::endl; + GGML_ASSERT(false); + } + } + ggml_vk_tensor_extra_gpu * extra = (ggml_vk_tensor_extra_gpu *) dst->extra; GGML_ASSERT(extra->comp_seqs.empty()); @@ -1858,33 +1897,33 @@ static void ggml_vk_mul_mat_q_f16(const ggml_tensor * src0, const ggml_tensor * } else { d_D = &vk_prealloc_d; } - GGML_ASSERT(d_D->size >= d_sz); + GGML_ASSERT(d_D->size >= d_sz * ne02 * ne03); vk_buffer* d_Qx; vk_buffer* d_Qy; vk_buffer* d_X; vk_buffer* d_Y; if (load_x) { d_Qx = &vk_prealloc_qx; - GGML_ASSERT(d_Qx->size >= qx_sz); + GGML_ASSERT(d_Qx->size >= qx_sz * ne02 * ne03); } else { d_Qx = (vk_buffer *) src0->data; } if (load_y) { d_Qy = &vk_prealloc_qy; - GGML_ASSERT(d_Qy->size >= qy_sz); + GGML_ASSERT(d_Qy->size >= qy_sz * ne02 * ne03); } else { d_Qy = (vk_buffer *) src1->data; } if (qx_needs_dequant) { d_X = &vk_prealloc_x; - GGML_ASSERT(d_X->size >= x_sz); + GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03); } else { d_X = d_Qx; GGML_ASSERT(qx_sz == x_sz); // NOLINT } if (qy_needs_dequant) { d_Y = &vk_prealloc_y; - GGML_ASSERT(d_Y->size >= y_sz); + GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03); } else { d_Y = d_Qy; GGML_ASSERT(qy_sz == y_sz); @@ -1919,7 +1958,7 @@ static void ggml_vk_mul_mat_q_f16(const ggml_tensor * src0, const ggml_tensor * if (load_x) { // copy data to device - extra->in0_seqs.push_back(ggml_vk_h2d_tensor_2d(d_Qx, qx_offset, src0, i03, i02, tr0q, {}, { { sem->s, sem->value + 1 } }, nullptr, &extra->memcpys)); + extra->in0_seqs.push_back(ggml_vk_h2d_tensor_2d(d_Qx, qx_offset, src0, i03, i02, ne01, tr0q, {}, { { sem->s, sem->value + 1 } }, nullptr, &extra->memcpys)); } if (qx_needs_dequant) { @@ -1950,7 +1989,7 @@ static void ggml_vk_mul_mat_q_f16(const ggml_tensor * src0, const ggml_tensor * for (int64_t i13 = 0; i13 < ne13; i13++) { const int64_t i03 = i13 / r3; for (int64_t i12 = 0; i12 < ne12; i12++) { - int64_t i02 = i12 / r2; + const int64_t i02 = i12 / r2; const uint32_t it_idx0 = (i03 * ne02 + i02); const uint32_t it_idx1 = (i13 * ne12 + i12); @@ -1968,7 +2007,7 @@ static void ggml_vk_mul_mat_q_f16(const ggml_tensor * src0, const ggml_tensor * } if (load_y) { // Set semaphore to 1 - extra->in1_seqs.push_back(ggml_vk_h2d_tensor_2d(d_Qy, qy_offset, src1, i13, i12, tr1q, {}, { { sem->s, sem->value + 1 }}, nullptr, &extra->memcpys)); + extra->in1_seqs.push_back(ggml_vk_h2d_tensor_2d(d_Qy, qy_offset, src1, i13, i12, ne11, tr1q, {}, { { sem->s, sem->value + 1 } }, nullptr, &extra->memcpys)); // Wait for semaphore val 1 mm_semaphores.push_back({ sem->s, sem->value + 1 }); } @@ -1989,9 +2028,9 @@ static void ggml_vk_mul_mat_q_f16(const ggml_tensor * src0, const ggml_tensor * static void ggml_vk_mul_mat_vec_q_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { #ifdef VK_DEBUG - std::cerr << "ggml_vk_mul_mat_vec_q_f16((type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3]; - std::cerr << "), (type=" << src1->type << ", backend=" << src0->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3]; - std::cerr << "), (type=" << dst->type << ", backend=" << src0->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << "),)" << std::endl; + std::cerr << "ggml_vk_mul_mat_vec_q_f16((" << src0 << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl; #endif const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; @@ -2028,6 +2067,8 @@ static void ggml_vk_mul_mat_vec_q_f16(const ggml_tensor * src0, const ggml_tenso const uint32_t y_sz = ggml_vk_align_size(f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne, vk_device.properties.limits.minStorageBufferOffsetAlignment); const uint32_t d_sz = ggml_vk_align_size(sizeof(float) * d_ne, vk_device.properties.limits.minStorageBufferOffsetAlignment); + GGML_ASSERT(qy_sz == src1->nb[2]); + ggml_vk_tensor_extra_gpu * extra = (ggml_vk_tensor_extra_gpu *) dst->extra; GGML_ASSERT(extra->comp_seqs.empty()); @@ -2068,7 +2109,7 @@ static void ggml_vk_mul_mat_vec_q_f16(const ggml_tensor * src0, const ggml_tenso for (int64_t i13 = 0; i13 < ne13; i13++) { const int64_t i03 = i13 / r3; for (int64_t i12 = 0; i12 < ne12; i12++) { - int64_t i02 = i12 / r2; + const int64_t i02 = i12 / r2; const uint32_t it_idx0 = (i03 * ne02 + i02); const uint32_t it_idx1 = (i13 * ne12 + i12); @@ -2081,7 +2122,7 @@ static void ggml_vk_mul_mat_vec_q_f16(const ggml_tensor * src0, const ggml_tenso vk_semaphore s_x; if (load_y) { - ggml_vk_h2d_tensor_2d(d_Qy, qy_offset, src1, i13, i12, compq, {}, {}, &s, &extra->memcpys); + ggml_vk_h2d_tensor_2d(d_Qy, qy_offset, src1, i13, i12, ne11, compq, {}, {}, &s, &extra->memcpys); } if (qy_needs_dequant) { @@ -2190,7 +2231,6 @@ static void ggml_vk_op_repeat(const ggml_tensor * src0, const ggml_tensor * src1 std::vector copies; - // TODO: very inefficient, implement in a kernel, or fewer cudaMemcpyAsync calls for contiguous tensors for (int i3 = 0; i3 < nr3; i3++) { for (int k3 = 0; k3 < ne03; k3++) { for (int i2 = 0; i2 < nr2; i2++) { @@ -2249,6 +2289,11 @@ static vk_pipeline* ggml_vk_op_get_pipeline(const ggml_tensor * src0, const ggml return &vk_pipeline_scale_f32; } return nullptr; + case GGML_OP_RMS_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return &vk_pipeline_rms_norm_f32; + } + return nullptr; default: return nullptr; } @@ -2263,15 +2308,24 @@ static ggml_vk_func_t ggml_vk_op_get_func(ggml_op op) { } } -static void ggml_vk_op_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_op op, float scale=1.0f) { +static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) { + return + tensor->nb[0] == ggml_type_size(tensor->type) && + tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + +static void ggml_vk_op_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_op op, float param=1.0f) { #ifdef VK_DEBUG - std::cerr << "ggml_vk_op_f32((type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3]; + std::cerr << "ggml_vk_op_f32((" << src0 << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; if (src1 != nullptr) { - std::cerr << "), (type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3]; + std::cerr << "), (" << src1 << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; } - std::cerr << "), (type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << "), " << ggml_op_name(op) << ")" << std::endl; + std::cerr << "), (" << dst << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "), " << ggml_op_name(op) << ")" << std::endl; #endif GGML_ASSERT(!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type))); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src0)); + GGML_ASSERT(src1 == nullptr || ggml_vk_dim01_contiguous(src1)); // NOLINT GGML_ASSERT(dst->extra != nullptr); const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; @@ -2288,7 +2342,7 @@ static void ggml_vk_op_f32(const ggml_tensor * src0, const ggml_tensor * src1, g const int nb2 = dst->nb[2]; const int nb3 = dst->nb[3]; - GGML_ASSERT(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] == ne0 * ne02 * ne03); + GGML_ASSERT(dst->ne[0] * dst->ne[1] == ne0); GGML_ASSERT(!use_src1 || nb10 == sizeof(float)); // NOLINT vk_pipeline * pipeline = ggml_vk_op_get_pipeline(src0, src1, dst, op); @@ -2332,80 +2386,131 @@ static void ggml_vk_op_f32(const ggml_tensor * src0, const ggml_tensor * src1, g d_Y = (vk_buffer *) src1->data; } - // Allocate descriptor sets - ggml_vk_pipeline_allocate_descriptor_sets(*pipeline, ne02 * ne03); + vk_op_push_constants pc; + std::array elements; - vk_op_push_constants pc = { (int)ne00, (int)ne01, (int)ne00, (int)ne00, (int)ne00, 0, 0, 0, scale }; + std::vector transfer_semaphores; + // copy src0 to device + if (transfer_src0) { + vk_semaphore * sem_x = ggml_vk_create_timeline_semaphore(); + extra->in0_seqs.push_back(ggml_vk_h2d_tensor_2d(d_X, 0, src0, 0, 0, ggml_nrows(src0), vk_device.transfer_queues[0], {}, { { sem_x->s, sem_x->value + 1 } }, nullptr, &extra->memcpys)); + transfer_semaphores.push_back({ sem_x->s, sem_x->value + 1}); + sem_x->value += 1; + } + if (transfer_src1) { + vk_semaphore * sem_y = ggml_vk_create_timeline_semaphore(); + extra->in1_seqs.push_back(ggml_vk_h2d_tensor_2d(d_Y, 0, src1, 0, 0, ggml_nrows(src1), vk_device.transfer_queues[1], {}, { { sem_y->s, sem_y->value + 1 } }, nullptr, &extra->memcpys)); + transfer_semaphores.push_back({ sem_y->s, sem_y->value + 1 }); + sem_y->value += 1; + } - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - const uint32_t it_idx = (i03 * ne02 + i02); + // Single call if dimension 2 is contiguous + if (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1))) { + ggml_vk_pipeline_allocate_descriptor_sets(*pipeline, 1); - const uint32_t x_offset = transfer_src0 ? x_sz * it_idx : 0; - const uint32_t y_offset = transfer_src1 ? y_sz * it_idx : 0; - const uint32_t d_offset = d_sz * it_idx; + switch (dst->op) { + case GGML_OP_RMS_NORM: + pc = { (int)src0->ne[0], (int)src0->ne[1], param }; + elements = { (uint32_t)ggml_nrows(src0), 1, 1 }; + break; + default: + pc = { (int)ggml_nelements(src0), (int)(src1 != nullptr ? ggml_nelements(src1) : 0), param }; + elements = { (uint32_t)ggml_nelements(src0), 1, 1 }; + break; + } - vk_semaphore * sem = ggml_vk_create_timeline_semaphore(); - vk_semaphore * sem_x; - std::vector transfer_semaphores; - // copy src0 to device - if (transfer_src0) { - sem_x = ggml_vk_create_binary_semaphore(); - extra->in0_seqs.push_back(ggml_vk_h2d_tensor_2d(d_X, x_offset, src0, i03, i02, vk_device.transfer_queues[0], {}, { *sem_x }, nullptr, &extra->memcpys)); - transfer_semaphores.push_back(*sem_x); - } - if (transfer_src1) { - extra->in1_seqs.push_back(ggml_vk_h2d_tensor_2d(d_Y, y_offset, src1, i03, i02, vk_device.transfer_queues[1], {}, { { sem->s, sem->value + 1 } }, nullptr, &extra->memcpys)); - transfer_semaphores.push_back({ sem->s, sem->value + 1 }); - } - - const int64_t i13 = use_src1 ? i03%ne13 : i03; - const int64_t i12 = use_src1 ? i02%ne12 : i02; - pc.y_offset = (i13*ne12*ne11 + i12*ne11) * ne10; - - vk_submission s = ggml_vk_begin_submission(vk_device.compute_queue); - ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(*d_D) }, vk_device.compute_queue, vk::AccessFlagBits::eTransferRead, vk::AccessFlagBits::eShaderWrite, false); - if (use_src1) { - ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(*d_X), ggml_vk_subbuffer(*d_Y) }, vk_device.compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false); - ggml_vk_dispatch_pipeline(s, *pipeline, { { *d_X, x_offset, x_sz }, { *d_Y, y_offset, y_sz }, { *d_D, d_offset, d_sz } }, sizeof(vk_op_push_constants), &pc, { (uint32_t)ne00, (uint32_t)ne01, 1}); - } else { - ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(*d_X) }, vk_device.compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false); - ggml_vk_dispatch_pipeline(s, *pipeline, { { *d_X, x_offset, x_sz }, { *d_D, d_offset, d_sz } }, sizeof(vk_op_push_constants), &pc, { (uint32_t)ne00, (uint32_t)ne01, 1}); - } - ggml_vk_end_submission(s, std::move(transfer_semaphores), { { sem->s, sem->value + 2 } }); + vk_submission s = ggml_vk_begin_submission(vk_device.compute_queue); + ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(*d_D) }, vk_device.compute_queue, vk::AccessFlagBits::eTransferRead, vk::AccessFlagBits::eShaderWrite, false); + if (use_src1) { + ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(*d_X), ggml_vk_subbuffer(*d_Y) }, vk_device.compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false); + ggml_vk_dispatch_pipeline(s, *pipeline, { ggml_vk_subbuffer(*d_X), ggml_vk_subbuffer(*d_Y), ggml_vk_subbuffer(*d_D) }, sizeof(vk_op_push_constants), &pc, elements); + } else { + ggml_vk_sync_buffers(s.buffer, { ggml_vk_subbuffer(*d_X) }, vk_device.compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false); + ggml_vk_dispatch_pipeline(s, *pipeline, { ggml_vk_subbuffer(*d_X), ggml_vk_subbuffer(*d_D) }, sizeof(vk_op_push_constants), &pc, elements); + } + if (dst->backend == GGML_BACKEND_CPU) { + vk_semaphore * fsem = ggml_vk_create_binary_semaphore(); + ggml_vk_end_submission(s, std::move(transfer_semaphores), { *fsem }); extra->comp_seqs.push_back({ s }); - if (dst->backend == GGML_BACKEND_CPU) { - // copy dst to host - float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); - extra->out_seqs.push_back(ggml_vk_buffer_read_async(d_D, d_offset, d, sizeof(float) * ne00 * ne01, vk_device.transfer_queues[1], { { sem->s, sem->value + 2 } }, {})); - } + // copy dst to host + float * d = (float *) dst->data; + extra->out_seqs.push_back(ggml_vk_buffer_read_async(d_D, 0, d, sizeof(float) * ggml_nelements(src0), vk_device.transfer_queues[1], { *fsem }, {})); + } else { + ggml_vk_end_submission(s, std::move(transfer_semaphores), {}); + extra->comp_seqs.push_back({ s }); + } + } else { + ggml_vk_pipeline_allocate_descriptor_sets(*pipeline, ne02 * ne03); - sem->value += 2; + switch (dst->op) { + case GGML_OP_RMS_NORM: + pc = { (int)src0->ne[0], (int)src0->ne[1], param }; + elements = { (uint32_t)ne01, 1, 1 }; + break; + default: + pc = { (int)ne0, (int)(src1 != nullptr ? ne1 : 0), param }; + elements = { (uint32_t)ne0, 1, 1 }; + break; + } + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + const uint32_t it_idx0 = (i03 * ne02 + i02); + const uint32_t it_idx1 = use_src1 ? ((i03 % ne13) * ne12 + (i02 % ne12)) : 0; + const uint32_t x_offset = x_sz * it_idx0; + const uint32_t y_offset = y_sz * it_idx1; + const uint32_t d_offset = d_sz * it_idx0; + + vk_submission s = ggml_vk_begin_submission(vk_device.compute_queue); + ggml_vk_sync_buffers(s.buffer, { { *d_D, d_offset, d_sz } }, vk_device.compute_queue, vk::AccessFlagBits::eTransferRead, vk::AccessFlagBits::eShaderWrite, false); + if (use_src1) { + ggml_vk_sync_buffers(s.buffer, { { *d_X, x_offset, x_sz }, { *d_Y, y_offset, y_sz } }, vk_device.compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false); + ggml_vk_dispatch_pipeline(s, *pipeline, { { *d_X, x_offset, x_sz }, { *d_Y, y_offset, y_sz }, { *d_D, d_offset, d_sz } }, sizeof(vk_op_push_constants), &pc, elements); + } else { + ggml_vk_sync_buffers(s.buffer, { { *d_X, x_offset, x_sz } }, vk_device.compute_queue, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false); + ggml_vk_dispatch_pipeline(s, *pipeline, { { *d_X, x_offset, x_sz }, { *d_D, d_offset, d_sz } }, sizeof(vk_op_push_constants), &pc, elements); + } + if (dst->backend == GGML_BACKEND_CPU) { + vk_semaphore * fsem = ggml_vk_create_binary_semaphore(); + ggml_vk_end_submission(s, std::move(transfer_semaphores), { *fsem }); + extra->comp_seqs.push_back({ s }); + + // copy dst to host + extra->out_seqs.push_back(ggml_vk_buffer_read_async(d_D, d_offset, (char *) dst->data + i02*nb2 + i03*nb3, sizeof(float) * ne0, vk_device.transfer_queues[1], { *fsem }, {})); + } else { + ggml_vk_end_submission(s, std::move(transfer_semaphores), {}); + extra->comp_seqs.push_back({ s }); + } + } } } } -static void ggml_vk_repeat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { +static void ggml_vk_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_vk_op_f32(src0, src1, dst, GGML_OP_REPEAT); } -static void ggml_vk_get_rows(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { +static void ggml_vk_get_rows(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_vk_op_f32(src0, src1, dst, GGML_OP_GET_ROWS); } -static void ggml_vk_add(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { +static void ggml_vk_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_vk_op_f32(src0, src1, dst, GGML_OP_ADD); } -static void ggml_vk_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { +static void ggml_vk_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_vk_op_f32(src0, src1, dst, GGML_OP_MUL); } -static void ggml_vk_scale(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { +static void ggml_vk_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_vk_op_f32(src0, nullptr, dst, GGML_OP_SCALE, ((float *)src1->data)[0]); } +static void ggml_vk_rms_norm(const ggml_tensor * src0, ggml_tensor * dst) { + ggml_vk_op_f32(src0, nullptr, dst, GGML_OP_RMS_NORM, ((float *)src0->op_params)[0]); +} + void ggml_vk_transform_tensor(void * data, ggml_tensor * tensor) { #ifdef VK_DEBUG std::cerr << "ggml_vk_transform_tensor(" << data << ", " << tensor << ")" << std::endl; @@ -2426,7 +2531,7 @@ void ggml_vk_transform_tensor(void * data, ggml_tensor * tensor) { tensor->data = data; // copy tensor to device - seqs.push_back(ggml_vk_h2d_tensor_2d(&dst, 0, tensor, 0, 0, vk_device.transfer_queues[0], {}, {})); + seqs.push_back(ggml_vk_h2d_tensor_2d(&dst, 0, tensor, 0, 0, ne1, vk_device.transfer_queues[0], {}, {})); ggml_vk_submit(vk_device.transfer_queues[0], seqs, VK_NULL_HANDLE); vk_device.transfer_queues[0].queue.waitIdle(); @@ -2436,10 +2541,21 @@ void ggml_vk_transform_tensor(void * data, ggml_tensor * tensor) { GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); } +static void ggml_vk_realign_tensor(ggml_tensor * tensor) { + // Handle split-k, which needs more space per MM + ggml_vk_tensor_extra_gpu * extra = (ggml_vk_tensor_extra_gpu *) tensor->extra; + + tensor->nb[2] = ggml_vk_align_size(std::max(extra->tensor_size / tensor->ne[3] / tensor->ne[2], tensor->nb[1]*tensor->ne[1]), vk_device.properties.limits.minStorageBufferOffsetAlignment); + for (int i = 3; i < GGML_MAX_DIMS; i++) { + tensor->nb[i] = tensor->nb[i - 1]*tensor->ne[i - 1]; + } +} + static ggml_vk_tensor_extra_gpu * ggml_vk_preallocate_buffers(uint32_t d_size, uint32_t qx_size, uint32_t qy_size, uint32_t x_size, uint32_t y_size) { ggml_vk_tensor_extra_gpu * extra = new ggml_vk_tensor_extra_gpu; extra->tensor_size = d_size; + extra->gpu_buffer = nullptr; // Check if buffer already exists, increase size if required if (vk_prealloc_size_d < d_size) { @@ -2469,8 +2585,8 @@ void ggml_vk_preallocate_buffers_graph(ggml_tensor * node){ #endif node->extra = nullptr; - const bool src0_gpu = false; // node->src[0] != nullptr && node->src[0]->extra != nullptr && node->src[0]->backend == GGML_BACKEND_CPU; - const bool src1_gpu = false; // node->src[1] != nullptr && node->src[1]->extra != nullptr && node->src[1]->backend == GGML_BACKEND_CPU; + const bool src0_gpu = node->src[0] != nullptr && node->src[0]->ne[1] > 32 && node->src[0]->extra != nullptr && node->src[0]->backend == GGML_BACKEND_CPU; + const bool src1_gpu = node->src[1] != nullptr && node->src[1]->ne[1] > 32 && node->src[1]->extra != nullptr && node->src[1]->backend == GGML_BACKEND_CPU; const bool any_on_device = node->backend == GGML_BACKEND_GPU || (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_GPU || node->src[0]->backend == GGML_BACKEND_GPU_SPLIT || src0_gpu)) @@ -2488,7 +2604,7 @@ void ggml_vk_preallocate_buffers_graph(ggml_tensor * node){ const int64_t ne01 = use_src0 ? src0->ne[1] : 0; const int64_t ne02 = use_src0 ? src0->ne[2] : 0; const int64_t ne03 = use_src0 ? src0->ne[3] : 0; - const bool use_src1 = src1 != nullptr; + const bool use_src1 = src1 != nullptr && node->op != GGML_OP_SCALE; const int64_t ne10 = use_src1 ? src1->ne[0] : 0; const int64_t ne11 = use_src1 ? src1->ne[1] : 0; const int64_t ne12 = use_src1 ? src1->ne[2] : 0; @@ -2534,6 +2650,7 @@ void ggml_vk_preallocate_buffers_graph(ggml_tensor * node){ case GGML_OP_ADD: case GGML_OP_SCALE: case GGML_OP_MUL: + case GGML_OP_RMS_NORM: node->extra = ggml_vk_preallocate_buffers(d_sz, transfer_src0 ? qx_sz : 0, transfer_src1 ? qy_sz : 0, 0, 0); break; case GGML_OP_MUL_MAT: @@ -2548,22 +2665,27 @@ void ggml_vk_preallocate_buffers_graph(ggml_tensor * node){ } // Reuse GPU buffer if previous op is also on GPU - // if (src0_gpu) { - // src0->backend = GGML_BACKEND_GPU; - // ggml_vk_tensor_extra_gpu * src0_extra = (ggml_vk_tensor_extra_gpu *) src0->extra; + if (src0_gpu) { + src0->backend = GGML_BACKEND_GPU; + ggml_vk_tensor_extra_gpu * src0_extra = (ggml_vk_tensor_extra_gpu *) src0->extra; - // // Replace with data GPU tensor - // src0->data = malloc(sizeof(vk_buffer)); - // *(vk_buffer*) src0->data = ggml_vk_create_buffer(src0_extra->tensor_size, vk::MemoryPropertyFlagBits::eDeviceLocal); - // } - // if (src1_gpu) { - // src1->backend = GGML_BACKEND_GPU; - // ggml_vk_tensor_extra_gpu * src1_extra = (ggml_vk_tensor_extra_gpu *) src1->extra; + // Replace with data GPU tensor + src0->data = malloc(sizeof(vk_buffer)); + ggml_vk_pool_malloc(src0_extra->tensor_size, (vk_buffer *)src0->data, vk::MemoryPropertyFlagBits::eDeviceLocal); - // // Replace with data GPU tensor - // src1->data = malloc(sizeof(vk_buffer)); - // *(vk_buffer*) src1->data = ggml_vk_create_buffer(src1_extra->tensor_size, vk::MemoryPropertyFlagBits::eDeviceLocal); - // } + // Handle buffer offset alignment issues in 2nd and 3rd dimensions early by changing stride + ggml_vk_realign_tensor(src0); + } + if (src1_gpu) { + src1->backend = GGML_BACKEND_GPU; + ggml_vk_tensor_extra_gpu * src1_extra = (ggml_vk_tensor_extra_gpu *) src1->extra; + + // Replace with data GPU tensor + src1->data = malloc(sizeof(vk_buffer)); + ggml_vk_pool_malloc(src1_extra->tensor_size, (vk_buffer *)src1->data, vk::MemoryPropertyFlagBits::eDeviceLocal); + + ggml_vk_realign_tensor(src1); + } } void ggml_vk_preallocate_buffers() { @@ -2639,6 +2761,10 @@ void ggml_vk_build_graph(ggml_tensor * node){ case GGML_OP_SCALE: ggml_vk_scale(node->src[0], node->src[1], node); + break; + case GGML_OP_RMS_NORM: + ggml_vk_rms_norm(node->src[0], node); + break; case GGML_OP_MUL_MAT: if (!any_on_device && !ggml_vk_can_mul_mat(node->src[0], node->src[1], node)) { @@ -2649,6 +2775,10 @@ void ggml_vk_build_graph(ggml_tensor * node){ break; default: + if (any_on_device) { + std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl; + GGML_ASSERT(false); + } return; } } @@ -2669,6 +2799,7 @@ bool ggml_vk_compute_forward(ggml_compute_params * params, ggml_tensor * tensor) case GGML_OP_GET_ROWS: case GGML_OP_MUL: case GGML_OP_SCALE: + case GGML_OP_RMS_NORM: extra = (ggml_vk_tensor_extra_gpu *) tensor->extra; break; @@ -2747,13 +2878,82 @@ void ggml_vk_graph_cleanup() { vk_gc.tl_semaphores.clear(); for (auto * extra : vk_gc.extras) { + if (extra->gpu_buffer != nullptr) { + ggml_vk_pool_free(*extra->gpu_buffer); + } delete extra; } vk_gc.extras.clear(); } #ifdef GGML_VULKAN_CHECK_RESULTS +void ggml_vk_print_graph_origin(const ggml_tensor * tensor, std::vector& done, int level = 0) { + if (std::find(done.begin(), done.end(), tensor) != done.end()) { + return; + } + for (int j = 0; j < level; j++) { + std::cerr << " "; + } + std::cerr << ggml_op_name(tensor->op) << " " << tensor->backend << std::endl; + + done.push_back(tensor); + + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] != nullptr) { + ggml_vk_print_graph_origin(tensor->src[i], done, level + 1); + } + } +} + +void ggml_vk_check_tensor(const std::string& name, const ggml_tensor * tensor) { + if (tensor->type != GGML_TYPE_F32) { + return; + } + for (int i3 = 0; i3 < tensor->ne[3]; i3++) { + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + const float val = *(float *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); + if (std::isnan(val)) { + std::cerr << "ERROR: TENSOR CHECK " << name << ": Invalid value in " << ggml_op_name(tensor->op) << " i3=" << i3 << " i2=" << i2 << " i1=" << i1 << " i0=" << i0 << " val=" << val << std::endl; + std::cerr << "tensor->backend: " << tensor->backend << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << std::endl; + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + GGML_ASSERT(false); + } + } + } + } + } +} + +void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) { + fprintf(stderr, " "); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + fprintf(stderr, "%7d ", idx1); + } + fprintf(stderr, "\n"); + for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { + fprintf(stderr, "%7d: ", idx0); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1]) { + float val = *(float *) ((char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); + fprintf(stderr, "% 7.2f ", val); + } else { + fprintf(stderr, " "); + } + } + fprintf(stderr, "\n"); + } +} + +size_t ggml_vk_tensor_size(const ggml_tensor * tensor) { + return std::max(tensor->ne[3]*tensor->nb[3], tensor->nb[1] * tensor->ne[1] * tensor->ne[2] * tensor->ne[3]); +} + void * comp_result; +size_t comp_nb[GGML_MAX_DIMS]; void ggml_vk_check_results_0(ggml_compute_params * params, ggml_tensor * tensor) { if (params->ith != 0) { return; @@ -2783,35 +2983,68 @@ void ggml_vk_check_results_0(ggml_compute_params * params, ggml_tensor * tensor) if (src0 != nullptr) { src0_clone = ggml_dup_tensor(ctx, src0); - // Some tensors have wrong strides for some reason - src0_size = src0->nb[1] * src0->ne[1] * src0->ne[2] * src0->ne[3]; + src0_size = ggml_vk_tensor_size(src0); src0_clone->data = malloc(src0_size); if (src0->backend == GGML_BACKEND_CPU) { - memcpy(src0_clone->data, src0->data, src0_size); + memcpy(src0_clone->data, src0->data, src0_size); + memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); } else if (src0->backend == GGML_BACKEND_GPU) { - ggml_vk_buffer_read((vk_buffer *)src0->data, 0, src0_clone->data, src0_size, vk_device.transfer_queues[0]); + if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) { + for (int i3 = 0; i3 < src0->ne[3]; i3++) { + for (int i2 = 0; i2 < src0->ne[2]; i2++) { + const int idx = i3*src0->ne[2] + i2; + ggml_vk_buffer_read((vk_buffer *)src0->data, (idx) * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1], vk_device.transfer_queues[0]); + } + } + + src0_clone->nb[0] = src0->nb[0]; + src0_clone->nb[1] = src0->nb[1]; + for (int i = 2; i < GGML_MAX_DIMS; i++) { + src0_clone->nb[i] = src0_clone->nb[i - 1]*src0_clone->ne[i - 1]; + } + } else { + ggml_vk_buffer_read((vk_buffer *)src0->data, 0, src0_clone->data, src0_size, vk_device.transfer_queues[0]); + memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); + } } else { GGML_ASSERT(false); } - memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); + ggml_vk_check_tensor(std::string(ggml_op_name(tensor->op)) + "->src0", src0_clone); } if (src1 != nullptr) { src1_clone = ggml_dup_tensor(ctx, src1); - src1_size = src1->ne[3] * src1->nb[3]; + src1_size = ggml_vk_tensor_size(src1); src1_clone->data = malloc(src1_size); if (src1->backend == GGML_BACKEND_CPU) { - memcpy(src1_clone->data, src1->data, src1_size); + memcpy(src1_clone->data, src1->data, src1_size); + memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); } else if (src1->backend == GGML_BACKEND_GPU) { - ggml_vk_buffer_read((vk_buffer *)src1->data, 0, src1_clone->data, src1_size, vk_device.transfer_queues[0]); + if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) { + for (int i3 = 0; i3 < src1->ne[3]; i3++) { + for (int i2 = 0; i2 < src1->ne[2]; i2++) { + const int idx = i3*src1->ne[2] + i2; + ggml_vk_buffer_read((vk_buffer *)src1->data, (idx) * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1], vk_device.transfer_queues[0]); + } + } + + src1_clone->nb[0] = src1->nb[0]; + src1_clone->nb[1] = src1->nb[1]; + for (int i = 2; i < GGML_MAX_DIMS; i++) { + src1_clone->nb[i] = src1_clone->nb[i - 1]*src1_clone->ne[i - 1]; + } + } else { + ggml_vk_buffer_read((vk_buffer *)src1->data, 0, src1_clone->data, src1_size, vk_device.transfer_queues[0]); + memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); + } } else { GGML_ASSERT(false); } - memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); + ggml_vk_check_tensor(std::string(ggml_op_name(tensor->op)) + "->src1", src1_clone); } if (tensor->op == GGML_OP_MUL_MAT) { @@ -2821,8 +3054,10 @@ void ggml_vk_check_results_0(ggml_compute_params * params, ggml_tensor * tensor) } else if (tensor->op == GGML_OP_SCALE) { tensor_clone = ggml_scale(ctx, src0_clone, src1_clone); } else if (tensor->op == GGML_OP_ADD) { - tensor_clone = ggml_add(ctx, src1_clone, src1_clone); - }else { + tensor_clone = ggml_add(ctx, src0_clone, src1_clone); + } else if (tensor->op == GGML_OP_RMS_NORM) { + tensor_clone = ggml_rms_norm(ctx, src0_clone, *(float *)tensor->op_params); + } else { std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ASSERT(false); } @@ -2831,13 +3066,20 @@ void ggml_vk_check_results_0(ggml_compute_params * params, ggml_tensor * tensor) ggml_graph_compute_with_ctx(ctx, &cgraph, 8); - size_t tensor_size = tensor_clone->ne[3] * tensor_clone->nb[3]; + ggml_vk_check_tensor(ggml_op_name(tensor->op), tensor_clone); + + size_t tensor_size = ggml_vk_tensor_size(tensor); comp_result = malloc(tensor_size); memcpy(comp_result, tensor_clone->data, tensor_size); + memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS); - free(src0_clone->data); - free(src1_clone->data); + if (src0 != nullptr) { + free(src0_clone->data); + } + if (src1 != nullptr) { + free(src1_clone->data); + } ggml_free(ctx); } @@ -2856,54 +3098,82 @@ void ggml_vk_check_results_1(ggml_compute_params * params, ggml_tensor * tensor) void * tensor_data = tensor->data; if (tensor->backend == GGML_BACKEND_GPU) { - const size_t tensor_size = tensor->nb[1] * tensor->ne[1] * tensor->ne[2] * tensor->ne[3]; + const size_t tensor_size = ggml_vk_tensor_size(tensor); tensor_data = malloc(tensor_size); ggml_vk_buffer_read((vk_buffer *)tensor->data, 0, tensor_data, tensor_size, vk_device.transfer_queues[0]); } - double avg_err = 0.0f; + float first_error_result = -1.0f; + float first_error_correct = -1.0f; + std::array first_error = { -1, -1, -1, -1 }; + double avg_err = 0.0; + size_t counter = 0; for (int i3 = 0; i3 < tensor->ne[3]; i3++) { for (int i2 = 0; i2 < tensor->ne[2]; i2++) { for (int i1 = 0; i1 < tensor->ne[1]; i1++) { for (int i0 = 0; i0 < tensor->ne[0]; i0++) { if (tensor->type == GGML_TYPE_F32) { - float correct = *(float *) ((char *) comp_result + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); + float correct = *(float *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); float result = *(float *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); - if (std::isnan(correct) || std::isnan(result)) { - std::cerr << "ERROR: NaN value in " << ggml_op_name(tensor->op) << " i3=" << i3 << " i2=" << i2 << " i1=" << i1 << " i0=" << i0 << " result=" << result << " correct=" << correct << std::endl; - std::cerr << "tensor->backend: " << tensor->backend << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << std::endl; - if (tensor->src[0] != nullptr) { - std::cerr << "src0 " << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " backend=" << src0->backend << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << std::endl; + if (std::isnan(correct) || std::isnan(result) || std::isnan(avg_err)) { + std::cerr << "ERROR: Invalid value in " << ggml_op_name(tensor->op) << " i3=" << i3 << " i2=" << i2 << " i1=" << i1 << " i0=" << i0 << " result=" << result << " correct=" << correct << " avg_err=" << (avg_err / counter) << std::endl; + std::cerr << "tensor=" << tensor << " tensor->backend: " << tensor->backend << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << std::endl; + if (src0 != nullptr) { + std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " backend=" << src0->backend << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << std::endl; } - if (tensor->src[1] != nullptr) { - std::cerr << "src1 " << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " backend=" << src1->backend << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << std::endl; + if (src1 != nullptr) { + std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " backend=" << src1->backend << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << std::endl; } + std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; + std::cerr << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); GGML_ASSERT(false); } + if (first_error[0] == -1 && std::fabs(correct - result) > 0.1f) { + first_error[0] = i0; + first_error[1] = i1; + first_error[2] = i2; + first_error[3] = i3; + first_error_result = result; + first_error_correct = correct; + } - avg_err += std::fabs(correct - result); + // Special case, value is infinite, avoid NaN result in avg_err + if (!std::isinf(correct) || !std::isinf(result) || correct != result) { + avg_err += std::fabs(correct - result); + } } else { GGML_ASSERT(false); } + counter++; } } } } - avg_err /= tensor->ne[3] * tensor->ne[2] * tensor->ne[1] * tensor->ne[0]; + avg_err /= counter; if (avg_err > 0.1 || std::isnan(avg_err)) { std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << std::endl; - std::cerr << "tensor->backend: " << tensor->backend << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << std::endl; - if (tensor->src[0] != nullptr) { - std::cerr << "src0 type=" << ggml_type_name(src0->type) << " backend=" << src0->backend << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << std::endl; + std::cerr << "tensor=" << tensor << " tensor->backend: " << tensor->backend << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << std::endl; + if (src0 != nullptr) { + std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " backend=" << src0->backend << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << std::endl; } - if (tensor->src[1] != nullptr) { - std::cerr << "src1 type=" << ggml_type_name(src1->type) << " backend=" << src1->backend << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << std::endl; + if (src1 != nullptr) { + std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " backend=" << src1->backend << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << std::endl; } + std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; + std::cerr << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); GGML_ASSERT(false); } diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py index 8569806ae..ec1fd55f3 100644 --- a/ggml_vk_generate_shaders.py +++ b/ggml_vk_generate_shaders.py @@ -157,82 +157,82 @@ struct block_q6_K # Dequant functions shader_f16_dequant_func = """ -#define DEQUANT_FUNC f16vec2 v = f16vec2(x[ib + 0], x[ib + 1]); +#define DEQUANT_FUNC f16vec2 v = f16vec2(data_a[ib + 0], data_a[ib + 1]); """ shader_f16_dequant_func_compat = """ -#define DEQUANT_FUNC vec2 v = vec2(x[ib + 0], x[ib + 1]); +#define DEQUANT_FUNC vec2 v = vec2(data_a[ib + 0], data_a[ib + 1]); """ shader_q4_0_dequant_func = """ -#define DEQUANT_FUNC const float16_t d = x[ib].d; \ -const uint8_t vui = x[ib].qs[iqs]; \ +#define DEQUANT_FUNC const float16_t d = data_a[ib].d; \ +const uint8_t vui = data_a[ib].qs[iqs]; \ f16vec2 v = f16vec2(vui & 0xF, vui >> 4); \ v = (v - 8.0hf)*d; """ shader_q4_0_dequant_func_compat = """ -#define DEQUANT_FUNC const float d = float(x[ib].d); \ -const uint vui = uint(x[ib].qs[iqs]); \ +#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ +const uint vui = uint(data_a[ib].qs[iqs]); \ vec2 v = vec2(vui & 0xF, vui >> 4); \ v = (v - 8.0f)*d; """ shader_q4_1_dequant_func = """ -#define DEQUANT_FUNC const float16_t d = x[ib].d; \ -const float16_t m = x[ib].m; \ -const uint8_t vui = x[ib].qs[iqs]; \ +#define DEQUANT_FUNC const float16_t d = data_a[ib].d; \ +const float16_t m = data_a[ib].m; \ +const uint8_t vui = data_a[ib].qs[iqs]; \ f16vec2 v = f16vec2(vui & 0xF, vui >> 4); \ v = v*d + m; """ shader_q4_1_dequant_func_compat = """ -#define DEQUANT_FUNC const float d = float(x[ib].d); \ -const float m = float(x[ib].m); \ -const uint vui = uint(x[ib].qs[iqs]); \ +#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ +const float m = float(data_a[ib].m); \ +const uint vui = uint(data_a[ib].qs[iqs]); \ vec2 v = vec2(vui & 0xF, vui >> 4); \ v = v*d + m; """ shader_q5_0_dequant_func = """ -#define DEQUANT_FUNC const float16_t d = x[ib].d; \ -const uint uint_qh = uint(x[ib].qh[1]) << 16 | x[ib].qh[0]; \ +#define DEQUANT_FUNC const float16_t d = data_a[ib].d; \ +const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; \ const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \ -const uint8_t vui = x[ib].qs[iqs]; \ +const uint8_t vui = data_a[ib].qs[iqs]; \ f16vec2 v = f16vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \ v = (v - 16.0hf) * d; """ shader_q5_0_dequant_func_compat = """ -#define DEQUANT_FUNC const float d = float(x[ib].d); \ -const uint uint_qh = uint(x[ib].qh[1]) << 16 | x[ib].qh[0]; \ +#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ +const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; \ const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \ -const uint vui = uint(x[ib].qs[iqs]); \ +const uint vui = uint(data_a[ib].qs[iqs]); \ vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \ v = (v - 16.0f) * d; """ shader_q5_1_dequant_func = """ -#define DEQUANT_FUNC const float16_t d = x[ib].d; \ -const float16_t m = x[ib].m; \ -const ivec2 qh = ivec2(((x[ib].qh >> iqs) << 4) & 0x10, (x[ib].qh >> (iqs + 12)) & 0x10); \ -const uint8_t vui = x[ib].qs[iqs]; \ +#define DEQUANT_FUNC const float16_t d = data_a[ib].d; \ +const float16_t m = data_a[ib].m; \ +const ivec2 qh = ivec2(((data_a[ib].qh >> iqs) << 4) & 0x10, (data_a[ib].qh >> (iqs + 12)) & 0x10); \ +const uint8_t vui = data_a[ib].qs[iqs]; \ f16vec2 v = f16vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \ v = v*d + m; """ shader_q5_1_dequant_func_compat = """ -#define DEQUANT_FUNC const float d = float(x[ib].d); \ -const float m = float(x[ib].m); \ -const ivec2 qh = ivec2(((x[ib].qh >> iqs) << 4) & 0x10, (x[ib].qh >> (iqs + 12)) & 0x10); \ -const uint vui = uint(x[ib].qs[iqs]); \ +#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ +const float m = float(data_a[ib].m); \ +const ivec2 qh = ivec2(((data_a[ib].qh >> iqs) << 4) & 0x10, (data_a[ib].qh >> (iqs + 12)) & 0x10); \ +const uint vui = uint(data_a[ib].qs[iqs]); \ vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \ v = v*d + m; """ shader_q8_0_dequant_func = """ -#define DEQUANT_FUNC const float16_t d = x[ib].d; \ -f16vec2 v = f16vec2(x[ib].qs[iqs], x[ib].qs[iqs + 1]); \ +#define DEQUANT_FUNC const float16_t d = data_a[ib].d; \ +f16vec2 v = f16vec2(data_a[ib].qs[iqs], data_a[ib].qs[iqs + 1]); \ v = v * d; """ shader_q8_0_dequant_func_compat = """ -#define DEQUANT_FUNC const float d = float(x[ib].d); \ -vec2 v = vec2(int(x[ib].qs[iqs]), int(x[ib].qs[iqs + 1])); \ +#define DEQUANT_FUNC const float d = float(data_a[ib].d); \ +vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])); \ v = v * d; """ @@ -309,12 +309,12 @@ void main() { int pos_a = ir * BM * p.stride_a / LOAD_VEC + start_k / LOAD_VEC; int pos_b = ic * BN * p.stride_b / LOAD_VEC + start_k / LOAD_VEC; - FLOAT_TYPE sums[WMITER * TM * WNITER * TN]; + float sums[WMITER * TM * WNITER * TN]; FLOAT_TYPE cache_a[WMITER * TM]; FLOAT_TYPE cache_b[WNITER * TN]; [[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = FLOAT_TYPE(0.0f); + sums[i] = 0.0f; } [[unroll]] for (int block = start_k; block < end_k; block += BK) { @@ -391,7 +391,7 @@ void main() { [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (int cc = 0; cc < TN; cc++) { [[unroll]] for (int cr = 0; cr < TM; cr++) { - sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += FLOAT_TYPE(cache_a[wsir * TM + cr]) * FLOAT_TYPE(cache_b[wsic * TN + cc]); + sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]); } } } @@ -466,8 +466,8 @@ dequant_head = """#version 450 dequant_body = """ layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE x[];}; -layout (binding = 1) writeonly buffer D {D_TYPE y[];}; +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; layout (push_constant) uniform parameter { @@ -498,8 +498,8 @@ void main() { [[unroll]] for (int iqs = 0; iqs < QUANT_K/QUANT_R; iqs += step) { DEQUANT_FUNC - y[col * p.stride_b + row*QUANT_K + iqs + 0 ] = D_TYPE(v.x); - y[col * p.stride_b + row*QUANT_K + iqs + y_offset] = D_TYPE(v.y); + data_b[col * p.stride_b + row*QUANT_K + iqs + 0 ] = D_TYPE(v.x); + data_b[col * p.stride_b + row*QUANT_K + iqs + y_offset] = D_TYPE(v.y); } } """ @@ -508,8 +508,8 @@ void main() { dequant_q2_K_body = """ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE x[];}; -layout (binding = 1) writeonly buffer D {D_TYPE y[];}; +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; layout (push_constant) uniform parameter { @@ -534,22 +534,22 @@ void main() { const int y_idx = i * QUANT_K + 128 * ip + il; const int ql_idx = 32 * ip + il; - const uint8_t qs = x[i].qs[32 * ip + il]; + const uint8_t qs = data_a[i].qs[32 * ip + il]; - FLOAT_TYPE dall = FLOAT_TYPE(x[i].d.x); - FLOAT_TYPE dmin = FLOAT_TYPE(x[i].d.y); - y[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((x[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(x[i].scales[is+0] >> 4)); - y[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((x[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(x[i].scales[is+2] >> 4)); - y[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((x[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(x[i].scales[is+4] >> 4)); - y[y_idx + 96] = D_TYPE(dall * FLOAT_TYPE((x[i].scales[is+6] & 0xF) * ((qs >> 6) & 3)) - dmin * FLOAT_TYPE(x[i].scales[is+6] >> 4)); + FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); + FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); + data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4)); + data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4)); + data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4)); + data_b[y_idx + 96] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+6] & 0xF) * ((qs >> 6) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+6] >> 4)); } } """ dequant_q3_K_body = """ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE x[];}; -layout (binding = 1) writeonly buffer D {D_TYPE y[];}; +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; layout (push_constant) uniform parameter { @@ -577,18 +577,18 @@ void main() { const int is = 8*n + 2*j + is0; const int shift = 2*j; - const int8_t us = int8_t(is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : - is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) : - is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) : - (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4)); - const FLOAT_TYPE d_all = FLOAT_TYPE(x[i].d); + const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is+0] >> 4) & 3) << 4) : + (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is-4] >> 6) & 3) << 4)); + const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d); const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32); const int y_idx = i * QUANT_K + 128 * n + 32 * j; const int qs_idx = 32*n; for (int l = l0; l < l0 + 4; ++l) { - y[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((x[i].qs[qs_idx + l] >> shift) & 3) - (((x[i].hmask[l] & m) != 0) ? 0 : 4))); + data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4))); } } } @@ -596,8 +596,8 @@ void main() { dequant_q4_K_body = """ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE x[];}; -layout (binding = 1) writeonly buffer D {D_TYPE y[];}; +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; layout (push_constant) uniform parameter { @@ -620,8 +620,8 @@ void main() { const int is = 2 * il; const int n = 4; - const FLOAT_TYPE dall = FLOAT_TYPE(x[i].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(x[i].d.y); + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); const int y_idx = i * QUANT_K + 64 * il + n * ir; const int qs_idx = 32*il + n * ir; @@ -629,28 +629,28 @@ void main() { uint8_t sc; uint8_t m; if (is < 4) { - sc = uint8_t(x[i].scales[is] & 63); - m = uint8_t(x[i].scales[is + 4] & 63); + sc = uint8_t(data_a[i].scales[is] & 63); + m = uint8_t(data_a[i].scales[is + 4] & 63); } else { - sc = uint8_t((x[i].scales[is + 4] & 0xF) | ((x[i].scales[is - 4] >> 6) << 4)); - m = uint8_t((x[i].scales[is + 4] >> 4) | ((x[i].scales[is ] >> 6) << 4)); + sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4)); + m = uint8_t((data_a[i].scales[is + 4] >> 4) | ((data_a[i].scales[is ] >> 6) << 4)); } const FLOAT_TYPE d1 = dall * sc; const FLOAT_TYPE m1 = dmin * m; if (is < 4) { - sc = uint8_t(x[i].scales[is + 1] & 63); - m = uint8_t(x[i].scales[is + 5] & 63); + sc = uint8_t(data_a[i].scales[is + 1] & 63); + m = uint8_t(data_a[i].scales[is + 5] & 63); } else { - sc = uint8_t((x[i].scales[is + 5] & 0xF) | ((x[i].scales[is - 3] >> 6) << 4)); - m = uint8_t((x[i].scales[is + 5] >> 4) | ((x[i].scales[is + 1] >> 6) << 4)); + sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4)); + m = uint8_t((data_a[i].scales[is + 5] >> 4) | ((data_a[i].scales[is + 1] >> 6) << 4)); } const FLOAT_TYPE d2 = dall * sc; const FLOAT_TYPE m2 = dmin * m; [[unroll]] for (int l = 0; l < n; ++l) { - y[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(x[i].qs[qs_idx + l] & 0xF) - m1); - y[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(x[i].qs[qs_idx + l] >> 4) - m2); + data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] & 0xF) - m1); + data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] >> 4) - m2); } } } @@ -658,8 +658,8 @@ void main() { dequant_q5_K_body = """ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE x[];}; -layout (binding = 1) writeonly buffer D {D_TYPE y[];}; +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; layout (push_constant) uniform parameter { @@ -681,8 +681,8 @@ void main() { const int ir = tid % 16; const int is = 2 * il; - const FLOAT_TYPE dall = FLOAT_TYPE(x[i].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(x[i].d.y); + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); const int y_idx = i * QUANT_K + 64 * il + 2 * ir; const int qs_idx = 32*il + 2 * ir; @@ -691,39 +691,39 @@ void main() { uint8_t sc; uint8_t m; if (is < 4) { - sc = uint8_t(x[i].scales[is] & 63); - m = uint8_t(x[i].scales[is + 4] & 63); + sc = uint8_t(data_a[i].scales[is] & 63); + m = uint8_t(data_a[i].scales[is + 4] & 63); } else { - sc = uint8_t((x[i].scales[is + 4] & 0xF) | ((x[i].scales[is - 4] >> 6) << 4)); - m = uint8_t((x[i].scales[is + 4] >> 4) | ((x[i].scales[is ] >> 6) << 4)); + sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4)); + m = uint8_t((data_a[i].scales[is + 4] >> 4) | ((data_a[i].scales[is ] >> 6) << 4)); } const FLOAT_TYPE d1 = dall * sc; const FLOAT_TYPE m1 = dmin * m; if (is < 4) { - sc = uint8_t(x[i].scales[is + 1] & 63); - m = uint8_t(x[i].scales[is + 5] & 63); + sc = uint8_t(data_a[i].scales[is + 1] & 63); + m = uint8_t(data_a[i].scales[is + 5] & 63); } else { - sc = uint8_t((x[i].scales[is + 5] & 0xF) | ((x[i].scales[is - 3] >> 6) << 4)); - m = uint8_t((x[i].scales[is + 5] >> 4) | ((x[i].scales[is + 1] >> 6) << 4)); + sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4)); + m = uint8_t((data_a[i].scales[is + 5] >> 4) | ((data_a[i].scales[is + 1] >> 6) << 4)); } const FLOAT_TYPE d2 = dall * sc; const FLOAT_TYPE m2 = dmin * m; const uint8_t hm1 = uint8_t(1 << (2 * il )); const uint8_t hm2 = uint8_t(1 << (2 * il + 1)); - y[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((x[i].qs[qs_idx ] & 0xF) + (((x[i].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1); - y[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((x[i].qs[qs_idx + 1] & 0xF) + (((x[i].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1); - y[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((x[i].qs[qs_idx ] >> 4) + (((x[i].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2); - y[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((x[i].qs[qs_idx + 1] >> 4) + (((x[i].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2); + data_b[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((data_a[i].qs[qs_idx ] & 0xF) + (((data_a[i].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1); + data_b[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((data_a[i].qs[qs_idx + 1] & 0xF) + (((data_a[i].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1); + data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[i].qs[qs_idx ] >> 4) + (((data_a[i].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2); + data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[i].qs[qs_idx + 1] >> 4) + (((data_a[i].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2); } } """ dequant_q6_K_body = """ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE x[];}; -layout (binding = 1) writeonly buffer D {D_TYPE y[];}; +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; layout (push_constant) uniform parameter { @@ -747,14 +747,14 @@ void main() { const int y_idx = i * QUANT_K + 128 * ip + il; const int ql_idx = 64 * ip + il; - const uint8_t qh = x[i].qh[32 * ip + il]; + const uint8_t qh = data_a[i].qh[32 * ip + il]; - const FLOAT_TYPE d = FLOAT_TYPE(x[i].d); + const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d); - y[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 0] * (int8_t((x[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))); - y[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 2] * (int8_t((x[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))); - y[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 4] * (int8_t((x[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))); - y[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(x[i].scales[is + 6] * (int8_t((x[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))); + data_b[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 0] * (int8_t((data_a[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))); + data_b[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 2] * (int8_t((data_a[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))); + data_b[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 4] * (int8_t((data_a[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))); + data_b[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 6] * (int8_t((data_a[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))); } } """ @@ -770,8 +770,8 @@ mul_mat_vec_head = """#version 450 mul_mat_vec_body = """ layout(local_size_x = QUANT_K, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE x[];}; -layout (binding = 1) readonly buffer B {B_TYPE y[];}; +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; layout (push_constant) uniform parameter @@ -799,8 +799,8 @@ void main() { DEQUANT_FUNC // matrix multiplication - tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(y[iybs + iqs + 0]); - tmp[tid] += FLOAT_TYPE(v.y) * FLOAT_TYPE(y[iybs + iqs + y_offset]); + tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[iybs + iqs + 0]); + tmp[tid] += FLOAT_TYPE(v.y) * FLOAT_TYPE(data_b[iybs + iqs + y_offset]); } // sum up partial sums and write back result @@ -821,8 +821,8 @@ void main() { mul_mat_vec_q2_K_body = """ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE x[];}; -layout (binding = 1) readonly buffer B {B_TYPE y[];}; +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; layout (push_constant) uniform parameter @@ -856,28 +856,28 @@ void main() { [[unroll]] for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { const int y_idx = i * QUANT_K + y_offset; - const FLOAT_TYPE dall = FLOAT_TYPE(x[ib0 + i].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(x[ib0 + i].d.y); + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y); FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - sum1 += FLOAT_TYPE(y[y_idx + l + 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((x[ib0 + i].qs[q_offset + l + 0] >> 0) & 3) - + FLOAT_TYPE(y[y_idx + l + 16]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((x[ib0 + i].qs[q_offset + l +16] >> 0) & 3) - + FLOAT_TYPE(y[y_idx + l + 32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((x[ib0 + i].qs[q_offset + l + 0] >> 2) & 3) - + FLOAT_TYPE(y[y_idx + l + 48]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((x[ib0 + i].qs[q_offset + l +16] >> 2) & 3) - + FLOAT_TYPE(y[y_idx + l + 64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((x[ib0 + i].qs[q_offset + l + 0] >> 4) & 3) - + FLOAT_TYPE(y[y_idx + l + 80]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((x[ib0 + i].qs[q_offset + l +16] >> 4) & 3) - + FLOAT_TYPE(y[y_idx + l + 96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((x[ib0 + i].qs[q_offset + l + 0] >> 6) & 3) - + FLOAT_TYPE(y[y_idx + l +112]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((x[ib0 + i].qs[q_offset + l +16] >> 6) & 3); - sum2 += FLOAT_TYPE(y[y_idx + l + 0]) * FLOAT_TYPE((x[ib0 + i].scales[s_offset + 0] >> 4) & 0xF) - + FLOAT_TYPE(y[y_idx + l + 16]) * FLOAT_TYPE((x[ib0 + i].scales[s_offset + 1] >> 4) & 0xF) - + FLOAT_TYPE(y[y_idx + l + 32]) * FLOAT_TYPE((x[ib0 + i].scales[s_offset + 2] >> 4) & 0xF) - + FLOAT_TYPE(y[y_idx + l + 48]) * FLOAT_TYPE((x[ib0 + i].scales[s_offset + 3] >> 4) & 0xF) - + FLOAT_TYPE(y[y_idx + l + 64]) * FLOAT_TYPE((x[ib0 + i].scales[s_offset + 4] >> 4) & 0xF) - + FLOAT_TYPE(y[y_idx + l + 80]) * FLOAT_TYPE((x[ib0 + i].scales[s_offset + 5] >> 4) & 0xF) - + FLOAT_TYPE(y[y_idx + l + 96]) * FLOAT_TYPE((x[ib0 + i].scales[s_offset + 6] >> 4) & 0xF) - + FLOAT_TYPE(y[y_idx + l +112]) * FLOAT_TYPE((x[ib0 + i].scales[s_offset + 7] >> 4) & 0xF); + sum1 += FLOAT_TYPE(data_b[y_idx + l + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3) + + FLOAT_TYPE(data_b[y_idx + l + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3) + + FLOAT_TYPE(data_b[y_idx + l + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3) + + FLOAT_TYPE(data_b[y_idx + l + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3) + + FLOAT_TYPE(data_b[y_idx + l + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3) + + FLOAT_TYPE(data_b[y_idx + l + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3) + + FLOAT_TYPE(data_b[y_idx + l + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3) + + FLOAT_TYPE(data_b[y_idx + l +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3); + sum2 += FLOAT_TYPE(data_b[y_idx + l + 0]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF) + + FLOAT_TYPE(data_b[y_idx + l + 16]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF) + + FLOAT_TYPE(data_b[y_idx + l + 32]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF) + + FLOAT_TYPE(data_b[y_idx + l + 48]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF) + + FLOAT_TYPE(data_b[y_idx + l + 64]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF) + + FLOAT_TYPE(data_b[y_idx + l + 80]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF) + + FLOAT_TYPE(data_b[y_idx + l + 96]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF) + + FLOAT_TYPE(data_b[y_idx + l +112]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF); } tmp[16 * ix + tid] += dall * sum1 - dmin * sum2; } @@ -898,8 +898,8 @@ void main() { mul_mat_vec_q3_K_body = """ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE x[];}; -layout (binding = 1) readonly buffer B {B_TYPE y[];}; +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; layout (push_constant) uniform parameter @@ -936,18 +936,18 @@ void main() { [[unroll]] for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { const int y_idx = i * QUANT_K + y_offset; - const FLOAT_TYPE d = FLOAT_TYPE(x[ib0 + i].d); + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); FLOAT_TYPE sum = FLOAT_TYPE(0.0); for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - sum += FLOAT_TYPE(y[y_idx + l + 0]) * FLOAT_TYPE(int8_t(((x[ib0 + i].scales[0] >> s_shift) & 0xF) | ((x[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((x[ib0 + i].qs[q_offset + l ] ) & 3) - (((x[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)) - + FLOAT_TYPE(y[y_idx + l + 32]) * FLOAT_TYPE(int8_t(((x[ib0 + i].scales[2] >> s_shift) & 0xF) | ((x[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((x[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((x[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)) - + FLOAT_TYPE(y[y_idx + l + 64]) * FLOAT_TYPE(int8_t(((x[ib0 + i].scales[4] >> s_shift) & 0xF) | ((x[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((x[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((x[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)) - + FLOAT_TYPE(y[y_idx + l + 96]) * FLOAT_TYPE(int8_t(((x[ib0 + i].scales[6] >> s_shift) & 0xF) | ((x[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((x[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((x[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)) - + FLOAT_TYPE(y[y_idx + l + 16]) * FLOAT_TYPE(int8_t(((x[ib0 + i].scales[1] >> s_shift) & 0xF) | ((x[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((x[ib0 + i].qs[q_offset + l+16] ) & 3) - (((x[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)) - + FLOAT_TYPE(y[y_idx + l + 48]) * FLOAT_TYPE(int8_t(((x[ib0 + i].scales[3] >> s_shift) & 0xF) | ((x[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((x[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((x[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)) - + FLOAT_TYPE(y[y_idx + l + 80]) * FLOAT_TYPE(int8_t(((x[ib0 + i].scales[5] >> s_shift) & 0xF) | ((x[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((x[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((x[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)) - + FLOAT_TYPE(y[y_idx + l +112]) * FLOAT_TYPE(int8_t(((x[ib0 + i].scales[7] >> s_shift) & 0xF) | ((x[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((x[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((x[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)); + sum += FLOAT_TYPE(data_b[y_idx + l + 0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)) + + FLOAT_TYPE(data_b[y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)) + + FLOAT_TYPE(data_b[y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)) + + FLOAT_TYPE(data_b[y_idx + l + 96]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[6] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)) + + FLOAT_TYPE(data_b[y_idx + l + 16]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[1] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)) + + FLOAT_TYPE(data_b[y_idx + l + 48]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[3] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)) + + FLOAT_TYPE(data_b[y_idx + l + 80]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[5] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)) + + FLOAT_TYPE(data_b[y_idx + l +112]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[7] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)); } tmp[16 * ix + tid] += d * sum; } @@ -968,8 +968,8 @@ void main() { mul_mat_vec_q4_K_body = """ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE x[];}; -layout (binding = 1) readonly buffer B {B_TYPE y[];}; +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; layout (push_constant) uniform parameter @@ -1007,67 +1007,67 @@ void main() { const int y1_idx = i * QUANT_K + y_offset; const int y2_idx = y1_idx + 128; - const FLOAT_TYPE dall = FLOAT_TYPE(x[ib0 + i].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(x[ib0 + i].d.y); + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y); - const uint8_t sc0 = uint8_t( x[ib0 + i].scales[v_im * 2 ] & 0x3f); - const uint8_t sc1 = uint8_t( x[ib0 + i].scales[v_im * 2 + 1] & 0x3f); - const uint8_t sc2 = uint8_t( x[ib0 + i].scales[v_im * 2 + 4] & 0x3f); - const uint8_t sc3 = uint8_t( x[ib0 + i].scales[v_im * 2 + 5] & 0x3f); - const uint8_t sc4 = uint8_t(( x[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((x[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2)); - const uint8_t sc5 = uint8_t(( x[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2)); - const uint8_t sc6 = uint8_t(((x[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2)); - const uint8_t sc7 = uint8_t(((x[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2)); + const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f); + const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f); + const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f); + const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f); + const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2)); + const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2)); + const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2)); + const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2)); #if K_QUANTS_PER_ITERATION == 2 - const uint8_t q4_0 = uint8_t(x[ib0 + i].qs[q_offset ] & 0xf); - const uint8_t q4_1 = uint8_t(x[ib0 + i].qs[q_offset + 1] & 0xf); - const uint8_t q4_2 = uint8_t(x[ib0 + i].qs[q_offset + 2] & 0xf); - const uint8_t q4_3 = uint8_t(x[ib0 + i].qs[q_offset + 3] & 0xf); - const uint8_t q4_4 = uint8_t(x[ib0 + i].qs[q_offset ] >> 4); - const uint8_t q4_5 = uint8_t(x[ib0 + i].qs[q_offset + 1] >> 4); - const uint8_t q4_6 = uint8_t(x[ib0 + i].qs[q_offset + 2] >> 4); - const uint8_t q4_7 = uint8_t(x[ib0 + i].qs[q_offset + 3] >> 4); - const uint8_t q4_8 = uint8_t(x[ib0 + i].qs[q_offset + 64] & 0xf); - const uint8_t q4_9 = uint8_t(x[ib0 + i].qs[q_offset + 65] & 0xf); - const uint8_t q4_10 = uint8_t(x[ib0 + i].qs[q_offset + 66] & 0xf); - const uint8_t q4_11 = uint8_t(x[ib0 + i].qs[q_offset + 67] & 0xf); - const uint8_t q4_12 = uint8_t(x[ib0 + i].qs[q_offset + 64] >> 4); - const uint8_t q4_13 = uint8_t(x[ib0 + i].qs[q_offset + 65] >> 4); - const uint8_t q4_14 = uint8_t(x[ib0 + i].qs[q_offset + 66] >> 4); - const uint8_t q4_15 = uint8_t(x[ib0 + i].qs[q_offset + 67] >> 4); + const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf); + const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf); + const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] & 0xf); + const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] & 0xf); + const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4); + const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4); + const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] >> 4); + const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] >> 4); + const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf); + const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf); + const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] & 0xf); + const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] & 0xf); + const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4); + const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4); + const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] >> 4); + const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] >> 4); - const FLOAT_TYPE sx = FLOAT_TYPE(y[y1_idx] * q4_0 + y[y1_idx + 1] * q4_1 + y[y1_idx + 2] * q4_2 + y[y1_idx + 3] * q4_3); - const FLOAT_TYPE sy = FLOAT_TYPE(y[y1_idx + 32] * q4_4 + y[y1_idx + 33] * q4_5 + y[y1_idx + 34] * q4_6 + y[y1_idx + 35] * q4_7); - const FLOAT_TYPE sz = FLOAT_TYPE(y[y2_idx] * q4_8 + y[y2_idx + 1] * q4_9 + y[y2_idx + 2] * q4_10 + y[y2_idx + 3] * q4_11); - const FLOAT_TYPE sw = FLOAT_TYPE(y[y2_idx + 32] * q4_12 + y[y2_idx + 33] * q4_13 + y[y2_idx + 34] * q4_14 + y[y2_idx + 35] * q4_15); + const FLOAT_TYPE sx = FLOAT_TYPE(data_b[y1_idx] * q4_0 + data_b[y1_idx + 1] * q4_1 + data_b[y1_idx + 2] * q4_2 + data_b[y1_idx + 3] * q4_3); + const FLOAT_TYPE sy = FLOAT_TYPE(data_b[y1_idx + 32] * q4_4 + data_b[y1_idx + 33] * q4_5 + data_b[y1_idx + 34] * q4_6 + data_b[y1_idx + 35] * q4_7); + const FLOAT_TYPE sz = FLOAT_TYPE(data_b[y2_idx] * q4_8 + data_b[y2_idx + 1] * q4_9 + data_b[y2_idx + 2] * q4_10 + data_b[y2_idx + 3] * q4_11); + const FLOAT_TYPE sw = FLOAT_TYPE(data_b[y2_idx + 32] * q4_12 + data_b[y2_idx + 33] * q4_13 + data_b[y2_idx + 34] * q4_14 + data_b[y2_idx + 35] * q4_15); const FLOAT_TYPE smin = FLOAT_TYPE( - y[y1_idx ] * sc2 + y[y1_idx + 32] * sc3 + y[y2_idx ] * sc6 + y[y2_idx + 32] * sc7 - + y[y1_idx + 1] * sc2 + y[y1_idx + 33] * sc3 + y[y2_idx + 1] * sc6 + y[y2_idx + 33] * sc7 - + y[y1_idx + 2] * sc2 + y[y1_idx + 34] * sc3 + y[y2_idx + 2] * sc6 + y[y2_idx + 34] * sc7 - + y[y1_idx + 3] * sc2 + y[y1_idx + 35] * sc3 + y[y2_idx + 3] * sc6 + y[y2_idx + 35] * sc7 + data_b[y1_idx ] * sc2 + data_b[y1_idx + 32] * sc3 + data_b[y2_idx ] * sc6 + data_b[y2_idx + 32] * sc7 + + data_b[y1_idx + 1] * sc2 + data_b[y1_idx + 33] * sc3 + data_b[y2_idx + 1] * sc6 + data_b[y2_idx + 33] * sc7 + + data_b[y1_idx + 2] * sc2 + data_b[y1_idx + 34] * sc3 + data_b[y2_idx + 2] * sc6 + data_b[y2_idx + 34] * sc7 + + data_b[y1_idx + 3] * sc2 + data_b[y1_idx + 35] * sc3 + data_b[y2_idx + 3] * sc6 + data_b[y2_idx + 35] * sc7 ); tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin); #else - const uint8_t q4_0 = uint8_t(x[ib0 + i].qs[q_offset ] & 0xf); - const uint8_t q4_1 = uint8_t(x[ib0 + i].qs[q_offset + 1] & 0xf); - const uint8_t q4_2 = uint8_t(x[ib0 + i].qs[q_offset ] >> 4); - const uint8_t q4_3 = uint8_t(x[ib0 + i].qs[q_offset + 1] >> 4); - const uint8_t q4_4 = uint8_t(x[ib0 + i].qs[q_offset + 64] & 0xf); - const uint8_t q4_5 = uint8_t(x[ib0 + i].qs[q_offset + 65] & 0xf); - const uint8_t q4_6 = uint8_t(x[ib0 + i].qs[q_offset + 64] >> 4); - const uint8_t q4_7 = uint8_t(x[ib0 + i].qs[q_offset + 65] >> 4); + const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf); + const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf); + const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4); + const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4); + const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf); + const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf); + const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4); + const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4); - const FLOAT_TYPE sx = FLOAT_TYPE(y[y1_idx ] * q4_0 + y[y1_idx + 1] * q4_1); - const FLOAT_TYPE sy = FLOAT_TYPE(y[y1_idx + 32] * q4_2 + y[y1_idx + 33] * q4_3); - const FLOAT_TYPE sz = FLOAT_TYPE(y[y2_idx ] * q4_4 + y[y2_idx + 1] * q4_5); - const FLOAT_TYPE sw = FLOAT_TYPE(y[y2_idx + 32] * q4_6 + y[y2_idx + 33] * q4_7); + const FLOAT_TYPE sx = FLOAT_TYPE(data_b[y1_idx ] * q4_0 + data_b[y1_idx + 1] * q4_1); + const FLOAT_TYPE sy = FLOAT_TYPE(data_b[y1_idx + 32] * q4_2 + data_b[y1_idx + 33] * q4_3); + const FLOAT_TYPE sz = FLOAT_TYPE(data_b[y2_idx ] * q4_4 + data_b[y2_idx + 1] * q4_5); + const FLOAT_TYPE sw = FLOAT_TYPE(data_b[y2_idx + 32] * q4_6 + data_b[y2_idx + 33] * q4_7); const FLOAT_TYPE smin = FLOAT_TYPE( - y[y1_idx] * sc2 + y[y1_idx + 32] * sc3 + y[y2_idx] * sc6 + y[y2_idx + 32] * sc7 - + y[y1_idx + 1] * sc2 + y[y1_idx + 33] * sc3 + y[y2_idx + 1] * sc6 + y[y2_idx + 33] * sc7 + data_b[y1_idx] * sc2 + data_b[y1_idx + 32] * sc3 + data_b[y2_idx] * sc6 + data_b[y2_idx + 32] * sc7 + + data_b[y1_idx + 1] * sc2 + data_b[y1_idx + 33] * sc3 + data_b[y2_idx + 1] * sc6 + data_b[y2_idx + 33] * sc7 ); - tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(x[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(x[ib0 + i].scales[v_im + 1] & 0x3f) + sz * FLOAT_TYPE((x[ib0 + i].scales[v_im + 4] & 0x0f) | ((x[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((x[ib0 + i].scales[v_im + 5] & 0x0f) | ((x[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin); + tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) + sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin); #endif } @@ -1087,8 +1087,8 @@ void main() { mul_mat_vec_q5_K_body = """ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE x[];}; -layout (binding = 1) readonly buffer B {B_TYPE y[];}; +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; layout (push_constant) uniform parameter @@ -1126,62 +1126,62 @@ void main() { const int y1_idx = i * QUANT_K + y_offset; const int y2_idx = y1_idx + 128; - const FLOAT_TYPE dall = FLOAT_TYPE(x[ib0 + i].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(x[ib0 + i].d.y); + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y); - const uint8_t sc0 = uint8_t( x[ib0 + i].scales[v_im * 2 ] & 0x3f); - const uint8_t sc1 = uint8_t( x[ib0 + i].scales[v_im * 2 + 1] & 0x3f); - const uint8_t sc2 = uint8_t( x[ib0 + i].scales[v_im * 2 + 4] & 0x3f); - const uint8_t sc3 = uint8_t( x[ib0 + i].scales[v_im * 2 + 5] & 0x3f); - const uint8_t sc4 = uint8_t(( x[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((x[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2)); - const uint8_t sc5 = uint8_t(( x[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2)); - const uint8_t sc6 = uint8_t(((x[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2)); - const uint8_t sc7 = uint8_t(((x[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2)); + const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f); + const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f); + const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f); + const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f); + const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2)); + const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2)); + const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2)); + const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2)); - const uint8_t q4_0 = uint8_t(x[ib0 + i].qs[q_offset ] & 0xf); - const uint8_t q4_1 = uint8_t(x[ib0 + i].qs[q_offset + 1] & 0xf); - const uint8_t q4_2 = uint8_t(x[ib0 + i].qs[q_offset + 16] & 0xf); - const uint8_t q4_3 = uint8_t(x[ib0 + i].qs[q_offset + 17] & 0xf); - const uint8_t q4_4 = uint8_t(x[ib0 + i].qs[q_offset ] >> 4); - const uint8_t q4_5 = uint8_t(x[ib0 + i].qs[q_offset + 1] >> 4); - const uint8_t q4_6 = uint8_t(x[ib0 + i].qs[q_offset + 16] >> 4); - const uint8_t q4_7 = uint8_t(x[ib0 + i].qs[q_offset + 17] >> 4); - const uint8_t q4_8 = uint8_t(x[ib0 + i].qs[q_offset + 64] & 0xf); - const uint8_t q4_9 = uint8_t(x[ib0 + i].qs[q_offset + 65] & 0xf); - const uint8_t q4_10 = uint8_t(x[ib0 + i].qs[q_offset + 80] & 0xf); - const uint8_t q4_11 = uint8_t(x[ib0 + i].qs[q_offset + 81] & 0xf); - const uint8_t q4_12 = uint8_t(x[ib0 + i].qs[q_offset + 64] >> 4); - const uint8_t q4_13 = uint8_t(x[ib0 + i].qs[q_offset + 65] >> 4); - const uint8_t q4_14 = uint8_t(x[ib0 + i].qs[q_offset + 80] >> 4); - const uint8_t q4_15 = uint8_t(x[ib0 + i].qs[q_offset + 81] >> 4); + const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf); + const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf); + const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] & 0xf); + const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] & 0xf); + const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4); + const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4); + const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] >> 4); + const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] >> 4); + const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf); + const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf); + const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] & 0xf); + const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] & 0xf); + const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4); + const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4); + const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] >> 4); + const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] >> 4); const FLOAT_TYPE sx = FLOAT_TYPE( - y[y1_idx ] * (q4_0 + (((x[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0)) - + y[y1_idx + 1] * (q4_1 + (((x[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0)) - + y[y1_idx + 16] * (q4_2 + (((x[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0)) - + y[y1_idx + 17] * (q4_3 + (((x[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0)) + data_b[y1_idx ] * (q4_0 + (((data_a[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0)) + + data_b[y1_idx + 1] * (q4_1 + (((data_a[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0)) + + data_b[y1_idx + 16] * (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0)) + + data_b[y1_idx + 17] * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0)) ); const FLOAT_TYPE sy = FLOAT_TYPE( - y[y1_idx + 32] * (q4_4 + (((x[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0)) - + y[y1_idx + 33] * (q4_5 + (((x[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0)) - + y[y1_idx + 48] * (q4_6 + (((x[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0)) - + y[y1_idx + 49] * (q4_7 + (((x[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0)) + data_b[y1_idx + 32] * (q4_4 + (((data_a[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0)) + + data_b[y1_idx + 33] * (q4_5 + (((data_a[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0)) + + data_b[y1_idx + 48] * (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0)) + + data_b[y1_idx + 49] * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0)) ); const FLOAT_TYPE sz = FLOAT_TYPE( - y[y2_idx ] * (q4_8 + (((x[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0)) - + y[y2_idx + 1] * (q4_9 + (((x[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0)) - + y[y2_idx + 16] * (q4_10 + (((x[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0)) - + y[y2_idx + 17] * (q4_11 + (((x[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0)) + data_b[y2_idx ] * (q4_8 + (((data_a[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0)) + + data_b[y2_idx + 1] * (q4_9 + (((data_a[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0)) + + data_b[y2_idx + 16] * (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0)) + + data_b[y2_idx + 17] * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0)) ); const FLOAT_TYPE sw = FLOAT_TYPE( - y[y2_idx + 32] * (q4_12 + (((x[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0)) - + y[y2_idx + 33] * (q4_13 + (((x[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0)) - + y[y2_idx + 48] * (q4_14 + (((x[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0)) - + y[y2_idx + 49] * (q4_15 + (((x[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0)) + data_b[y2_idx + 32] * (q4_12 + (((data_a[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0)) + + data_b[y2_idx + 33] * (q4_13 + (((data_a[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0)) + + data_b[y2_idx + 48] * (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0)) + + data_b[y2_idx + 49] * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0)) ); const FLOAT_TYPE smin = FLOAT_TYPE( - (y[y1_idx] + y[y1_idx + 1] + y[y1_idx + 16] + y[y1_idx + 17]) * sc2 + (y[y1_idx + 32] + y[y1_idx + 33] + y[y1_idx + 48] + y[y1_idx + 49]) * sc3 - + (y[y2_idx] + y[y2_idx + 1] + y[y2_idx + 16] + y[y2_idx + 17]) * sc6 + (y[y2_idx + 32] + y[y2_idx + 33] + y[y2_idx + 48] + y[y2_idx + 49]) * sc7 + (data_b[y1_idx] + data_b[y1_idx + 1] + data_b[y1_idx + 16] + data_b[y1_idx + 17]) * sc2 + (data_b[y1_idx + 32] + data_b[y1_idx + 33] + data_b[y1_idx + 48] + data_b[y1_idx + 49]) * sc3 + + (data_b[y2_idx] + data_b[y2_idx + 1] + data_b[y2_idx + 16] + data_b[y2_idx + 17]) * sc6 + (data_b[y2_idx + 32] + data_b[y2_idx + 33] + data_b[y2_idx + 48] + data_b[y2_idx + 49]) * sc7 ); tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin); } @@ -1202,8 +1202,8 @@ void main() { mul_mat_vec_q6_K_body = """ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE x[];}; -layout (binding = 1) readonly buffer B {B_TYPE y[];}; +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; layout (push_constant) uniform parameter @@ -1245,25 +1245,25 @@ void main() { [[unroll]] for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { const int y_idx = i * QUANT_K + y_offset; - const FLOAT_TYPE d = FLOAT_TYPE(x[ib0 + i].d); + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); #if K_QUANTS_PER_ITERATION == 1 - FLOAT_TYPE sum = FLOAT_TYPE(y[y_idx + 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 0] & 0xF) | ((x[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32) - + FLOAT_TYPE(y[y_idx + 16]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 16] & 0xF) | ((x[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32) - + FLOAT_TYPE(y[y_idx + 32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 32] & 0xF) | ((x[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32) - + FLOAT_TYPE(y[y_idx + 48]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 48] & 0xF) | ((x[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32) - + FLOAT_TYPE(y[y_idx + 64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 0] >> 4) | ((x[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32) - + FLOAT_TYPE(y[y_idx + 80]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 16] >> 4) | ((x[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32) - + FLOAT_TYPE(y[y_idx + 96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 32] >> 4) | ((x[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32) - + FLOAT_TYPE(y[y_idx +112]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 48] >> 4) | ((x[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32); + FLOAT_TYPE sum = FLOAT_TYPE(data_b[y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32) + + FLOAT_TYPE(data_b[y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32) + + FLOAT_TYPE(data_b[y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32) + + FLOAT_TYPE(data_b[y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32) + + FLOAT_TYPE(data_b[y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32) + + FLOAT_TYPE(data_b[y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32) + + FLOAT_TYPE(data_b[y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32) + + FLOAT_TYPE(data_b[y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32); tmp[16 * ix + tid] += sum; #else FLOAT_TYPE sum = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 4; ++l) { - sum += FLOAT_TYPE(y[y_idx + l+ 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32) - + FLOAT_TYPE(y[y_idx + l+32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32) - + FLOAT_TYPE(y[y_idx + l+64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((x[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32) - + FLOAT_TYPE(y[y_idx + l+96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+32] >> 4) | (((x[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32); + sum += FLOAT_TYPE(data_b[y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32) + + FLOAT_TYPE(data_b[y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32) + + FLOAT_TYPE(data_b[y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32) + + FLOAT_TYPE(data_b[y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32); } tmp[16 * ix + tid] += sum; #endif @@ -1311,151 +1311,118 @@ void main() { } """ -# MUL F32 -mul_f32_src = """#version 450 - -layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in; - -layout (binding = 0) buffer X {X_TYPE data_x[];}; -layout (binding = 1) buffer Y {Y_TYPE data_y[];}; -layout (binding = 2) buffer D {D_TYPE data_d[];}; - -layout (push_constant) uniform parameter -{ - int M; - int N; - int stride_x; - int stride_y; - int stride_d; - int x_offset; - int y_offset; - int d_offset; - float scale; -} p; - -void main() { - const int x = int(gl_GlobalInvocationID.x); - const int y = int(gl_GlobalInvocationID.y); - - if (x >= p.M || y >= p.N) { - return; - } - - data_d[p.d_offset + y * p.stride_d + x] = D_TYPE(data_x[p.x_offset + y * p.stride_x + x]) * D_TYPE(data_y[p.y_offset + x]); -} -""" - -# ADD -add_head = """ +generic_head = """ #version 450 #extension GL_EXT_shader_16bit_storage : require """ -add_body = """ -layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in; -layout (binding = 0) buffer X {X_TYPE data_x[];}; -layout (binding = 1) buffer Y {Y_TYPE data_y[];}; -layout (binding = 2) buffer D {D_TYPE data_d[];}; +# MUL F32 +mul_body = """layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; layout (push_constant) uniform parameter { - int M; - int N; - int stride_x; - int stride_y; - int stride_d; - int x_offset; - int y_offset; - int d_offset; - float scale; + int KX; + int KY; + float param; } p; void main() { - const int x = int(gl_GlobalInvocationID.x); - const int y = int(gl_GlobalInvocationID.y); + const int idx = int(gl_GlobalInvocationID.x); - if (x >= p.M || y >= p.N) { + if (idx >= p.KX) { return; } - data_d[p.d_offset + y * p.stride_d + x] = D_TYPE(FLOAT_TYPE(data_x[p.x_offset + y * p.stride_x + x]) + FLOAT_TYPE(data_y[p.y_offset + x])); + data_d[idx] = D_TYPE(FLOAT_TYPE(data_a[idx]) * FLOAT_TYPE(data_b[idx % p.KY])); +} +""" + +# ADD +add_body = """ +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +layout (push_constant) uniform parameter +{ + int KX; + int KY; + float param; +} p; + +void main() { + const int idx = int(gl_GlobalInvocationID.x); + + if (idx >= p.KX) { + return; + } + + data_d[idx] = D_TYPE(FLOAT_TYPE(data_a[idx]) + FLOAT_TYPE(data_b[idx % p.KY])); } """ # SCALE -scale_src = """#version 450 +scale_body = """layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in; - -layout (binding = 0) buffer X {X_TYPE data_x[];}; -layout (binding = 1) buffer D {D_TYPE data_d[];}; +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (push_constant) uniform parameter { - int M; - int N; - int stride_x; - int stride_y; - int stride_d; - int x_offset; - int y_offset; - int d_offset; - float scale; + int KX; + int KY; + float param; } p; void main() { - const int x = int(gl_GlobalInvocationID.x); - const int y = int(gl_GlobalInvocationID.y); + const int idx = int(gl_GlobalInvocationID.x); - if (x >= p.M || y >= p.N) { + if (idx >= p.KX) { return; } - data_d[p.d_offset + y * p.stride_d + x] = D_TYPE(data_x[p.x_offset + y * p.stride_x + x]) * D_TYPE(p.scale); + data_d[idx] = D_TYPE(FLOAT_TYPE(data_a[idx]) * FLOAT_TYPE(p.param)); } """ # GET_ROWS -get_rows_head = """#version 450 - +get_rows_body = """ #extension GL_EXT_control_flow_attributes : enable -#extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_8bit_storage : require -""" -get_rows_body = """layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in; +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) buffer X {A_TYPE x[];}; -layout (binding = 1) buffer Y {int y[];}; -layout (binding = 2) buffer D {D_TYPE dst[];}; +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {int data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; layout (push_constant) uniform parameter { int M; int N; - int stride_x; - int stride_y; - int stride_d; - int x_offset; - int y_offset; - int d_offset; - float scale; + float param; } p; void main() { const int col = int(gl_GlobalInvocationID.x) * 2; const int row = int(gl_GlobalInvocationID.y); - if (col >= p.M) { + if (col >= p.N) { return; } - const int r = y[row]; + const int r = data_b[row]; - // copy x[r*p.M + col] to dst[row*p.M + col] - const int xi = r*p.M + col; - const int di = row*p.M + col; + // copy data_a[r*p.N + col] to dst[row*p.M + col] + const int xi = r*p.N + col; + const int di = row*p.N + col; const int ib = xi/QUANT_K; // block index const int iqs = (xi%QUANT_K)/QUANT_R; // quant index @@ -1469,6 +1436,53 @@ void main() { } """ +rms_norm_body = """ +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (push_constant) uniform parameter +{ + int M; + int N; + float param; +} p; + +shared FLOAT_TYPE sum[BLOCK_SIZE]; + +void main() { + const uint row = uint(gl_WorkGroupID.x); + const uint tid = uint(gl_LocalInvocationID.x); + + sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp + + [[unroll]] for (uint col = tid; col < p.M; col += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.M + col]); + sum[tid] += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum[tid] += sum[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.M); + const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param)); + + for (uint col = tid; col < p.M; col += BLOCK_SIZE) { + data_d[row*p.M + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.M + col])); + } +} +""" + GLSLC = "glslc" VK_NUM_TYPES = 16 @@ -1582,7 +1596,7 @@ async def main(): vec_type = "vec4" stream = [] - stream.extend((mulmat_head, shader_float_type, mulmat_body)); + stream.extend((mulmat_head, shader_float_type, mulmat_body)) tasks.append(string_to_spv("matmul_f32_l", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) tasks.append(string_to_spv("matmul_f32_m", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) tasks.append(string_to_spv("matmul_f32_s", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) @@ -1608,9 +1622,9 @@ async def main(): tasks.append(string_to_spv("f32_to_f16", f32_to_f16_src, {}, fp16)) for i in range(0, VK_NUM_TYPES): - stream.clear(); + stream.clear() - stream.extend((dequant_head, shader_int8_ext, shader_float_type)); + stream.extend((dequant_head, shader_int8_ext, shader_float_type)) if i == GGML_TYPE_F16: stream.extend((shader_f16_defines, shader_f16_dequant_func_compat if not fp16 else shader_f16_dequant_func, dequant_body)) @@ -1641,7 +1655,7 @@ async def main(): # mul mat vec for i in range(0, VK_NUM_TYPES): - stream.clear(); + stream.clear() stream.extend((mul_mat_vec_head, shader_int8_ext, shader_float_type)) if i == GGML_TYPE_F16: @@ -1674,8 +1688,8 @@ async def main(): # get_rows for i in range(0, VK_NUM_TYPES): - stream.clear(); - stream.extend((get_rows_head, shader_int8_ext, shader_float_type)) + stream.clear() + stream.extend((generic_head, shader_int8_ext, shader_float_type)) if i == GGML_TYPE_F16: stream.extend((shader_f16_defines, shader_f16_dequant_func_compat if not fp16 else shader_f16_dequant_func, get_rows_body)) @@ -1696,20 +1710,19 @@ async def main(): tasks.append(string_to_spv(f"get_rows_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float"}, fp16)) # add - stream.clear(); - - stream.extend((add_head, shader_float_type, add_body)) - tasks.append(string_to_spv("add_f32", "".join(stream), {"X_TYPE": "float", "Y_TYPE": "float", "D_TYPE": "float"}, fp16)) - - stream.clear(); - stream.extend((add_head, shader_float_type, add_body)) - tasks.append(string_to_spv("add_f16_f32_f16", "".join(stream), {"X_TYPE": "float16_t", "Y_TYPE": "float", "D_TYPE": "float16_t"}, fp16)) + stream.clear() + stream.extend((generic_head, shader_float_type, add_body)) + tasks.append(string_to_spv("add_f32", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("add_f16_f32_f16", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float16_t"}, fp16)) # Static shaders tasks.append(string_to_spv("split_k_reduce", mulmat_split_k_reduce_src, {}, fp16)) - tasks.append(string_to_spv("mul_f32", mul_f32_src, {"X_TYPE": "float", "Y_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("mul_f32", f"{generic_head}\n{shader_float_type}\n{mul_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16)) - tasks.append(string_to_spv("scale_f32", scale_src, {"X_TYPE": "float", "D_TYPE": "float"}, fp16)) + tasks.append(string_to_spv("scale_f32", f"{generic_head}\n{shader_float_type}\n{scale_body}", {"A_TYPE": "float", "D_TYPE": "float"}, fp16)) + + # Shaders where precision is needed, so no fp16 version + tasks.append(string_to_spv("rms_norm_f32", f"{generic_head}\n{shader_f32}\n{rms_norm_body}", {"A_TYPE": "float", "D_TYPE": "float"}, True)) await asyncio.gather(*tasks)