k_quants: call them _K, not _k, also on Metal
This commit is contained in:
parent
285eeb1531
commit
8b98d01e31
2 changed files with 67 additions and 67 deletions
60
ggml-metal.m
60
ggml-metal.m
|
@ -51,21 +51,21 @@ struct ggml_metal_context {
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q2_k);
|
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q3_k);
|
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q4_k);
|
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q5_k);
|
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q6_k);
|
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
||||||
GGML_METAL_DECL_KERNEL(rms_norm);
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
||||||
GGML_METAL_DECL_KERNEL(norm);
|
GGML_METAL_DECL_KERNEL(norm);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q3_k_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(rope);
|
GGML_METAL_DECL_KERNEL(rope);
|
||||||
GGML_METAL_DECL_KERNEL(alibi_f32);
|
GGML_METAL_DECL_KERNEL(alibi_f32);
|
||||||
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
||||||
|
@ -165,21 +165,21 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q2_k);
|
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q3_k);
|
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q4_k);
|
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q5_k);
|
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q6_k);
|
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
||||||
GGML_METAL_ADD_KERNEL(rms_norm);
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
||||||
GGML_METAL_ADD_KERNEL(norm);
|
GGML_METAL_ADD_KERNEL(norm);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q3_k_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(rope);
|
GGML_METAL_ADD_KERNEL(rope);
|
||||||
GGML_METAL_ADD_KERNEL(alibi_f32);
|
GGML_METAL_ADD_KERNEL(alibi_f32);
|
||||||
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
||||||
|
@ -668,7 +668,7 @@ void ggml_metal_graph_compute(
|
||||||
|
|
||||||
nth0 = 4;
|
nth0 = 4;
|
||||||
nth1 = 16;
|
nth1 = 16;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
{
|
{
|
||||||
|
@ -677,7 +677,7 @@ void ggml_metal_graph_compute(
|
||||||
|
|
||||||
nth0 = 4;
|
nth0 = 4;
|
||||||
nth1 = 16;
|
nth1 = 16;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
{
|
{
|
||||||
|
@ -686,7 +686,7 @@ void ggml_metal_graph_compute(
|
||||||
|
|
||||||
nth0 = 4;
|
nth0 = 4;
|
||||||
nth1 = 16;
|
nth1 = 16;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
{
|
{
|
||||||
|
@ -695,7 +695,7 @@ void ggml_metal_graph_compute(
|
||||||
|
|
||||||
nth0 = 4;
|
nth0 = 4;
|
||||||
nth1 = 16;
|
nth1 = 16;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
{
|
{
|
||||||
|
@ -704,7 +704,7 @@ void ggml_metal_graph_compute(
|
||||||
|
|
||||||
nth0 = 4;
|
nth0 = 4;
|
||||||
nth1 = 16;
|
nth1 = 16;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
|
@ -756,11 +756,11 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
||||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
||||||
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
||||||
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
|
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
||||||
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break;
|
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
||||||
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
|
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
|
||||||
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break;
|
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
|
||||||
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
|
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
|
||||||
default: GGML_ASSERT(false && "not implemented");
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -792,7 +792,7 @@ typedef struct {
|
||||||
uint8_t qs[QK_K/4]; // quants
|
uint8_t qs[QK_K/4]; // quants
|
||||||
half d; // super-block scale for quantized scales
|
half d; // super-block scale for quantized scales
|
||||||
half dmin; // super-block scale for quantized mins
|
half dmin; // super-block scale for quantized mins
|
||||||
} block_q2_k;
|
} block_q2_K;
|
||||||
// 84 bytes / block
|
// 84 bytes / block
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
@ -804,20 +804,20 @@ typedef struct {
|
||||||
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
|
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
|
||||||
#endif
|
#endif
|
||||||
half d; // super-block scale
|
half d; // super-block scale
|
||||||
} block_q3_k;
|
} block_q3_K;
|
||||||
|
|
||||||
#if QK_K == 64
|
#if QK_K == 64
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half4 d; // super-block scales/mins
|
half4 d; // super-block scales/mins
|
||||||
uint8_t qs[QK_K/2]; // 4-bit quants
|
uint8_t qs[QK_K/2]; // 4-bit quants
|
||||||
} block_q4_k;
|
} block_q4_K;
|
||||||
#else
|
#else
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // super-block scale for quantized scales
|
half d; // super-block scale for quantized scales
|
||||||
half dmin; // super-block scale for quantized mins
|
half dmin; // super-block scale for quantized mins
|
||||||
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
||||||
uint8_t qs[QK_K/2]; // 4--bit quants
|
uint8_t qs[QK_K/2]; // 4--bit quants
|
||||||
} block_q4_k;
|
} block_q4_K;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if QK_K == 64
|
#if QK_K == 64
|
||||||
|
@ -825,7 +825,7 @@ typedef struct {
|
||||||
half4 d; // super-block scales/mins
|
half4 d; // super-block scales/mins
|
||||||
uint8_t qh[QK_K/8]; // quants, high bit
|
uint8_t qh[QK_K/8]; // quants, high bit
|
||||||
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
||||||
} block_q5_k;
|
} block_q5_K;
|
||||||
#else
|
#else
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // super-block scale for quantized scales
|
half d; // super-block scale for quantized scales
|
||||||
|
@ -833,7 +833,7 @@ typedef struct {
|
||||||
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
|
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
|
||||||
uint8_t qh[QK_K/8]; // quants, high bit
|
uint8_t qh[QK_K/8]; // quants, high bit
|
||||||
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
||||||
} block_q5_k;
|
} block_q5_K;
|
||||||
// 176 bytes / block
|
// 176 bytes / block
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -842,7 +842,7 @@ typedef struct {
|
||||||
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
||||||
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
||||||
half d; // super-block scale
|
half d; // super-block scale
|
||||||
} block_q6_k;
|
} block_q6_K;
|
||||||
// 210 bytes / block
|
// 210 bytes / block
|
||||||
|
|
||||||
static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
||||||
|
@ -863,7 +863,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
||||||
|
|
||||||
//========================================== dequantization =============================
|
//========================================== dequantization =============================
|
||||||
|
|
||||||
static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) {
|
static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
const int nb = k / QK_K;
|
const int nb = k / QK_K;
|
||||||
|
|
||||||
|
@ -910,7 +910,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) {
|
static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
const int nb = k / QK_K;
|
const int nb = k / QK_K;
|
||||||
|
|
||||||
|
@ -992,7 +992,7 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
|
static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
const int nb = k / QK_K;
|
const int nb = k / QK_K;
|
||||||
|
|
||||||
|
@ -1028,7 +1028,7 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) {
|
static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
const int nb = k / QK_K;
|
const int nb = k / QK_K;
|
||||||
|
|
||||||
|
@ -1077,7 +1077,7 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
|
static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
const int nb = k / QK_K;
|
const int nb = k / QK_K;
|
||||||
|
|
||||||
|
@ -1123,7 +1123,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_get_rows_q2_k(
|
kernel void kernel_get_rows_q2_K(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const int * src1,
|
device const int * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
@ -1134,12 +1134,12 @@ kernel void kernel_get_rows_q2_k(
|
||||||
const int i = tpig;
|
const int i = tpig;
|
||||||
const int r = ((device int32_t *) src1)[i];
|
const int r = ((device int32_t *) src1)[i];
|
||||||
|
|
||||||
dequantize_row_q2_k(
|
dequantize_row_q2_K(
|
||||||
(device const block_q2_k *) ((device char *) src0 + r*nb01),
|
(device const block_q2_K *) ((device char *) src0 + r*nb01),
|
||||||
(device float *) ((device char *) dst + i*nb1), ne00);
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_get_rows_q3_k(
|
kernel void kernel_get_rows_q3_K(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const int * src1,
|
device const int * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
@ -1150,12 +1150,12 @@ kernel void kernel_get_rows_q3_k(
|
||||||
const int i = tpig;
|
const int i = tpig;
|
||||||
const int r = ((device int32_t *) src1)[i];
|
const int r = ((device int32_t *) src1)[i];
|
||||||
|
|
||||||
dequantize_row_q3_k(
|
dequantize_row_q3_K(
|
||||||
(device const block_q3_k *) ((device char *) src0 + r*nb01),
|
(device const block_q3_K *) ((device char *) src0 + r*nb01),
|
||||||
(device float *) ((device char *) dst + i*nb1), ne00);
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_get_rows_q4_k(
|
kernel void kernel_get_rows_q4_K(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const int * src1,
|
device const int * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
@ -1166,12 +1166,12 @@ kernel void kernel_get_rows_q4_k(
|
||||||
const int i = tpig;
|
const int i = tpig;
|
||||||
const int r = ((device int32_t *) src1)[i];
|
const int r = ((device int32_t *) src1)[i];
|
||||||
|
|
||||||
dequantize_row_q4_k(
|
dequantize_row_q4_K(
|
||||||
(device const block_q4_k *) ((device char *) src0 + r*nb01),
|
(device const block_q4_K *) ((device char *) src0 + r*nb01),
|
||||||
(device float *) ((device char *) dst + i*nb1), ne00);
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_get_rows_q5_k(
|
kernel void kernel_get_rows_q5_K(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const int * src1,
|
device const int * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
@ -1182,12 +1182,12 @@ kernel void kernel_get_rows_q5_k(
|
||||||
const int i = tpig;
|
const int i = tpig;
|
||||||
const int r = ((device int32_t *) src1)[i];
|
const int r = ((device int32_t *) src1)[i];
|
||||||
|
|
||||||
dequantize_row_q5_k(
|
dequantize_row_q5_K(
|
||||||
(device const block_q5_k *) ((device char *) src0 + r*nb01),
|
(device const block_q5_K *) ((device char *) src0 + r*nb01),
|
||||||
(device float *) ((device char *) dst + i*nb1), ne00);
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_get_rows_q6_k(
|
kernel void kernel_get_rows_q6_K(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const int * src1,
|
device const int * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
@ -1198,14 +1198,14 @@ kernel void kernel_get_rows_q6_k(
|
||||||
const int i = tpig;
|
const int i = tpig;
|
||||||
const int r = ((device int32_t *) src1)[i];
|
const int r = ((device int32_t *) src1)[i];
|
||||||
|
|
||||||
dequantize_row_q6_k(
|
dequantize_row_q6_K(
|
||||||
(device const block_q6_k *) ((device char *) src0 + r*nb01),
|
(device const block_q6_K *) ((device char *) src0 + r*nb01),
|
||||||
(device float *) ((device char *) dst + i*nb1), ne00);
|
(device float *) ((device char *) dst + i*nb1), ne00);
|
||||||
}
|
}
|
||||||
|
|
||||||
//====================================== dot products =========================
|
//====================================== dot products =========================
|
||||||
|
|
||||||
kernel void kernel_mul_mat_q2_k_f32(
|
kernel void kernel_mul_mat_q2_K_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
@ -1222,7 +1222,7 @@ kernel void kernel_mul_mat_q2_k_f32(
|
||||||
const int64_t r0 = tgpig.x;
|
const int64_t r0 = tgpig.x;
|
||||||
const int64_t r1 = tgpig.y;
|
const int64_t r1 = tgpig.y;
|
||||||
|
|
||||||
device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb;
|
device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
|
||||||
device const float * yy = (device const float *) src1 + r1*ne10;
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
||||||
|
|
||||||
const int nth = tptg.x*tptg.y;
|
const int nth = tptg.x*tptg.y;
|
||||||
|
@ -1317,7 +1317,7 @@ kernel void kernel_mul_mat_q2_k_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mat_q3_k_f32(
|
kernel void kernel_mul_mat_q3_K_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
@ -1341,7 +1341,7 @@ kernel void kernel_mul_mat_q3_k_f32(
|
||||||
const int64_t r0 = tgpig.x;
|
const int64_t r0 = tgpig.x;
|
||||||
const int64_t r1 = tgpig.y;
|
const int64_t r1 = tgpig.y;
|
||||||
|
|
||||||
device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb;
|
device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
|
||||||
device const float * yy = (device const float *) src1 + r1*ne10;
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
||||||
|
|
||||||
const int nth = tptg.x*tptg.y;
|
const int nth = tptg.x*tptg.y;
|
||||||
|
@ -1455,7 +1455,7 @@ kernel void kernel_mul_mat_q3_k_f32(
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mat_q4_k_f32(
|
kernel void kernel_mul_mat_q4_K_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
@ -1475,7 +1475,7 @@ kernel void kernel_mul_mat_q4_k_f32(
|
||||||
const int nth = tptg.x*tptg.y;
|
const int nth = tptg.x*tptg.y;
|
||||||
const int ith = tptg.y*tpitg.x + tpitg.y;
|
const int ith = tptg.y*tpitg.x + tpitg.y;
|
||||||
|
|
||||||
device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
|
device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
|
||||||
device const float * yy = (device const float *) src1 + r1*ne10;
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
|
@ -1580,7 +1580,7 @@ kernel void kernel_mul_mat_q4_k_f32(
|
||||||
//}
|
//}
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mat_q5_k_f32(
|
kernel void kernel_mul_mat_q5_K_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
@ -1597,7 +1597,7 @@ kernel void kernel_mul_mat_q5_k_f32(
|
||||||
const int64_t r0 = tgpig.x;
|
const int64_t r0 = tgpig.x;
|
||||||
const int64_t r1 = tgpig.y;
|
const int64_t r1 = tgpig.y;
|
||||||
|
|
||||||
device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb;
|
device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
|
||||||
device const float * yy = (device const float *) src1 + r1*ne10;
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
||||||
|
|
||||||
const int nth = tptg.x*tptg.y;
|
const int nth = tptg.x*tptg.y;
|
||||||
|
@ -1704,7 +1704,7 @@ kernel void kernel_mul_mat_q5_k_f32(
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mat_q6_k_f32(
|
kernel void kernel_mul_mat_q6_K_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
@ -1726,7 +1726,7 @@ kernel void kernel_mul_mat_q6_k_f32(
|
||||||
const int64_t r0 = tgpig.x;
|
const int64_t r0 = tgpig.x;
|
||||||
const int64_t r1 = tgpig.y;
|
const int64_t r1 = tgpig.y;
|
||||||
|
|
||||||
device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
|
device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
|
||||||
device const float * yy = (device const float *) src1 + r1*ne10;
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
||||||
|
|
||||||
const int nth = tptg.x*tptg.y;
|
const int nth = tptg.x*tptg.y;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue