dequant_q4_0 kernel
This commit is contained in:
parent
cb5cb4d6e2
commit
c8ff09bdc7
3 changed files with 69 additions and 4 deletions
|
@ -71,7 +71,7 @@ vk::Device vk_device;
|
|||
vk::CommandPool vk_command_pool_compute, vk_command_pool_transfer;
|
||||
VmaAllocator vk_allocator;
|
||||
vk_pipeline vk_pipeline_matmul_f32, vk_pipeline_matmul_f16;
|
||||
vk_pipeline vk_pipeline_f16_to_f32;
|
||||
vk_pipeline vk_pipeline_f16_to_f32, vk_pipeline_dequant_q4_0;
|
||||
VmaAllocation vk_buffer_qa_alloc, vk_buffer_a_alloc, vk_buffer_b_alloc, vk_buffer_c_alloc;
|
||||
vk::Buffer vk_buffer_qa, vk_buffer_a, vk_buffer_b, vk_buffer_c;
|
||||
|
||||
|
@ -332,6 +332,7 @@ void ggml_vk_init(void) {
|
|||
}
|
||||
|
||||
vk_pipeline_f16_to_f32 = ggml_vk_create_pipeline("vk_shaders/f16_to_f32.spv", "main", 2, 1, {32, 1, 1});
|
||||
vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline("vk_shaders/dequant_q4_0.spv", "main", 2, 1, {32, 1, 1});
|
||||
|
||||
// Command pools
|
||||
vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(), vk_compute_queue_family_index);
|
||||
|
@ -359,8 +360,8 @@ void ggml_vk_init(void) {
|
|||
|
||||
static vk_pipeline* ggml_get_to_fp32_vk(ggml_type type) {
|
||||
switch (type) {
|
||||
// case GGML_TYPE_Q4_0:
|
||||
// return &dequantize_row_q4_0_cl;
|
||||
case GGML_TYPE_Q4_0:
|
||||
return &vk_pipeline_dequant_q4_0;
|
||||
// case GGML_TYPE_Q4_1:
|
||||
// return &dequantize_row_q4_1_cl;
|
||||
// case GGML_TYPE_Q5_0:
|
||||
|
@ -1022,7 +1023,7 @@ bool ggml_vk_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
|
|||
const int64_t ne1 = dst->ne[1];
|
||||
|
||||
// TODO: find the optimal values for these
|
||||
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 /*|| ggml_is_quantized(src0->type)*/) &&
|
||||
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
|
||||
src1->type == GGML_TYPE_F32 &&
|
||||
dst->type == GGML_TYPE_F32 &&
|
||||
((ne0 >= 128 && ne1 >= 32 && ne10 >= 128) || src0->backend == GGML_BACKEND_GPU)) {
|
||||
|
|
7
ggml.c
7
ggml.c
|
@ -11044,6 +11044,13 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||
}
|
||||
return;
|
||||
}
|
||||
#elif defined(GGML_USE_VULKAN)
|
||||
if (ggml_vk_can_mul_mat(src0, src1, dst)) {
|
||||
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
||||
ggml_vk_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
|
|
57
vk_shaders/dequant_q4_0.glsl
Normal file
57
vk_shaders/dequant_q4_0.glsl
Normal file
|
@ -0,0 +1,57 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||
|
||||
#define QUANT_K 32
|
||||
#define QUANT_R 2
|
||||
|
||||
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
struct block_q4_0
|
||||
{
|
||||
float16_t d;
|
||||
uint8_t qs[16];
|
||||
};
|
||||
|
||||
layout (binding = 0) readonly buffer A { block_q4_0 x[]; };
|
||||
layout (binding = 1) writeonly buffer D { float y[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int N;
|
||||
} p;
|
||||
|
||||
void main() {
|
||||
const int idx = int(gl_GlobalInvocationID.x);
|
||||
|
||||
const int i = int(gl_WorkGroupID.x * gl_WorkGroupSize.x + gl_LocalInvocationID.x*2);
|
||||
|
||||
if (idx >= p.N) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int qk = QUANT_K;
|
||||
const int qr = QUANT_R;
|
||||
|
||||
const int ib = i/qk; // block index
|
||||
const int iqs = (i%qk)/qr; // quant index
|
||||
const int iybs = i - i%qk; // y block start index
|
||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
// dequantize
|
||||
float v0, v1;
|
||||
const float d = float(x[ib].d);
|
||||
|
||||
const uint8_t vui = x[ib].qs[iqs];
|
||||
|
||||
const int8_t vi0 = int8_t(vui & 0xF);
|
||||
const int8_t vi1 = int8_t(vui >> 4);
|
||||
|
||||
v0 = (vi0 - 8)*d;
|
||||
v1 = (vi1 - 8)*d;
|
||||
|
||||
y[iybs + iqs + 0] = v0;
|
||||
y[iybs + iqs + y_offset] = v1;
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue