clean code 2

This commit is contained in:
root 2024-06-09 21:15:02 +08:00
parent 1c5a8b7fec
commit 3a0f8b0697
5 changed files with 100 additions and 51 deletions

35
ggml.c
View file

@ -2724,7 +2724,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS",
"CROSS_ENTROPY_LOSS_BACK",
};
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
@ -2813,7 +2812,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss(x,y)",
"cross_entropy_loss_back(x,y)",
};
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
@ -3078,10 +3076,9 @@ GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) {
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
}
if(tensor->type == 31){
if(tensor->type == GGML_TYPE_I2_S){
nbytes = nbytes / 4 + 32;
}
}
else {
nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
@ -3107,6 +3104,7 @@ GGML_CALL size_t ggml_type_size(enum ggml_type type) {
GGML_CALL size_t ggml_row_size(enum ggml_type type, int64_t ne) {
assert(ne % ggml_blck_size(type) == 0);
if (type == GGML_TYPE_I2_S) ne /= 4;
return ggml_type_size(type)*ne/ggml_blck_size(type);
}
@ -12333,11 +12331,11 @@ static void ggml_compute_forward_mul_mat_one_chunk(
return;
}
void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
size_t row_size = ggml_row_size(vec_dot_type, ne10);
if (src0->type == 31) {
row_size = ne10;
}
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
// if (src0->type == 31) {
// row_size = ne10;
// }
assert(ne12 % ne02 == 0);
assert(ne13 % ne03 == 0);
@ -12351,9 +12349,8 @@ static void ggml_compute_forward_mul_mat_one_chunk(
// attempt to reduce false-sharing (does not seem to make a difference)
// 16 * 2, accounting for mmla kernels
float tmp[32];
uint8_t *i_weight = (uint8_t*) (src0->data);
float * scale = (float * )((i_weight) + (ne00 * ne01 / 4));
float * act_scales = (float*) ((char *) wdata + (ne11 * ne10));
float * scale = (float * )((uint8_t*) (src0->data) + (ne00 * ne01 / 4));
const float * act_scales = (const float*) ((const char *) wdata + (ne11 * ne10));
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
@ -12380,7 +12377,6 @@ static void ggml_compute_forward_mul_mat_one_chunk(
(src1_cont || src1->type != vec_dot_type
? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
: (i11 * nb11 + i12 * nb12 + i13 * nb13));
float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
@ -12388,13 +12384,12 @@ static void ggml_compute_forward_mul_mat_one_chunk(
//}
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
if (src0->type == GGML_TYPE_I2_S) {
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01 / 4, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
tmp[ir0 - iir0] = tmp[ir0 - iir0] / (act_scales[i11]) * (*scale);
} else {
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
}
if (src0->type == GGML_TYPE_I2_S) {
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01 / 4, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
tmp[ir0 - iir0] = tmp[ir0 - iir0] / (act_scales[i11]) * (*scale);
} else {
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
}
}
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));