ggml : extend quantize_fns_t with "vec_dot_type"
This commit is contained in:
parent
46fc696dea
commit
91bfa51dca
2 changed files with 21 additions and 12 deletions
32
ggml.c
32
ggml.c
|
@ -1838,6 +1838,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
||||||
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
|
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
|
||||||
.quantize_row_q_dot = quantize_row_q8_0,
|
.quantize_row_q_dot = quantize_row_q8_0,
|
||||||
.vec_dot_q = ggml_vec_dot_q4_0_q8_0,
|
.vec_dot_q = ggml_vec_dot_q4_0_q8_0,
|
||||||
|
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||||
},
|
},
|
||||||
[GGML_TYPE_Q4_1] = {
|
[GGML_TYPE_Q4_1] = {
|
||||||
.dequantize_row_q = dequantize_row_q4_1,
|
.dequantize_row_q = dequantize_row_q4_1,
|
||||||
|
@ -1845,6 +1846,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
||||||
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
|
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
|
||||||
.quantize_row_q_dot = quantize_row_q8_1,
|
.quantize_row_q_dot = quantize_row_q8_1,
|
||||||
.vec_dot_q = ggml_vec_dot_q4_1_q8_1,
|
.vec_dot_q = ggml_vec_dot_q4_1_q8_1,
|
||||||
|
.vec_dot_type = GGML_TYPE_Q8_1,
|
||||||
},
|
},
|
||||||
[GGML_TYPE_Q4_2] = {
|
[GGML_TYPE_Q4_2] = {
|
||||||
.dequantize_row_q = dequantize_row_q4_2,
|
.dequantize_row_q = dequantize_row_q4_2,
|
||||||
|
@ -1852,6 +1854,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
||||||
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
|
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
|
||||||
.quantize_row_q_dot = quantize_row_q8_0,
|
.quantize_row_q_dot = quantize_row_q8_0,
|
||||||
.vec_dot_q = ggml_vec_dot_q4_2_q8_0,
|
.vec_dot_q = ggml_vec_dot_q4_2_q8_0,
|
||||||
|
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||||
},
|
},
|
||||||
[GGML_TYPE_Q4_3] = {
|
[GGML_TYPE_Q4_3] = {
|
||||||
.dequantize_row_q = dequantize_row_q4_3,
|
.dequantize_row_q = dequantize_row_q4_3,
|
||||||
|
@ -1859,6 +1862,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
||||||
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference,
|
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference,
|
||||||
.quantize_row_q_dot = quantize_row_q8_1,
|
.quantize_row_q_dot = quantize_row_q8_1,
|
||||||
.vec_dot_q = ggml_vec_dot_q4_3_q8_1,
|
.vec_dot_q = ggml_vec_dot_q4_3_q8_1,
|
||||||
|
.vec_dot_type = GGML_TYPE_Q8_1,
|
||||||
},
|
},
|
||||||
[GGML_TYPE_Q8_0] = {
|
[GGML_TYPE_Q8_0] = {
|
||||||
.dequantize_row_q = dequantize_row_q8_0,
|
.dequantize_row_q = dequantize_row_q8_0,
|
||||||
|
@ -1866,6 +1870,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
||||||
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
|
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
|
||||||
.quantize_row_q_dot = quantize_row_q8_0,
|
.quantize_row_q_dot = quantize_row_q8_0,
|
||||||
.vec_dot_q = ggml_vec_dot_q8_0_q8_0,
|
.vec_dot_q = ggml_vec_dot_q8_0_q8_0,
|
||||||
|
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||||
},
|
},
|
||||||
[GGML_TYPE_Q8_1] = {
|
[GGML_TYPE_Q8_1] = {
|
||||||
.dequantize_row_q = NULL, // TODO
|
.dequantize_row_q = NULL, // TODO
|
||||||
|
@ -1873,6 +1878,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
||||||
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_1_reference,
|
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_1_reference,
|
||||||
.quantize_row_q_dot = quantize_row_q8_1,
|
.quantize_row_q_dot = quantize_row_q8_1,
|
||||||
.vec_dot_q = NULL, // TODO
|
.vec_dot_q = NULL, // TODO
|
||||||
|
.vec_dot_type = GGML_TYPE_Q8_1,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -2476,9 +2482,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||||
const int nb = n / QK8_1;
|
const int nb = n / QK8_0;
|
||||||
|
|
||||||
assert(n % QK8_1 == 0);
|
assert(n % QK8_0 == 0);
|
||||||
assert(nb % 2 == 0);
|
assert(nb % 2 == 0);
|
||||||
|
|
||||||
const block_q4_0 * restrict x = vx;
|
const block_q4_0 * restrict x = vx;
|
||||||
|
@ -2627,7 +2633,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
||||||
const int8_t * restrict p1 = y[i].qs;
|
const int8_t * restrict p1 = y[i].qs;
|
||||||
|
|
||||||
int sumi = 0;
|
int sumi = 0;
|
||||||
for (int j = 0; j < QK8_1/2; j++) {
|
for (int j = 0; j < QK8_0/2; j++) {
|
||||||
const uint8_t v0 = p0[j];
|
const uint8_t v0 = p0[j];
|
||||||
|
|
||||||
const int i0 = (int8_t) (v0 & 0xf) - 8;
|
const int i0 = (int8_t) (v0 & 0xf) - 8;
|
||||||
|
@ -2779,11 +2785,11 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||||
const int nb = n / QK8_1;
|
const int nb = n / QK8_0;
|
||||||
|
|
||||||
assert(n % QK8_1 == 0);
|
assert(n % QK8_0 == 0);
|
||||||
assert(nb % 2 == 0);
|
assert(nb % 2 == 0);
|
||||||
assert(QK8_1 == 2*QK4_2);
|
assert(QK8_0 == 2*QK4_2);
|
||||||
|
|
||||||
const block_q4_2 * restrict x = vx;
|
const block_q4_2 * restrict x = vx;
|
||||||
const block_q8_0 * restrict y = vy;
|
const block_q8_0 * restrict y = vy;
|
||||||
|
@ -2908,7 +2914,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
|
||||||
int sumi_0 = 0;
|
int sumi_0 = 0;
|
||||||
int sumi_1 = 0;
|
int sumi_1 = 0;
|
||||||
|
|
||||||
for (int j = 0; j < QK8_1/4; j++) {
|
for (int j = 0; j < QK8_0/4; j++) {
|
||||||
const uint8_t v0 = x0[j];
|
const uint8_t v0 = x0[j];
|
||||||
const uint8_t v1 = x1[j];
|
const uint8_t v1 = x1[j];
|
||||||
|
|
||||||
|
@ -2921,8 +2927,8 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
|
||||||
const int i2_0 = y0[2*j + 0];
|
const int i2_0 = y0[2*j + 0];
|
||||||
const int i3_0 = y0[2*j + 1];
|
const int i3_0 = y0[2*j + 1];
|
||||||
|
|
||||||
const int i2_1 = y0[2*(j + QK8_1/4) + 0];
|
const int i2_1 = y0[2*(j + QK8_0/4) + 0];
|
||||||
const int i3_1 = y0[2*(j + QK8_1/4) + 1];
|
const int i3_1 = y0[2*(j + QK8_0/4) + 1];
|
||||||
|
|
||||||
sumi_0 += i0_0*i2_0 + i1_0*i3_0;
|
sumi_0 += i0_0*i2_0 + i1_0*i3_0;
|
||||||
sumi_1 += i0_1*i2_1 + i1_1*i3_1;
|
sumi_1 += i0_1*i2_1 + i1_1*i3_1;
|
||||||
|
@ -8099,6 +8105,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
||||||
const enum ggml_type type = src0->type;
|
const enum ggml_type type = src0->type;
|
||||||
quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
|
quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
|
||||||
vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
|
vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
|
||||||
|
enum ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type;
|
||||||
|
|
||||||
// we don't support permuted src0 or src1
|
// we don't support permuted src0 or src1
|
||||||
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
|
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
|
||||||
|
@ -8235,7 +8242,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
||||||
|
|
||||||
if (params->type == GGML_TASK_INIT) {
|
if (params->type == GGML_TASK_INIT) {
|
||||||
char * wdata = params->wdata;
|
char * wdata = params->wdata;
|
||||||
const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_1]/GGML_BLCK_SIZE[GGML_TYPE_Q8_1];
|
const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
|
||||||
|
|
||||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||||
|
@ -8266,7 +8273,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
void * wdata = params->wdata;
|
void * wdata = params->wdata;
|
||||||
const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_1]/GGML_BLCK_SIZE[GGML_TYPE_Q8_1];
|
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
|
||||||
|
|
||||||
for (int ir = ir0; ir < ir1; ++ir) {
|
for (int ir = ir0; ir < ir1; ++ir) {
|
||||||
// src0 indices
|
// src0 indices
|
||||||
|
@ -11069,7 +11076,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
||||||
} else
|
} else
|
||||||
#endif
|
#endif
|
||||||
{
|
{
|
||||||
cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_1]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_1];
|
const enum ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type;
|
||||||
|
cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
|
1
ggml.h
1
ggml.h
|
@ -878,6 +878,7 @@ extern "C" {
|
||||||
quantize_row_q_t quantize_row_q_reference;
|
quantize_row_q_t quantize_row_q_reference;
|
||||||
quantize_row_q_t quantize_row_q_dot;
|
quantize_row_q_t quantize_row_q_dot;
|
||||||
vec_dot_q_t vec_dot_q;
|
vec_dot_q_t vec_dot_q;
|
||||||
|
enum ggml_type vec_dot_type;
|
||||||
} quantize_fns_t;
|
} quantize_fns_t;
|
||||||
|
|
||||||
quantize_fns_t ggml_internal_get_quantize_fn(size_t i);
|
quantize_fns_t ggml_internal_get_quantize_fn(size_t i);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue