hf bitnet v1

This commit is contained in:
Eddie-Wang1120 2024-06-05 16:15:28 +08:00
parent 3b38d48609
commit 076b4a197b
7 changed files with 897 additions and 3 deletions

111
ggml.c
View file

@ -2621,6 +2621,22 @@ inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
*s = idx;
}
inline static void ggml_vec_absmaxclamp_f32(const int n, float * s, const float * x, float min) {
float max = min;
for (int i = 0; i < n; ++i) {
max = MAX(max, fabs(x[i]));
}
*s = max;
}
inline static void ggml_vec_scaleroundclamp_f32(const int n, float * s, const float * x, float scale, float min, float max) {
for (int i = 0; i < n; ++i) {
s[i] = round(x[i] * scale);
if (s[i] > max) s[i] = max;
if (s[i] < min) s[i] = min;
s[i] /= scale;
}
}
//
// data types
//
@ -2709,9 +2725,11 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS",
"CROSS_ENTROPY_LOSS_BACK",
"BITLINEAR_QUANT"
};
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -2797,9 +2815,11 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss(x,y)",
"cross_entropy_loss_back(x,y)",
"bitlinear(x)",
};
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -4830,6 +4850,28 @@ struct ggml_tensor * ggml_mean(
return result;
}
// ggml_bitlinear_quant for bitnet
struct ggml_tensor * ggml_bitlinear_quant(
struct ggml_context * ctx,
struct ggml_tensor * a) {
bool is_node = false;
if (a->grad) {
GGML_ASSERT(false); // TODO: implement
is_node = true;
}
int64_t ne[GGML_MAX_DIMS] = { a->ne[0], a->ne[1], a->ne[2], a->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, ggml_n_dims(a), ne);
result->op = GGML_OP_BITLINEAR_QUANT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
// ggml_argmax
struct ggml_tensor * ggml_argmax(
@ -10740,6 +10782,62 @@ static void ggml_compute_forward_mean(
}
}
static void ggml_compute_forward_bitlinear_quant_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
assert(params->ith == 0);
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
return;
}
assert(src0->nb[0] == sizeof(float));
GGML_TENSOR_UNARY_OP_LOCALS
assert(ne0 == ne00);
assert(ne1 == ne01);
assert(ne2 == ne02);
assert(ne3 == ne03);
UNUSED(ne0);
UNUSED(ne1);
UNUSED(ne2);
UNUSED(ne3);
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) {
float rowmax = 0.00001;
ggml_vec_absmaxclamp_f32(ne00, &rowmax, (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), 0.00001);
float s = 127 / rowmax;
ggml_vec_scaleroundclamp_f32(ne00,
(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
(float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
s, -128, 127);
}
}
}
}
static void ggml_compute_forward_bitlinear_quant(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_bitlinear_quant_f32(params, src0, dst);
} break;
default:
{
GGML_ASSERT(false);
} break;
}
}
// ggml_compute_forward_argmax
static void ggml_compute_forward_argmax_f32(
@ -17318,6 +17416,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_mean(params, tensor);
} break;
case GGML_OP_BITLINEAR_QUANT:
{
ggml_compute_forward_bitlinear_quant(params, tensor->src[0], tensor);
} break;
case GGML_OP_ARGMAX:
{
ggml_compute_forward_argmax(params, tensor);
@ -18484,6 +18586,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_BITLINEAR_QUANT:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_ARGSORT:
{
GGML_ASSERT(false); // TODO: not implemented
@ -19249,6 +19355,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
case GGML_OP_GET_REL_POS:
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_BITLINEAR_QUANT:
case GGML_OP_MAP_CUSTOM1_F32:
case GGML_OP_MAP_CUSTOM2_F32:
case GGML_OP_MAP_CUSTOM3_F32: