dequant_q4_0 kernel

This commit is contained in:
0cc4m 2023-06-30 20:48:42 +02:00
parent cb5cb4d6e2
commit c8ff09bdc7
3 changed files with 69 additions and 4 deletions

View file

@ -71,7 +71,7 @@ vk::Device vk_device;
vk::CommandPool vk_command_pool_compute, vk_command_pool_transfer; vk::CommandPool vk_command_pool_compute, vk_command_pool_transfer;
VmaAllocator vk_allocator; VmaAllocator vk_allocator;
vk_pipeline vk_pipeline_matmul_f32, vk_pipeline_matmul_f16; vk_pipeline vk_pipeline_matmul_f32, vk_pipeline_matmul_f16;
vk_pipeline vk_pipeline_f16_to_f32; vk_pipeline vk_pipeline_f16_to_f32, vk_pipeline_dequant_q4_0;
VmaAllocation vk_buffer_qa_alloc, vk_buffer_a_alloc, vk_buffer_b_alloc, vk_buffer_c_alloc; VmaAllocation vk_buffer_qa_alloc, vk_buffer_a_alloc, vk_buffer_b_alloc, vk_buffer_c_alloc;
vk::Buffer vk_buffer_qa, vk_buffer_a, vk_buffer_b, vk_buffer_c; vk::Buffer vk_buffer_qa, vk_buffer_a, vk_buffer_b, vk_buffer_c;
@ -332,6 +332,7 @@ void ggml_vk_init(void) {
} }
vk_pipeline_f16_to_f32 = ggml_vk_create_pipeline("vk_shaders/f16_to_f32.spv", "main", 2, 1, {32, 1, 1}); vk_pipeline_f16_to_f32 = ggml_vk_create_pipeline("vk_shaders/f16_to_f32.spv", "main", 2, 1, {32, 1, 1});
vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline("vk_shaders/dequant_q4_0.spv", "main", 2, 1, {32, 1, 1});
// Command pools // Command pools
vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(), vk_compute_queue_family_index); vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(), vk_compute_queue_family_index);
@ -359,8 +360,8 @@ void ggml_vk_init(void) {
static vk_pipeline* ggml_get_to_fp32_vk(ggml_type type) { static vk_pipeline* ggml_get_to_fp32_vk(ggml_type type) {
switch (type) { switch (type) {
// case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
// return &dequantize_row_q4_0_cl; return &vk_pipeline_dequant_q4_0;
// case GGML_TYPE_Q4_1: // case GGML_TYPE_Q4_1:
// return &dequantize_row_q4_1_cl; // return &dequantize_row_q4_1_cl;
// case GGML_TYPE_Q5_0: // case GGML_TYPE_Q5_0:
@ -1022,7 +1023,7 @@ bool ggml_vk_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
const int64_t ne1 = dst->ne[1]; const int64_t ne1 = dst->ne[1];
// TODO: find the optimal values for these // TODO: find the optimal values for these
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 /*|| ggml_is_quantized(src0->type)*/) && if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
src1->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 &&
dst->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
((ne0 >= 128 && ne1 >= 32 && ne10 >= 128) || src0->backend == GGML_BACKEND_GPU)) { ((ne0 >= 128 && ne1 >= 32 && ne10 >= 128) || src0->backend == GGML_BACKEND_GPU)) {

7
ggml.c
View file

@ -11044,6 +11044,13 @@ static void ggml_compute_forward_mul_mat_q_f32(
} }
return; return;
} }
#elif defined(GGML_USE_VULKAN)
if (ggml_vk_can_mul_mat(src0, src1, dst)) {
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
ggml_vk_mul_mat(src0, src1, dst, params->wdata, params->wsize);
}
return;
}
#endif #endif
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)

View file

@ -0,0 +1,57 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#define QUANT_K 32
#define QUANT_R 2
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
struct block_q4_0
{
float16_t d;
uint8_t qs[16];
};
layout (binding = 0) readonly buffer A { block_q4_0 x[]; };
layout (binding = 1) writeonly buffer D { float y[]; };
layout (push_constant) uniform parameter
{
int N;
} p;
void main() {
const int idx = int(gl_GlobalInvocationID.x);
const int i = int(gl_WorkGroupID.x * gl_WorkGroupSize.x + gl_LocalInvocationID.x*2);
if (idx >= p.N) {
return;
}
const int qk = QUANT_K;
const int qr = QUANT_R;
const int ib = i/qk; // block index
const int iqs = (i%qk)/qr; // quant index
const int iybs = i - i%qk; // y block start index
const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize
float v0, v1;
const float d = float(x[ib].d);
const uint8_t vui = x[ib].qs[iqs];
const int8_t vi0 = int8_t(vui & 0xF);
const int8_t vi1 = int8_t(vui >> 4);
v0 = (vi0 - 8)*d;
v1 = (vi1 - 8)*d;
y[iybs + iqs + 0] = v0;
y[iybs + iqs + y_offset] = v1;
}