metal : add dequantize_q8_0 kernel
This commit is contained in:
parent
c3e53b421a
commit
46a0881c7f
2 changed files with 26 additions and 6 deletions
|
@ -63,6 +63,7 @@ struct ggml_metal_context {
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
||||||
|
GGML_METAL_DECL_KERNEL(get_rows_q8_0);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
|
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
|
||||||
|
@ -188,6 +189,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
||||||
|
GGML_METAL_ADD_KERNEL(get_rows_q8_0);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
|
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
|
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
|
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
|
||||||
|
@ -899,6 +901,7 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
||||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
||||||
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
||||||
|
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
|
||||||
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
||||||
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
||||||
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
|
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
|
||||||
|
|
|
@ -18,6 +18,12 @@ typedef struct {
|
||||||
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
||||||
} block_q4_1;
|
} block_q4_1;
|
||||||
|
|
||||||
|
#define QK8_0 32
|
||||||
|
typedef struct {
|
||||||
|
half d; // delta
|
||||||
|
int8_t qs[QK8_0]; // quants
|
||||||
|
} block_q8_0;
|
||||||
|
|
||||||
kernel void kernel_add(
|
kernel void kernel_add(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
|
@ -1621,12 +1627,12 @@ template <typename type4x4>
|
||||||
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
||||||
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
||||||
const half d = il ? (xb->d / 16.h) : xb->d;
|
const half d = il ? (xb->d / 16.h) : xb->d;
|
||||||
const half m = il ? (-8.h * 16.h) : -8.h;
|
const half m = il ? ( -8.h * 16.h) : -8.h;
|
||||||
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
||||||
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
||||||
|
|
||||||
for (int i=0;i<8;i++) {
|
for (int i=0;i<8;i++) {
|
||||||
reg[i/2][2*(i%2)] = (((qs[i] & mask0)) + m) * d;
|
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
|
||||||
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
|
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1640,11 +1646,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
|
||||||
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
||||||
|
|
||||||
for (int i=0;i<8;i++) {
|
for (int i=0;i<8;i++) {
|
||||||
reg[i/2][2*(i%2)] = (((qs[i] & mask0)) * d) + m;
|
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
|
||||||
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
|
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename type4x4>
|
||||||
|
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
||||||
|
device const uint8_t * qs = ((device const uint8_t *)xb->qs);
|
||||||
|
const half d = xb->d;
|
||||||
|
|
||||||
|
for (int i=0;i<16;i++) {
|
||||||
|
reg[i/4][i%4] = (qs[i + 16*il] * d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
||||||
const half d = xb->d;
|
const half d = xb->d;
|
||||||
|
@ -1950,6 +1966,7 @@ typedef void (get_rows_t)(device const void *, device const int *, device float
|
||||||
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
||||||
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
||||||
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
||||||
|
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
|
||||||
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
||||||
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
||||||
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue