add op GGML_OP_MUL_MAT_ID for Q4_0_N_M with runtime repack
This commit is contained in:
parent
3a042b4872
commit
0a2be72dca
2 changed files with 175 additions and 46 deletions
|
@ -3684,7 +3684,7 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block
|
|||
block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
|
||||
const block_q4_0 * src = (const block_q4_0 *)data;
|
||||
block_q4_0 dst_tmp[4];
|
||||
int nrow = t->ne[1]; // Number of rows
|
||||
int nrow = t->ne[1]*t->ne[2]*t->ne[3]; // Number of rows
|
||||
int nrows_interleaved = 4;
|
||||
int nblocks = t->ne[0] / QK4_0;
|
||||
|
||||
|
@ -3715,7 +3715,7 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block,
|
|||
block_q4_0x8 * dst = (block_q4_0x8*)t->data;
|
||||
const block_q4_0 * src = (const block_q4_0*) data;
|
||||
block_q4_0 dst_tmp[8];
|
||||
int nrow = t->ne[1]; // Number of rows
|
||||
int nrow = t->ne[1]*t->ne[2]*t->ne[3]; // Number of rows
|
||||
int nrows_interleaved = 8;
|
||||
int nblocks = t->ne[0] / QK4_0;
|
||||
|
||||
|
@ -3779,7 +3779,7 @@ static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_b
|
|||
block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data;
|
||||
const block_iq4_nl * src = (const block_iq4_nl *)data;
|
||||
block_iq4_nl dst_tmp[4];
|
||||
int nrow = t->ne[1]; // Number of rows
|
||||
int nrow = t->ne[1]*t->ne[2]*t->ne[3]; // Number of rows
|
||||
int nrows_interleaved = 4;
|
||||
int nblocks = t->ne[0] / QK4_0;
|
||||
|
||||
|
@ -3877,16 +3877,44 @@ class tensor_traits_base : public ggml::cpu::tensor_traits {
|
|||
};
|
||||
|
||||
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
|
||||
|
||||
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
|
||||
// not realy a GGML_TYPE_Q8_0 but same size.
|
||||
size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1]));
|
||||
return true;
|
||||
switch (op->op) {
|
||||
case GGML_OP_MUL_MAT:
|
||||
size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1]));
|
||||
return true;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1]));
|
||||
size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
|
||||
size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
|
||||
return true;
|
||||
default:
|
||||
// GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * src1 = op->src[1];
|
||||
struct ggml_tensor * dst = op;
|
||||
switch (op->op) {
|
||||
case GGML_OP_MUL_MAT:
|
||||
forward_mul_mat(params, op);
|
||||
return true;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
forward_mul_mat_id(params, op);
|
||||
return true;
|
||||
default:
|
||||
// GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
|
||||
const ggml_tensor * src0 = op->src[0];
|
||||
const ggml_tensor * src1 = op->src[1];
|
||||
ggml_tensor * dst = op;
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
|
@ -3935,7 +3963,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|||
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
|
||||
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
||||
if (src0_start >= src0_end) {
|
||||
return true;
|
||||
return;
|
||||
}
|
||||
|
||||
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
||||
|
@ -3950,8 +3978,129 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|||
(const char *) src1_wdata + (src1_col_stride * iter), 1,
|
||||
src0_end - src0_start);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) {
|
||||
const ggml_tensor * src0 = op->src[0];
|
||||
const ggml_tensor * src1 = op->src[1];
|
||||
const ggml_tensor * ids = op->src[2];
|
||||
ggml_tensor * dst = op;
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float;
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == ggml_type_size(src0->type));
|
||||
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb0 <= nb1);
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
GGML_ASSERT(ne03 == 1);
|
||||
GGML_ASSERT(ne13 == 1);
|
||||
GGML_ASSERT(ne3 == 1);
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
// row groups
|
||||
const int n_ids = ids->ne[0]; // n_expert_used
|
||||
const int n_as = ne02; // n_expert
|
||||
|
||||
const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10);
|
||||
const size_t nbw2 = nbw1*ne11;
|
||||
const size_t nbw3 = nbw2*ne12;
|
||||
|
||||
struct mmid_row_mapping {
|
||||
int32_t i1;
|
||||
int32_t i2;
|
||||
};
|
||||
|
||||
GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
|
||||
n_as * ne12 * sizeof(mmid_row_mapping)));
|
||||
|
||||
auto wdata = (char *) params->wdata;
|
||||
auto wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
|
||||
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
|
||||
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
|
||||
|
||||
// src1: float32 => block_q8_0
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
|
||||
from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
|
||||
(void *) (wdata + i12 * nbw2 + i11 * nbw1),
|
||||
ne10);
|
||||
}
|
||||
}
|
||||
|
||||
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)]
|
||||
|
||||
if (ith == 0) {
|
||||
// initialize matrix_row_counts
|
||||
memset(matrix_row_counts, 0, n_as * sizeof(int64_t));
|
||||
|
||||
// group rows by src0 matrix
|
||||
for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
|
||||
for (int32_t id = 0; id < n_ids; ++id) {
|
||||
const int32_t i02 =
|
||||
*(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
|
||||
|
||||
GGML_ASSERT(i02 >= 0 && i02 < n_as);
|
||||
|
||||
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
|
||||
matrix_row_counts[i02] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
// compute each matrix multiplication in sequence
|
||||
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
|
||||
const int64_t cne1 = matrix_row_counts[cur_a];
|
||||
|
||||
if (cne1 == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto src0_cur = (const char *) src0->data + cur_a*nb02;
|
||||
|
||||
//const int64_t nr0 = ne01; // src0 rows
|
||||
const int64_t nr1 = cne1; // src1 rows
|
||||
|
||||
int64_t src0_cur_start = (ith * ne01) / nth;
|
||||
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
|
||||
src0_cur_start =
|
||||
(src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
|
||||
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
|
||||
|
||||
if (src0_cur_start >= src0_cur_end) return;
|
||||
|
||||
for (int ir1 = 0; ir1 < nr1; ir1++) {
|
||||
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
|
||||
const int id = row_mapping.i1; // selected expert index
|
||||
|
||||
const int64_t i11 = id % ne11;
|
||||
const int64_t i12 = row_mapping.i2; // row index in src1
|
||||
|
||||
const int64_t i1 = id; // selected expert index
|
||||
const int64_t i2 = i12; // row
|
||||
|
||||
auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
|
||||
|
||||
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(
|
||||
ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
|
||||
ne01, src0_cur + src0_cur_start * nb01,
|
||||
src1_col, 1, src0_cur_end - src0_cur_start);
|
||||
}
|
||||
}
|
||||
#undef MMID_MATRIX_ROW
|
||||
}
|
||||
|
||||
int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
|
||||
|
@ -4048,16 +4197,28 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|||
// return true;
|
||||
//}
|
||||
// may be possible if Q8_0 packed...
|
||||
} else if (op->op == GGML_OP_MUL_MAT_ID && op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 3) &&
|
||||
op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type() &&
|
||||
ggml_aarch64_get_optimal_repack_type(op->src[0])) {
|
||||
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
//if (op->src[1]->type == GGML_TYPE_Q8_0) {
|
||||
// return true;
|
||||
//}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
|
||||
if (op->op == GGML_OP_MUL_MAT && op->src[0]->buffer &&
|
||||
op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type()) {
|
||||
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
||||
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {
|
||||
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type()) {
|
||||
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -7576,8 +7576,6 @@ static void ggml_compute_forward_mul_mat_id(
|
|||
ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
|
||||
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
|
||||
ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
|
||||
//int64_t const matmul_num_cols = type_traits_cpu[type].ncols;
|
||||
//ggml_gemv_t const gemv = type_traits_cpu[type].gemv;
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == ggml_type_size(type));
|
||||
|
@ -7663,36 +7661,6 @@ static void ggml_compute_forward_mul_mat_id(
|
|||
const int64_t nr0 = ne01; // src0 rows
|
||||
const int64_t nr1 = cne1; // src1 rows
|
||||
|
||||
#if 0
|
||||
// see what to do with that. the dynamic rescale have not be write.
|
||||
if (((ggml_n_dims(src0) - 1) == 2) && gemv) {
|
||||
int64_t src0_cur_start = (ith * ne01) / nth;
|
||||
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
|
||||
src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start;
|
||||
src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end;
|
||||
if (src0_cur_start >= src0_cur_end) return;
|
||||
|
||||
for (int ir1 = 0; ir1 < nr1; ir1++) {
|
||||
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
|
||||
const int id = row_mapping.i1; // selected expert index
|
||||
|
||||
const int64_t i11 = id % ne11;
|
||||
const int64_t i12 = row_mapping.i2; // row index in src1
|
||||
|
||||
const int64_t i1 = id; // selected expert index
|
||||
const int64_t i2 = i12; // row
|
||||
|
||||
const char * src1_col = (const char *) wdata +
|
||||
(src1_cont || src1->type != vec_dot_type
|
||||
? (i11 + i12 * ne11) * row_size
|
||||
: (i11 * nb11 + i12 * nb12));
|
||||
|
||||
gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
|
||||
(const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
// distribute the thread work across the inner or outer loop based on which one is larger
|
||||
|
||||
const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue