ggml : group all experts in a single ggml_mul_mat_id

cuda : improve mmid row copy
This commit is contained in:
slaren 2024-04-03 21:03:45 +02:00
parent a307375c02
commit ea2b79534e
7 changed files with 405 additions and 162 deletions

View file

@ -1231,7 +1231,7 @@ static void ggml_cuda_op_mul_mat_cublas(
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool());
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
if (src0->type != GGML_TYPE_F16) {
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
@ -1241,7 +1241,7 @@ static void ggml_cuda_op_mul_mat_cublas(
}
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool());
ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool(id));
if (src1->type != GGML_TYPE_F16) {
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
@ -1250,7 +1250,7 @@ static void ggml_cuda_op_mul_mat_cublas(
to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
}
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(), row_diff*src1_ncols);
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;
@ -1960,20 +1960,84 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
}
}
struct mmid_row_mapping {
int64_t i1;
int64_t i2;
};
static __global__ void k_copy_src1_to_contiguous(const char * src1_original, char * src1_contiguous,
int * cur_src1_row, mmid_row_mapping * row_mapping,
const char * ids_dev, int64_t i02, int64_t ids_nb1, int64_t ids_nb0,
int64_t ids_ne1, int64_t n_ids,
int64_t ne11,
size_t nb11, size_t nb12) {
int64_t iid1 = blockIdx.x;
int64_t id = blockIdx.y;
if (iid1 >= ids_ne1 || id >= n_ids) {
return;
}
const int32_t row_id_i = *(const int32_t *) (ids_dev + iid1*ids_nb1 + id*ids_nb0);
if (row_id_i != i02) {
return;
}
const int64_t i11 = id % ne11;
const int64_t i12 = iid1;
__shared__ int src1_row;
if (threadIdx.x == 0) {
src1_row = atomicAdd(cur_src1_row, 1);
row_mapping[src1_row] = {id, iid1};
}
__syncthreads();
const char * src1_row_original = src1_original + i11*nb11 + i12*nb12;
char * src1_row_contiguous = src1_contiguous + src1_row*nb11;
for (int i = threadIdx.x; i < nb11; i += blockDim.x) {
src1_row_contiguous[i] = src1_row_original[i];
}
}
static __global__ void k_copy_dst_from_contiguous(char * dst_original, const char * dst_contiguous,
const mmid_row_mapping * row_mapping,
int64_t n_rows,
int64_t nb1, int64_t nb2) {
int64_t i = blockIdx.x;
if (i >= n_rows) {
return;
}
const int64_t i1 = row_mapping[i].i1;
const int64_t i2 = row_mapping[i].i2;
const char * dst_row_contiguous = dst_contiguous + i*nb1;
char * dst_row_original = dst_original + i1*nb1 + i2*nb2;
for (int j = threadIdx.x; j < nb1; j += blockDim.x) {
dst_row_original[j] = dst_row_contiguous[j];
}
}
//#define MMID_MEMCPY
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * ids = dst->src[2];
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
cudaStream_t stream = ctx.stream();
const size_t nb11 = src1->nb[1];
const size_t nb1 = dst->nb[1];
const int32_t id = ((int32_t *) dst->op_params)[0];
const int32_t n_as = src0->ne[2];
const int64_t n_as = ne02;
const int64_t n_ids = ids->ne[0];
std::vector<char> ids_host(ggml_nbytes(ids));
const char * ids_dev = (const char *) ids->data;
@ -1982,7 +2046,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
ggml_tensor src0_row = *src0;
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;
ggml_tensor dst_row = *dst;
char * src0_original = (char *) src0->data;
char * src1_original = (char *) src1->data;
@ -1990,19 +2054,39 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
src0_row.ne[2] = 1;
src0_row.ne[3] = 1;
src0_row.nb[3] = src0->nb[2];
src0_row.nb[3] = nb02;
if (src1->ne[1] == 1) {
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
src1_row.ne[1] = 1;
src1_row.ne[2] = 1;
src1_row.ne[3] = 1;
src1_row.nb[2] = nb11;
src1_row.nb[3] = nb11;
GGML_ASSERT(row_id >= 0 && row_id < n_as);
dst_row.ne[1] = 1;
dst_row.ne[2] = 1;
dst_row.ne[3] = 1;
dst_row.nb[2] = nb1;
dst_row.nb[3] = nb1;
src0_row.data = src0_original + row_id*src0->nb[2];
src1_row.data = src1_original + i01*src1->nb[1];
dst_row.data = dst_original + i01*dst->nb[1];
if (ne12 == 1) {
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
GGML_ASSERT(i02 >= 0 && i02 < n_as);
const int64_t i11 = id % ne11;
const int64_t i12 = iid1;
const int64_t i1 = id;
const int64_t i2 = i12;
src0_row.data = src0_original + i02*nb02;
src1_row.data = src1_original + i11*nb11 + i12*nb12;
dst_row.data = dst_original + i1*nb1 + i2*nb2;
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
}
}
} else {
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
@ -2011,55 +2095,104 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
src1_row.data = src1_contiguous.get();
dst_row.data = dst_contiguous.get();
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
for (int64_t i02 = 0; i02 < n_as; i02++) {
int64_t num_src1_rows = 0;
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
if (row_id_i != row_id) {
continue;
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
if (row_id_i != i02) {
continue;
}
GGML_ASSERT(i02 >= 0 && i02 < n_as);
#ifdef MMID_MEMCPY
const int64_t i11 = id % ne11;
const int64_t i12 = iid1;
CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11,
src1_original + i11*nb11 + i12*nb12,
nb11, cudaMemcpyDeviceToDevice, stream));
#endif
num_src1_rows++;
}
GGML_ASSERT(row_id >= 0 && row_id < n_as);
CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, src1_original + i01*nb11,
nb11, cudaMemcpyDeviceToDevice, stream));
num_src1_rows++;
}
if (num_src1_rows == 0) {
continue;
}
src0_row.data = src0_original + row_id*src0->nb[2];
#ifndef MMID_MEMCPY
ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
{
dim3 block_dims(std::min((uint)nb11, 1024u));
dim3 grid_dims(ids->ne[1], n_ids);
k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
src1_original, src1_contiguous.get(),
dev_cur_src1_row.get(), dev_row_mapping.get(),
ids_dev, i02, ids->nb[1], ids->nb[0],
ids->ne[1], n_ids,
ne11,
nb11, nb12);
CUDA_CHECK(cudaGetLastError());
}
#endif
src0_row.data = src0_original + i02*nb02;
GGML_ASSERT(nb11 == sizeof(float)*ne10);
GGML_ASSERT(nb1 == sizeof(float)*ne0);
src1_row.ne[1] = num_src1_rows;
dst_row.ne[1] = num_src1_rows;
src1_row.nb[1] = nb11;
src1_row.nb[2] = num_src1_rows*nb11;
src1_row.nb[3] = num_src1_rows*nb11;
dst_row.ne[1] = num_src1_rows;
dst_row.nb[1] = nb1;
dst_row.nb[2] = num_src1_rows*nb1;
dst_row.nb[3] = num_src1_rows*nb1;
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
num_src1_rows = 0;
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
if (row_id_i != row_id) {
continue;
}
GGML_ASSERT(row_id >= 0 && row_id < n_as);
CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1,
nb1, cudaMemcpyDeviceToDevice, stream));
num_src1_rows++;
#ifndef MMID_MEMCPY
{
dim3 block_dims(std::min((uint)nb1, 1024u));
dim3 grid_dims(num_src1_rows);
k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
dst_original, dst_contiguous.get(),
dev_row_mapping.get(),
num_src1_rows, nb1, nb2);
CUDA_CHECK(cudaGetLastError());
}
#endif
#ifdef MMID_MEMCPY
num_src1_rows = 0;
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
if (row_id_i != i02) {
continue;
}
GGML_ASSERT(i02 >= 0 && i02 < n_as);
const int64_t i1 = id;
const int64_t i2 = iid1;
CUDA_CHECK(cudaMemcpyAsync(dst_original + i1*nb1 + i2*nb2,
dst_contiguous.get() + num_src1_rows*nb1,
nb1, cudaMemcpyDeviceToDevice, stream));
num_src1_rows++;
}
}
#endif
}
}
}
@ -2487,7 +2620,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
const int min_batch_size = 32;
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
(op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
GGML_UNUSED(backend);
}

