finish f16 hf bitnet e2e

This commit is contained in:
Eddie-Wang1120 2024-06-07 14:42:52 +08:00
parent 1f2e0ee012
commit 5e59660173
10 changed files with 440 additions and 11 deletions

202
ggml.c
View file

@ -569,6 +569,15 @@ static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t *
static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
[GGML_TYPE_I2] = {
.type_name = "i2",
.blck_size = 1,
.type_size = sizeof(int8_t),
.is_quantized = false,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_i2_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_I8] = {
.type_name = "i8",
.blck_size = 1,
@ -1805,6 +1814,7 @@ inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x)
inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
inline static void ggml_vec_mul_f32_bitnet (const int n, float * y, const float x) { for (int i = 0; i < n; ++i) y[i] = y[i] * x; }
static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
assert(nrc == 1);
@ -2636,6 +2646,16 @@ inline static void ggml_vec_scaleroundclamp_f32(const int n, float * s, const fl
s[i] /= scale;
}
}
inline static void ggml_vec_scaleroundclamp_f32_v2(const int n, float * s, int8_t* inp, float scale, float min, float max) {
for (int i = 0; i < n; ++i) {
s[i] = round(s[i] * scale);
if (s[i] > max) s[i] = max;
if (s[i] < min) s[i] = min;
inp[i] = (int8_t)(s[i]);
}
}
//
// data types
@ -3081,6 +3101,10 @@ GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) {
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
}
if(tensor->type == 31){
nbytes = nbytes / 4 + 32;
}
}
else {
nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
@ -12411,7 +12435,10 @@ static void ggml_compute_forward_mul_mat_one_chunk(
}
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
size_t row_size = ggml_row_size(vec_dot_type, ne10);
if (src0->type == 31) {
row_size = ne10;
}
assert(ne12 % ne02 == 0);
assert(ne13 % ne03 == 0);
@ -12425,6 +12452,9 @@ static void ggml_compute_forward_mul_mat_one_chunk(
// attempt to reduce false-sharing (does not seem to make a difference)
// 16 * 2, accounting for mmla kernels
float tmp[32];
uint8_t *i_weight = (uint8_t*) (src0->data);
float * scale = (float * )((i_weight) + (ne00 * ne01 / 4));
float* act_scales = (float*) ((char *) wdata + ((ne11*nb11) / 4));
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
@ -12458,9 +12488,15 @@ static void ggml_compute_forward_mul_mat_one_chunk(
//}
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
if (src0->type == 31) {
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01 / 4, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
tmp[ir0 - iir0] = tmp[ir0 - iir0] * (*scale) * (act_scales[i11]);
}else {
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
}
}
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
}
@ -12469,6 +12505,164 @@ static void ggml_compute_forward_mul_mat_one_chunk(
}
}
static void ggml_compute_forward_bitnet_mul_mat(
const struct ggml_compute_params * params,
struct ggml_tensor * dst,
struct ggml_compute_state * state) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
int64_t t0 = ggml_perf_time_us();
UNUSED(t0);
GGML_TENSOR_BINARY_OP_LOCALS
const int ith = params->ith;
const int nth = params->nth;
const enum ggml_type type = src0->type;
const bool src1_cont = ggml_is_contiguous(src1);
GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type));
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
// broadcast factors
const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03;
UNUSED(r2);
UNUSED(r3);
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
if (params->type == GGML_TASK_TYPE_INIT) {
if (ith != 0) {
return;
}
atomic_store(&state->shared->current_chunk, nth);
char * wdata = params->wdata;
float* act_scales = (float*) ((char *) wdata + ((ne11*nb11) / 4));
for (int64_t i13 = 0; i13 < ne13; i13++) {
for (int64_t i12 = 0; i12 < ne12; i12++) {
for (int64_t i11 = 0; i11 < ne11; i11++) {
float rowmax = 0.00001;
ggml_vec_absmaxclamp_f32(ne10, &rowmax, (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13), 0.00001);
float s = 127 / rowmax;
act_scales[i11] = 1/s;
ggml_vec_scaleroundclamp_f32_v2(ne10,
(float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13),
(int8_t*) ((char *) wdata + ((i11*nb11 + i12*nb12 + i13*nb13) / 4)),
s, -128, 127);
}
}
}
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
// atomic_store(&state->shared->current_chunk, nth);
// // char * wdata = params->wdata;
// const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, ne10);
// // printf("vec_dot_type:%d\n", vec_dot_type);
// // printf("row_size:%ld\n", row_size);
// assert(params->wsize >= ne11*ne12*ne13*row_size);
// GGML_ASSERT(src1->type == GGML_TYPE_F32);
// for (int64_t i13 = 0; i13 < ne13; ++i13) {
// for (int64_t i12 = 0; i12 < ne12; ++i12) {
// for (int64_t i11 = 0; i11 < ne11; ++i11) {
// quantize_row_q8_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
// wdata += row_size;
// }
// }
// }
return;
}
if (params->type == GGML_TASK_TYPE_FINALIZE) {
return;
}
// This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
const int64_t nr0 = ne0;
// This is the size of the rest of the dimensions of the result
const int64_t nr1 = ne1 * ne2 * ne3;
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
int64_t num_rows_per_vec_dot = 1;
// TODO: currently the mmla kernels support only even numbered rows/cols.
// this check can be removed once they are extended to support odd numbered rows/cols too
if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
num_rows_per_vec_dot = 1;
}
// Now select a reasonable chunk size.
int chunk_size = 16;
// We need to step up the size if it's small
if (nr0 == 1 || nr1 == 1) {
chunk_size = 64;
}
// distribute the work across the inner or outer loop based on which one is larger
// The number of chunks in the 0/1 dim.
// CEIL(nr0/chunk_size)
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
// If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread.
// Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915
// In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) {
// distribute the thread work across the inner or outer loop based on which one is larger
nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
}
// The number of elements in each chunk
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
//if (ith == 0)
// printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
// The first chunk comes from our thread_id, the rest will get auto-assigned.
int current_chunk = ith;
while (current_chunk < nchunk0 * nchunk1) {
const int64_t ith0 = current_chunk % nchunk0;
const int64_t ith1 = current_chunk / nchunk0;
const int64_t ir0_start = dr0 * ith0;
const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
const int64_t ir1_start = dr1 * ith1;
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
if (nth >= nchunk0 * nchunk1) {
break;
}
current_chunk = atomic_fetch_add(&state->shared->current_chunk, 1);
}
}
static void ggml_compute_forward_mul_mat(
const struct ggml_compute_params * params,
struct ggml_tensor * dst,
@ -12482,6 +12676,11 @@ static void ggml_compute_forward_mul_mat(
GGML_TENSOR_BINARY_OP_LOCALS
if (src0->type == 31) {
ggml_compute_forward_bitnet_mul_mat(params, dst, state);
return;
}
const int ith = params->ith;
const int nth = params->nth;
@ -14349,6 +14548,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_I32:
case GGML_TYPE_I64:
case GGML_TYPE_F64:
case GGML_TYPE_I2:
case GGML_TYPE_COUNT:
{
GGML_ASSERT(false);