View file

@ -45,6 +45,8 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
vals[ix] = x0[ix];
}
__syncthreads();
#pragma unroll
for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
if (need_check && i0 + iy + 2*threadIdx.x >= k) {

127
ggml.c
View file

@ -4573,21 +4573,32 @@ void ggml_mul_mat_set_prec(
// ggml_mul_mat_id
// NOTE: id will be removed in the future and instead all the experts listed in ids will be computed
// this will allow computing all the used experts in a single matrix multiplication
/*
c = ggml_mul_mat_id(ctx, as, b, ids);
as -> [cols, rows, n_expert]
ids -> [n_experts_used, n_tokens] (i32)
b -> [cols, n_expert_used, n_tokens]
c -> [cols, n_expert_used, n_tokens]
in b, n_experts_used can be broadcasted to match the n_expert_used of ids
c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e in ids
*/
struct ggml_tensor * ggml_mul_mat_id(
struct ggml_context * ctx,
struct ggml_tensor * as,
struct ggml_tensor * ids,
int id,
struct ggml_tensor * b) {
struct ggml_tensor * b,
struct ggml_tensor * ids) {
GGML_ASSERT(!ggml_is_transposed(as));
GGML_ASSERT(ids->type == GGML_TYPE_I32);
GGML_ASSERT(as->ne[3] == 1); // as is 3d (one matrix per expert)
GGML_ASSERT(b->ne[3] == 1); // b is 3d
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
GGML_ASSERT(ids->ne[1] == b->ne[1]); // must have an expert per b row
GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
GGML_ASSERT(id >= 0 && id < ids->ne[0]); // valid id
GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row
GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
bool is_node = false;
@ -4595,11 +4606,9 @@ struct ggml_tensor * ggml_mul_mat_id(
is_node = true;
}
const int64_t ne[4] = { as->ne[1], b->ne[1], b->ne[2], b->ne[3] };
const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
ggml_set_op_params_i32(result, 0, id);
result->op = GGML_OP_MUL_MAT_ID;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = as;
@ -10958,10 +10967,10 @@ static void ggml_compute_forward_mul_mat_id(
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);
//GGML_ASSERT(ne0 == ne01);
//GGML_ASSERT(ne1 == ne11);
//GGML_ASSERT(ne2 == ne12);
//GGML_ASSERT(ne3 == ne13);
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type));
@ -10973,13 +10982,9 @@ static void ggml_compute_forward_mul_mat_id(
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
// broadcast is not supported with mmid
assert(ne12 == 1);
assert(ne13 == 1);
// row groups
const int id = ggml_get_op_params_i32(dst, 0);
const int n_as = src0->ne[2];
const int n_ids = ids->ne[0]; // n_expert_used
const int n_as = ne02; // n_expert
char * wdata_src1_end = (src1->type == vec_dot_type) ?
(char *) params->wdata :
@ -10988,8 +10993,6 @@ static void ggml_compute_forward_mul_mat_id(
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
int64_t * matrix_rows = matrix_row_counts + n_as; // [n_as][ne11]
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)]
if (params->type == GGML_TASK_TYPE_INIT) {
if (ith != 0) {
return;
@ -11015,13 +11018,21 @@ static void ggml_compute_forward_mul_mat_id(
GGML_ASSERT(wdata == wdata_src1_end);
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
// group rows by src0 matrix
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
#define MAKE_I64(lo, hi) (((int64_t)(lo)) | (((int64_t)(hi)) << 32))
#define LO_I64(i64) ((int32_t)(i64))
#define HI_I64(i64) ((int32_t)((i64) >> 32))
GGML_ASSERT(row_id >= 0 && row_id < n_as);
MMID_MATRIX_ROW(row_id, matrix_row_counts[row_id]) = i01;
matrix_row_counts[row_id] += 1;
// group rows by src0 matrix
for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
for (int id = 0; id < n_ids; ++id) {
const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
assert(i02 >= 0 && i02 < n_as);
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = MAKE_I64(iid1, id);
matrix_row_counts[i02] += 1;
}
}
return;
@ -11039,15 +11050,13 @@ static void ggml_compute_forward_mul_mat_id(
continue;
}
size_t src0_offset = cur_a*src0->nb[2];
const char * src0_cur = (const char *) src0->data + cur_a*nb02;
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = cne1*ne12*ne13; // src1 rows
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = cne1; // src1 rows
// distribute the thread work across the inner or outer loop based on which one is larger
@ -11066,13 +11075,11 @@ static void ggml_compute_forward_mul_mat_id(
const int64_t ir110 = dr1*ith1;
const int64_t ir111 = MIN(ir110 + dr1, nr1);
//printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
// threads with no work simply yield (not sure if it helps)
if (ir010 >= ir011 || ir110 >= ir111) {
sched_yield();
continue;
}
//if (ir010 >= ir011 || ir110 >= ir111) {
// sched_yield();
// continue;
//}
// block-tiling attempt
const int64_t blck_0 = 16;
@ -11084,20 +11091,15 @@ static void ggml_compute_forward_mul_mat_id(
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
const int64_t i13 = (ir1/(ne12*cne1)); // Note: currently, src1 is always a matrix
const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1);
const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11);
const int64_t _i12 = ir1; // logical row index for this expert
// broadcast src0 into src1
//const int64_t i03 = i13/r3;
//const int64_t i02 = i12/r2;
const int id = HI_I64(MMID_MATRIX_ROW(cur_a, _i12)); // selected expert index
const int64_t i1 = i11;
const int64_t i2 = i12;
const int64_t i3 = i13;
const int64_t i11 = id % ne11;
const int64_t i12 = LO_I64(MMID_MATRIX_ROW(cur_a, _i12)); // row index in src1
const char * src0_row = (const char *) src0->data + src0_offset;
const int64_t i1 = id; // selected expert index
const int64_t i2 = i12; // row
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
@ -11105,25 +11107,30 @@ static void ggml_compute_forward_mul_mat_id(
// TODO: this is a bit of a hack, we should probably have a better way to handle this
const char * src1_col = (const char *) wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
: (i11*nb11 + i12*nb12 + i13*nb13));
? (i11 + i12*ne11)*row_size
: (i11*nb11 + i12*nb12));
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
//}
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_row + ir0*nb01, 0, src1_col, 0, 1);
vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
}
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
}
}
}
}
#undef MMID_MATRIX_ROW
#undef MMID_MATRIX_ROW
#undef MAKE_I64
#undef LO_I64
#undef HI_I64
}
// ggml_compute_forward_out_prod
@ -18462,7 +18469,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
const int n_as = src0->ne[2];
cur += GGML_PAD(cur, sizeof(int64_t)); // align
cur += n_as * sizeof(int64_t); // matrix_row_counts
cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
} break;
case GGML_OP_OUT_PROD:
{
@ -20862,12 +20869,12 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
ok = ok && cur != NULL;
ggml_set_name(cur, ctx->infos[i].name.data);
if (!ok) {
break;
}
ggml_set_name(cur, ctx->infos[i].name.data);
// point the data member to the appropriate location in the binary blob using the tensor infos
if (!params.no_alloc) {
//cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file

7
ggml.h
View file

@ -1161,13 +1161,12 @@ extern "C" {
enum ggml_prec prec);
// indirect matrix multiplication
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
// TODO: document
GGML_API struct ggml_tensor * ggml_mul_mat_id(
struct ggml_context * ctx,
struct ggml_tensor * as,
struct ggml_tensor * ids,
int id,
struct ggml_tensor * b);
struct ggml_tensor * b,
struct ggml_tensor * ids);
// A: m columns, n rows,
// B: p columns, n rows,

View file

@ -6392,52 +6392,56 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_expert, n_tokens]
cb(logits, "ffn_moe_logits", il);
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
cb(probs, "ffn_moe_probs", il);
// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_expert_used, n_tokens]
cb(selected_experts->src[0], "ffn_moe_argsort", il);
cb(selected_experts, "ffn_moe_topk", il);
ggml_tensor * weights = ggml_get_rows(ctx0,
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights", il);
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
cb(weights_sum, "ffn_moe_weights_sum", il);
weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights_norm", il);
// compute expert outputs
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
ggml_tensor * up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(up, "ffn_moe_up", il);
ggml_tensor * gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(gate, "ffn_moe_gate", il);
gate = ggml_silu(ctx0, gate);
cb(gate, "ffn_moe_silu", il);
ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
cb(par, "ffn_moe_gate_par", il);
ggml_tensor * experts = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
cb(experts, "ffn_moe_down", il);
experts = ggml_mul(ctx0, experts,
ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens));
// aggregate experts
ggml_tensor * moe_out = nullptr;
for (int i = 0; i < n_expert_used; ++i) {
ggml_tensor * cur_expert;
ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
experts->nb[2], i*experts->nb[1]);
ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
cb(cur_up, "ffn_moe_up", il);
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
cb(cur_gate, "ffn_moe_gate", il);
cur_gate = ggml_silu(ctx0, cur_gate);
cb(cur_gate, "ffn_moe_silu", il);
cur_expert = ggml_mul(ctx0, cur_up, cur_gate);
cb(cur_expert, "ffn_moe_gate_par", il);
cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
cb(cur_expert, "ffn_moe_down", il);
cur_expert = ggml_mul(ctx0, cur_expert,
ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
cb(cur_expert, "ffn_moe_weighted", il);
// FIXME: non-contiguous add broken in cuda
cur_expert = ggml_cont(ctx0, cur_expert);
if (i == 0) {
moe_out = cur_expert;
@ -6953,11 +6957,12 @@ struct llm_build_context {
for (int i = 0; i < n_expert_used; ++i) {
ggml_tensor * cur_expert;
// FIXME
ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, cur);
cb(cur_up, "ffn_moe_up", il);
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, cur);
cb(cur_gate, "ffn_moe_gate", il);
//GeLU
@ -6967,7 +6972,7 @@ struct llm_build_context {
cur_expert = ggml_mul(ctx0, cur_up, cur_gate);
cb(cur_expert, "ffn_moe_gate_par", il);
cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, cur_expert); // [n_tokens, n_embd]
cb(cur_expert, "ffn_moe_down", il);
cur_expert = ggml_mul(ctx0, cur_expert,

View file

@ -22,9 +22,9 @@ fi
make_opts=""
if [[ "$backend" == "cuda" ]]; then
make_opts="LLAMA_CUDA=1"
fi
#if [[ "$backend" == "cuda" ]]; then
# make_opts="LLAMA_CUDA=1"
#fi
git checkout $1
make clean && make -j32 $make_opts llama-bench

View file

@ -478,8 +478,9 @@ struct test_case {
}
double err = nmse(f1.data(), f2.data(), f1.size());
printf("[%s] NMSE = %.9f > %.9f \n", ggml_op_desc(t1), err, ud->max_err);
if (err > ud->max_err) {
printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
//printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
//for (int i = 0; i < (int) f1.size(); i++) {
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
//}
@ -948,14 +949,14 @@ struct test_mul_mat_id : public test_case {
const ggml_type type_a;
const ggml_type type_b;
const int n_mats;
const int id;
const int n_used;
const bool b; // brodcast b matrix
const int64_t m;
const int64_t n;
const int64_t k;
const bool v; // view (non-contiguous ids)
std::string vars() override {
return VARS_TO_STR8(type_a, type_b, n_mats, id, m, n, k, v);
return VARS_TO_STR7(type_a, type_b, n_mats, b, m, n, k);
}
double max_nmse_err() override {
@ -972,20 +973,18 @@ struct test_mul_mat_id : public test_case {
}
test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
int n_mats = 2, int id = 0,
int64_t m = 32, int64_t n = 32, int64_t k = 32, bool v = false)
: type_a(type_a), type_b(type_b), n_mats(n_mats), id(id),
m(m), n(n), k(k), v(v) {}
int n_mats = 8, int n_used = 2, bool b = false,
int64_t m = 32, int64_t n = 32, int64_t k = 32)
: type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b),
m(m), n(n), k(k) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
ggml_tensor * mats = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
ggml_tensor * as = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
if (v) {
ids = ggml_view_2d(ctx, ids, n_mats/2, ids->ne[1], ids->nb[1], 0);
}
ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, k, n);
ggml_tensor * out = ggml_mul_mat_id(ctx, mats, ids, v ? id/2 : id, b);
ids = ggml_view_2d(ctx, ids, n_used, n, ids->nb[1], 0);
ggml_tensor * b = ggml_new_tensor_3d(ctx, type_b, k, this->b ? 1 : n_used, n);
ggml_tensor * out = ggml_mul_mat_id(ctx, as, ids, b);
return out;
}
@ -1858,6 +1857,92 @@ struct test_falcon : public test_llm {
}
};
// Mixtral MOE
struct test_moe : public test_case {
const int n_expert;
const int n_expert_used;
const int n_tokens;
const int n_embd;
const int n_ff;
std::string op_desc(ggml_tensor * t) override {
return "MOE";
GGML_UNUSED(t);
}
std::string vars() override {
return VARS_TO_STR5(n_expert, n_expert_used, n_tokens, n_embd, n_ff);
}
test_moe(int n_experts = 8, int n_experts_per_tok = 2, int n_tokens = 1, int n_embd = 4096, int n_ff = 14336)
: n_expert(n_experts), n_expert_used(n_experts_per_tok), n_tokens(n_tokens), n_embd(n_embd), n_ff(n_ff) {
}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_type wtype = GGML_TYPE_F32;
ggml_tensor * ffn_gate_inp = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_expert);
ggml_tensor * ffn_gate_exps = ggml_new_tensor_3d(ctx, wtype, n_embd, n_ff, n_expert);
ggml_tensor * ffn_down_exps = ggml_new_tensor_3d(ctx, wtype, n_ff, n_embd, n_expert);
ggml_tensor * ffn_up_exps = ggml_new_tensor_3d(ctx, wtype, n_embd, n_ff, n_expert);
ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur); // [n_expert, n_tokens]
//ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, nullptr, 1.0f/sqrtf(n_embd), 0.0f);
// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
ggml_tensor * weights = ggml_get_rows(ctx,
ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens]
weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens]
cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
ggml_tensor * up = ggml_mul_mat_id(ctx, ffn_up_exps, selected_experts, cur); // [n_ff, n_expert_used, n_tokens]
ggml_tensor * gate = ggml_mul_mat_id(ctx, ffn_gate_exps, selected_experts, cur); // [n_ff, n_expert_used, n_tokens]
gate = ggml_silu(ctx, gate);
ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens]
ggml_tensor * experts = ggml_mul_mat_id(ctx, ffn_down_exps, selected_experts, par); // [n_embd, n_expert_used, n_tokens]
printf("mul: src0: %ld %ld %ld %ld\n", experts->ne[0], experts->ne[1], experts->ne[2], experts->ne[3]);
printf("mul: src1: %ld %ld %ld %ld\n", 1, n_expert_used, n_tokens, 1);
experts = ggml_mul(ctx, experts,
ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens));
// aggregate experts
ggml_tensor * moe_out = nullptr;
for (int i = 0; i < n_expert_used; ++i) {
ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
experts->nb[2], i*experts->nb[1]);
cur_expert = ggml_cont(ctx, cur_expert);
if (i == 0) {
moe_out = cur_expert;
} else {
moe_out = ggml_add(ctx, moe_out, cur_expert);
}
}
cur = moe_out;
return cur;
}
};
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
std::vector<std::unique_ptr<test_case>> test_cases;
std::default_random_engine rng(0);
@ -1875,6 +1960,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
};
test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024));
test_cases.emplace_back(new test_moe(8, 2, 32, 4096, 8*1024));
// unary ops
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
test_cases.emplace_back(new test_unary((ggml_unary_op) op));
@ -1944,6 +2032,10 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
};
// mul: src0: 4096 2 32 1
// mul: src1: 1 2 32 1
add_test_bin_bcast(GGML_TYPE_F32, {1, 2, 32, 1}, {4096, 1, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 8, 1}, {1, 1, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1, 1}, {32, 1, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 320, 320}, {1, 1, 1, 1});
@ -2012,10 +2104,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (ggml_type type_a : all_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
for (int n_mats : {2, 4, 8}) {
for (int id = 0; id < n_mats; id++) {
for (bool v : {false, true}) {
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, v));
}
for (bool b : {false, true}) {
// cur shape: 4096 1 32 1
// ffn_up_exps shape: 4096 8192 8 1
// selected_experts shape: 2 32 1 1
int m = 8192;
int n = 32;
int k = 4096;
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, 2, b, m, n, k));
}
}
}